diff --git a/telegram/ext/basepersistence.py b/telegram/ext/basepersistence.py index f62491ee903..0a3a5e11cbd 100644 --- a/telegram/ext/basepersistence.py +++ b/telegram/ext/basepersistence.py @@ -33,6 +33,12 @@ class BasePersistence(object): must overwrite :meth:`get_conversations` and :meth:`update_conversation`. * :meth:`flush` will be called when the bot is shutdown. + Note: + It may be benifitial to check if data has changed, before persisting it. Therefore + :meth:`get_chat_data`, :meth:`get_user_data` and :meth:`get_conversations` should *not* + return the data stored in the instance of your persistence class but rather a deep copy + of it. + Attributes: store_user_data (:obj:`bool`): Optional, Whether user_data should be saved by this persistence class. diff --git a/telegram/ext/dictpersistence.py b/telegram/ext/dictpersistence.py index 28ca4727331..1b9fcc22fa8 100644 --- a/telegram/ext/dictpersistence.py +++ b/telegram/ext/dictpersistence.py @@ -25,6 +25,7 @@ except ImportError: import json from collections import defaultdict +from copy import deepcopy from telegram.ext import BasePersistence @@ -36,11 +37,13 @@ class DictPersistence(BasePersistence): persistence class. store_chat_data (:obj:`bool`): Whether chat_data should be saved by this persistence class. + on_update (:obj:`bool`): Optional. When ``True`` will only save to file, if data has + changed. When ``False`` will save to file on every update. Default is ``False``. Args: store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this persistence class. Default is ``True``. - store_chat_data (:obj:`bool`, optional): Whether user_data should be saved by this + store_chat_data (:obj:`bool`, optional): Whether chat_data should be saved by this persistence class. Default is ``True``. user_data_json (:obj:`str`, optional): Json string that will be used to reconstruct user_data on creating this persistence. Default is ``""``. @@ -48,10 +51,12 @@ class DictPersistence(BasePersistence): chat_data on creating this persistence. Default is ``""``. conversations_json (:obj:`str`, optional): Json string that will be used to reconstruct conversation on creating this persistence. Default is ``""``. + on_update (:obj:`bool`, optional): When ``True`` will only save to file, if data has + changed. When ``False`` will save to file on every update. Default is ``False``. """ def __init__(self, store_user_data=True, store_chat_data=True, user_data_json='', - chat_data_json='', conversations_json=''): + chat_data_json='', conversations_json='', on_update=False): self.store_user_data = store_user_data self.store_chat_data = store_chat_data self._user_data = None @@ -60,6 +65,7 @@ def __init__(self, store_user_data=True, store_chat_data=True, user_data_json='' self._user_data_json = None self._chat_data_json = None self._conversations_json = None + self.on_update = on_update if user_data_json: try: self._user_data = decode_user_chat_data_from_json(user_data_json) @@ -129,7 +135,10 @@ def get_user_data(self): pass else: self._user_data = defaultdict(dict) - return self.user_data.copy() + if self.on_update: + return deepcopy(self.user_data) + else: + return self.user_data def get_chat_data(self): """Returns the chat_data created from the ``chat_data_json`` or an empty defaultdict. @@ -141,7 +150,10 @@ def get_chat_data(self): pass else: self._chat_data = defaultdict(dict) - return self.chat_data.copy() + if self.on_update: + return deepcopy(self.chat_data) + else: + return self.chat_data def get_conversations(self, name): """Returns the conversations created from the ``conversations_json`` or an empty @@ -154,7 +166,10 @@ def get_conversations(self, name): pass else: self._conversations = {} - return self.conversations.get(name, {}).copy() + if self.on_update: + return deepcopy(self.conversations.get(name, {})) + else: + return self.conversations.get(name, {}) def update_conversation(self, name, key, new_state): """Will update the conversations for the given handler. @@ -164,9 +179,12 @@ def update_conversation(self, name, key, new_state): key (:obj:`tuple`): The key the state is changed for. new_state (:obj:`tuple` | :obj:`any`): The new state for the given key. """ - if self._conversations.setdefault(name, {}).get(key) == new_state: - return - self._conversations[name][key] = new_state + if self.on_update: + if self._conversations.setdefault(name, {}).get(key) == new_state: + return + self._conversations[name][key] = deepcopy(new_state) + else: + self._conversations[name][key] = new_state self._conversations_json = None def update_user_data(self, user_id, data): @@ -176,8 +194,12 @@ def update_user_data(self, user_id, data): user_id (:obj:`int`): The user the data might have been changed for. data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id]. """ - if self._user_data.get(user_id) == data: - return + if self.on_update: + if self._user_data.get(user_id) == data: + return + self._user_data[user_id] = deepcopy(data) + else: + self._user_data[user_id] = data self._user_data[user_id] = data self._user_data_json = None @@ -188,7 +210,11 @@ def update_chat_data(self, chat_id, data): chat_id (:obj:`int`): The chat the data might have been changed for. data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[chat_id]. """ - if self._chat_data.get(chat_id) == data: - return + if self.on_update: + if self._chat_data.get(chat_id) == data: + return + self._chat_data[chat_id] = deepcopy(data) + else: + self._chat_data[chat_id] = data self._chat_data[chat_id] = data self._chat_data_json = None diff --git a/telegram/ext/picklepersistence.py b/telegram/ext/picklepersistence.py index ed3c06cf900..384fb195b17 100644 --- a/telegram/ext/picklepersistence.py +++ b/telegram/ext/picklepersistence.py @@ -19,6 +19,7 @@ """This module contains the PicklePersistence class.""" import pickle from collections import defaultdict +from copy import deepcopy from telegram.ext import BasePersistence @@ -31,7 +32,7 @@ class PicklePersistence(BasePersistence): is false this will be used as a prefix. store_user_data (:obj:`bool`): Optional. Whether user_data should be saved by this persistence class. - store_chat_data (:obj:`bool`): Optional. Whether user_data should be saved by this + store_chat_data (:obj:`bool`): Optional. Whether chat_data should be saved by this persistence class. single_file (:obj:`bool`): Optional. When ``False`` will store 3 sperate files of `filename_user_data`, `filename_chat_data` and `filename_conversations`. Default is @@ -39,6 +40,8 @@ class PicklePersistence(BasePersistence): on_flush (:obj:`bool`): Optional. When ``True`` will only save to file when :meth:`flush` is called and keep data in memory until that happens. When False will store data on any transaction. Default is ``False``. + on_update (:obj:`bool`): Optional. When ``True`` will only save to file, if data has + changed. When ``False`` will save to file on every update. Default is ``False``. Args: filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file` @@ -53,15 +56,18 @@ class PicklePersistence(BasePersistence): on_flush (:obj:`bool`, optional): When ``True`` will only save to file when :meth:`flush` is called and keep data in memory until that happens. When False will store data on any transaction. Default is ``False``. + on_update (:obj:`bool`, optional): When ``True`` will only save to file, if data has + changed. When ``False`` will save to file on every update. Default is ``False``. """ def __init__(self, filename, store_user_data=True, store_chat_data=True, singe_file=True, - on_flush=False): + on_flush=False, on_update=False): self.filename = filename self.store_user_data = store_user_data self.store_chat_data = store_chat_data self.single_file = singe_file self.on_flush = on_flush + self.on_update = on_update self.user_data = None self.chat_data = None self.conversations = None @@ -122,7 +128,10 @@ def get_user_data(self): self.user_data = data else: self.load_singlefile() - return self.user_data.copy() + if self.on_update: + return deepcopy(self.user_data) + else: + return self.user_data def get_chat_data(self): """Returns the chat_data from the pickle file if it exsists or an empty defaultdict. @@ -142,7 +151,10 @@ def get_chat_data(self): self.chat_data = data else: self.load_singlefile() - return self.chat_data.copy() + if self.on_update: + return deepcopy(self.chat_data) + else: + return self.chat_data def get_conversations(self, name): """Returns the conversations from the pickle file if it exsists or an empty defaultdict. @@ -163,7 +175,10 @@ def get_conversations(self, name): self.conversations = data else: self.load_singlefile() - return self.conversations.get(name, {}).copy() + if self.on_update: + return deepcopy(self.conversations.get(name, {})) + else: + return self.conversations.get(name, {}) def update_conversation(self, name, key, new_state): """Will update the conversations for the given handler and depending on :attr:`on_flush` @@ -174,9 +189,12 @@ def update_conversation(self, name, key, new_state): key (:obj:`tuple`): The key the state is changed for. new_state (:obj:`tuple` | :obj:`any`): The new state for the given key. """ - if self.conversations.setdefault(name, {}).get(key) == new_state: - return - self.conversations[name][key] = new_state + if self.on_update: + if self.conversations.setdefault(name, {}).get(key) == new_state: + return + self.conversations[name][key] = deepcopy(new_state) + else: + self.conversations[name][key] = new_state if not self.on_flush: if not self.single_file: filename = "{}_conversations".format(self.filename) @@ -192,9 +210,12 @@ def update_user_data(self, user_id, data): user_id (:obj:`int`): The user the data might have been changed for. data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id]. """ - if self.user_data.get(user_id) == data: - return - self.user_data[user_id] = data + if self.on_update: + if self.user_data.get(user_id) == data: + return + self.user_data[user_id] = deepcopy(data) + else: + self.user_data[user_id] = data if not self.on_flush: if not self.single_file: filename = "{}_user_data".format(self.filename) @@ -210,9 +231,12 @@ def update_chat_data(self, chat_id, data): chat_id (:obj:`int`): The chat the data might have been changed for. data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[chat_id]. """ - if self.chat_data.get(chat_id) == data: - return - self.chat_data[chat_id] = data + if self.on_update: + if self.chat_data.get(chat_id) == data: + return + self.chat_data[chat_id] = deepcopy(data) + else: + self.chat_data[chat_id] = data if not self.on_flush: if not self.single_file: filename = "{}_chat_data".format(self.filename) diff --git a/tests/test_persistence.py b/tests/test_persistence.py index c43c160580b..3a20c60ab16 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -223,7 +223,8 @@ def pickle_persistence(): store_user_data=True, store_chat_data=True, singe_file=False, - on_flush=False) + on_flush=False, + on_update=True) @pytest.fixture(scope='function') @@ -598,7 +599,7 @@ def conversations_json(conversations): class TestDictPersistence(object): def test_no_json_given(self): - dict_persistence = DictPersistence() + dict_persistence = DictPersistence(on_update=True) assert dict_persistence.get_user_data() == defaultdict(dict) assert dict_persistence.get_chat_data() == defaultdict(dict) assert dict_persistence.get_conversations('noname') == {} @@ -608,27 +609,28 @@ def test_bad_json_string_given(self): bad_chat_data = 'thisisnojson99900()))(' bad_conversations = 'thisisnojson99900()))(' with pytest.raises(TypeError, match='user_data'): - DictPersistence(user_data_json=bad_user_data) + DictPersistence(user_data_json=bad_user_data, on_update=True) with pytest.raises(TypeError, match='chat_data'): - DictPersistence(chat_data_json=bad_chat_data) + DictPersistence(chat_data_json=bad_chat_data, on_update=True) with pytest.raises(TypeError, match='conversations'): - DictPersistence(conversations_json=bad_conversations) + DictPersistence(conversations_json=bad_conversations, on_update=True) def test_invalid_json_string_given(self, pickle_persistence, bad_pickle_files): bad_user_data = '["this", "is", "json"]' bad_chat_data = '["this", "is", "json"]' bad_conversations = '["this", "is", "json"]' with pytest.raises(TypeError, match='user_data'): - DictPersistence(user_data_json=bad_user_data) + DictPersistence(user_data_json=bad_user_data, on_update=True) with pytest.raises(TypeError, match='chat_data'): - DictPersistence(chat_data_json=bad_chat_data) + DictPersistence(chat_data_json=bad_chat_data, on_update=True) with pytest.raises(TypeError, match='conversations'): - DictPersistence(conversations_json=bad_conversations) + DictPersistence(conversations_json=bad_conversations, on_update=True) def test_good_json_input(self, user_data_json, chat_data_json, conversations_json): dict_persistence = DictPersistence(user_data_json=user_data_json, chat_data_json=chat_data_json, - conversations_json=conversations_json) + conversations_json=conversations_json, + on_update=True) user_data = dict_persistence.get_user_data() assert isinstance(user_data, defaultdict) assert user_data[12345]['test1'] == 'test2' @@ -658,7 +660,8 @@ def test_dict_outputs(self, user_data, user_data_json, chat_data, chat_data_json conversations, conversations_json): dict_persistence = DictPersistence(user_data_json=user_data_json, chat_data_json=chat_data_json, - conversations_json=conversations_json) + conversations_json=conversations_json, + on_update=True) assert dict_persistence.user_data == user_data assert dict_persistence.chat_data == chat_data assert dict_persistence.conversations == conversations @@ -666,7 +669,8 @@ def test_dict_outputs(self, user_data, user_data_json, chat_data, chat_data_json def test_json_outputs(self, user_data_json, chat_data_json, conversations_json): dict_persistence = DictPersistence(user_data_json=user_data_json, chat_data_json=chat_data_json, - conversations_json=conversations_json) + conversations_json=conversations_json, + on_update=True) assert dict_persistence.user_data_json == user_data_json assert dict_persistence.chat_data_json == chat_data_json assert dict_persistence.conversations_json == conversations_json @@ -675,7 +679,8 @@ def test_json_changes(self, user_data, user_data_json, chat_data, chat_data_json conversations, conversations_json): dict_persistence = DictPersistence(user_data_json=user_data_json, chat_data_json=chat_data_json, - conversations_json=conversations_json) + conversations_json=conversations_json, + on_update=True) user_data_two = user_data.copy() user_data_two.update({4: {5: 6}}) dict_persistence.update_user_data(4, {5: 6}) @@ -699,7 +704,7 @@ def test_json_changes(self, user_data, user_data_json, chat_data, chat_data_json conversations_two) def test_with_handler(self, bot, update): - dict_persistence = DictPersistence() + dict_persistence = DictPersistence(on_update=True) u = Updater(bot=bot, persistence=dict_persistence) dp = u.dispatcher @@ -727,7 +732,8 @@ def second(bot, update, user_data, chat_data): chat_data = dict_persistence.chat_data_json del (dict_persistence) dict_persistence_2 = DictPersistence(user_data_json=user_data, - chat_data_json=chat_data) + chat_data_json=chat_data, + on_update=True) u = Updater(bot=bot, persistence=dict_persistence_2) dp = u.dispatcher @@ -735,7 +741,8 @@ def second(bot, update, user_data, chat_data): dp.process_update(update) def test_with_conversationHandler(self, dp, update, conversations_json): - dict_persistence = DictPersistence(conversations_json=conversations_json) + dict_persistence = DictPersistence(conversations_json=conversations_json, + on_update=True) dp.persistence = dict_persistence NEXT, NEXT2 = range(2)