diff --git a/docs/source/telegram.ext.chatadminsrole.rst b/docs/source/telegram.ext.chatadminsrole.rst new file mode 100644 index 00000000000..799d130f094 --- /dev/null +++ b/docs/source/telegram.ext.chatadminsrole.rst @@ -0,0 +1,6 @@ +telegram.ext.ChatAdminsRole +=========================== + +.. autoclass:: telegram.ext.ChatAdminsRole + :members: + :show-inheritance: diff --git a/docs/source/telegram.ext.chatcreatorrole.rst b/docs/source/telegram.ext.chatcreatorrole.rst new file mode 100644 index 00000000000..1a8b6d8ae6b --- /dev/null +++ b/docs/source/telegram.ext.chatcreatorrole.rst @@ -0,0 +1,6 @@ +telegram.ext.ChatCreatorRole +============================ + +.. autoclass:: telegram.ext.ChatCreatorRole + :members: + :show-inheritance: diff --git a/docs/source/telegram.ext.role.rst b/docs/source/telegram.ext.role.rst new file mode 100644 index 00000000000..46cad7c0967 --- /dev/null +++ b/docs/source/telegram.ext.role.rst @@ -0,0 +1,6 @@ +telegram.ext.Role +================= + +.. autoclass:: telegram.ext.Role + :members: + :show-inheritance: diff --git a/docs/source/telegram.ext.roles.rst b/docs/source/telegram.ext.roles.rst new file mode 100644 index 00000000000..9e9a38c9532 --- /dev/null +++ b/docs/source/telegram.ext.roles.rst @@ -0,0 +1,6 @@ +telegram.ext.Roles +================== + +.. autoclass:: telegram.ext.Roles + :members: + :show-inheritance: diff --git a/docs/source/telegram.ext.rst b/docs/source/telegram.ext.rst index d5148bd6122..fd58c346741 100644 --- a/docs/source/telegram.ext.rst +++ b/docs/source/telegram.ext.rst @@ -13,6 +13,8 @@ telegram.ext package telegram.ext.delayqueue telegram.ext.callbackcontext telegram.ext.defaults + telegram.ext.role + telegram.ext.roles Handlers -------- @@ -43,4 +45,14 @@ Persistence telegram.ext.basepersistence telegram.ext.picklepersistence - telegram.ext.dictpersistence \ No newline at end of file + telegram.ext.dictpersistence + +Authentication +-------------- + +.. toctree:: + + telegram.ext.role + telegram.ext.roles + telegram.ext.chatadminsrole + telegram.ext.chatcreatorrole \ No newline at end of file diff --git a/telegram/ext/__init__.py b/telegram/ext/__init__.py index e77b5567334..9683760e068 100644 --- a/telegram/ext/__init__.py +++ b/telegram/ext/__init__.py @@ -44,6 +44,7 @@ from .pollanswerhandler import PollAnswerHandler from .pollhandler import PollHandler from .defaults import Defaults +from .roles import Role, Roles, ChatAdminsRole, ChatCreatorRole __all__ = ('Dispatcher', 'JobQueue', 'Job', 'Updater', 'CallbackQueryHandler', 'ChosenInlineResultHandler', 'CommandHandler', 'Handler', 'InlineQueryHandler', @@ -52,4 +53,4 @@ 'PreCheckoutQueryHandler', 'ShippingQueryHandler', 'MessageQueue', 'DelayQueue', 'DispatcherHandlerStop', 'run_async', 'CallbackContext', 'BasePersistence', 'PicklePersistence', 'DictPersistence', 'PrefixHandler', 'PollAnswerHandler', - 'PollHandler', 'Defaults') + 'PollHandler', 'Defaults', 'Role', 'Roles', 'ChatAdminsRole', 'ChatCreatorRole') diff --git a/telegram/ext/basepersistence.py b/telegram/ext/basepersistence.py index 23c42453b68..87cb875caec 100644 --- a/telegram/ext/basepersistence.py +++ b/telegram/ext/basepersistence.py @@ -31,6 +31,8 @@ class BasePersistence(object): :meth:`update_chat_data`. * If :attr:`store_user_data` is ``True`` you must overwrite :meth:`get_user_data` and :meth:`update_user_data`. + * If :attr:`store_roles` is ``True`` you must overwrite :meth:`get_roles` and + :meth:`update_roles`. * If you want to store conversation data with :class:`telegram.ext.ConversationHandler`, you must overwrite :meth:`get_conversations` and :meth:`update_conversation`. * :meth:`flush` will be called when the bot is shutdown. @@ -42,6 +44,8 @@ class BasePersistence(object): persistence class. store_bot_data (:obj:`bool`): Optional. Whether bot_data should be saved by this persistence class. + store_roles (:obj:`bool`): Optional. Whether roles should be saved by this persistence + class. Args: store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this @@ -50,12 +54,16 @@ class BasePersistence(object): persistence class. Default is ``True`` . store_bot_data (:obj:`bool`, optional): Whether bot_data should be saved by this persistence class. Default is ``True`` . + store_roles (:obj:`bool`, optional): Whether roles should be saved by this persistence + class. Default is ``True``. """ - def __init__(self, store_user_data=True, store_chat_data=True, store_bot_data=True): + def __init__(self, store_user_data=True, store_chat_data=True, store_bot_data=True, + store_roles=True): self.store_user_data = store_user_data self.store_chat_data = store_chat_data self.store_bot_data = store_bot_data + self.store_roles = store_roles def get_user_data(self): """"Will be called by :class:`telegram.ext.Dispatcher` upon creation with a @@ -87,6 +95,19 @@ def get_bot_data(self): """ raise NotImplementedError + def get_roles(self): + """"Will be called by :class:`telegram.ext.Dispatcher` upon creation with a + persistence object. + + Warning: + The produced roles instance usually will have no bot assigned. Use + :attr:`telegram.ext.Roles.set_bot` to set it. + + Returns: + :class:`telegram.ext.Roles`: The restored roles. + """ + raise NotImplementedError + def get_conversations(self, name): """"Will be called by :class:`telegram.ext.Dispatcher` when a :class:`telegram.ext.ConversationHandler` is added if @@ -141,6 +162,15 @@ def update_bot_data(self, data): """ raise NotImplementedError + def update_roles(self, data): + """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has + handled an update. + + Args: + data (:class:`telegram.ext.Roles`): The :attr:`telegram.ext.dispatcher.roles` . + """ + raise NotImplementedError + def flush(self): """Will be called by :class:`telegram.ext.Updater` upon receiving a stop signal. Gives the persistence a chance to finish up saving or close a database connection gracefully. If this diff --git a/telegram/ext/callbackcontext.py b/telegram/ext/callbackcontext.py index 95fcdb0c3fa..b4ff05b07ea 100644 --- a/telegram/ext/callbackcontext.py +++ b/telegram/ext/callbackcontext.py @@ -82,6 +82,7 @@ def __init__(self, dispatcher): raise ValueError('CallbackContext should not be used with a non context aware ' 'dispatcher!') self._dispatcher = dispatcher + self._roles = dispatcher.roles self._bot_data = dispatcher.bot_data self._chat_data = None self._user_data = None @@ -95,6 +96,14 @@ def dispatcher(self): """:class:`telegram.ext.Dispatcher`: The dispatcher associated with this context.""" return self._dispatcher + @property + def roles(self): + return self._roles + + @roles.setter + def roles(self, value): + raise AttributeError("You can not assign a new value to roles.") + @property def bot_data(self): return self._bot_data diff --git a/telegram/ext/callbackqueryhandler.py b/telegram/ext/callbackqueryhandler.py index 69f926bc08e..a5cb0003466 100644 --- a/telegram/ext/callbackqueryhandler.py +++ b/telegram/ext/callbackqueryhandler.py @@ -33,6 +33,8 @@ class CallbackQueryHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. pass_update_queue (:obj:`bool`): Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Determines whether ``job_queue`` will be passed to @@ -66,6 +68,9 @@ class CallbackQueryHandler(Handler): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -104,13 +109,15 @@ def __init__(self, pass_groups=False, pass_groupdict=False, pass_user_data=False, - pass_chat_data=False): + pass_chat_data=False, + roles=None): super(CallbackQueryHandler, self).__init__( callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, - pass_chat_data=pass_chat_data) + pass_chat_data=pass_chat_data, + roles=roles) if isinstance(pattern, string_types): pattern = re.compile(pattern) diff --git a/telegram/ext/choseninlineresulthandler.py b/telegram/ext/choseninlineresulthandler.py index 349fd620709..9ead04ff8d4 100644 --- a/telegram/ext/choseninlineresulthandler.py +++ b/telegram/ext/choseninlineresulthandler.py @@ -27,6 +27,8 @@ class ChosenInlineResultHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. pass_update_queue (:obj:`bool`): Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Determines whether ``job_queue`` will be passed to @@ -54,6 +56,9 @@ class ChosenInlineResultHandler(Handler): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` diff --git a/telegram/ext/commandhandler.py b/telegram/ext/commandhandler.py index 1387c18d662..34065f56eb3 100644 --- a/telegram/ext/commandhandler.py +++ b/telegram/ext/commandhandler.py @@ -47,6 +47,8 @@ class CommandHandler(Handler): callback (:obj:`callable`): The callback function for this handler. filters (:class:`telegram.ext.BaseFilter`): Optional. Only allow updates with these Filters. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. allow_edited (:obj:`bool`): Determines Whether the handler should also accept edited messages. pass_args (:obj:`bool`): Determines whether the handler should be passed @@ -85,6 +87,9 @@ class CommandHandler(Handler): :class:`telegram.ext.filters.BaseFilter`. Standard filters can be found in :class:`telegram.ext.filters.Filters`. Filters can be combined using bitwise operators (& for and, | for or, ~ for not). + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). allow_edited (:obj:`bool`, optional): Determines whether the handler should also accept edited messages. Default is ``False``. DEPRECATED: Edited is allowed by default. To change this behavior use @@ -124,13 +129,15 @@ def __init__(self, pass_update_queue=False, pass_job_queue=False, pass_user_data=False, - pass_chat_data=False): + pass_chat_data=False, + roles=None): super(CommandHandler, self).__init__( callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, - pass_chat_data=pass_chat_data) + pass_chat_data=pass_chat_data, + roles=roles) if isinstance(command, string_types): self.command = [command.lower()] @@ -231,6 +238,8 @@ class PrefixHandler(CommandHandler): callback (:obj:`callable`): The callback function for this handler. filters (:class:`telegram.ext.BaseFilter`): Optional. Only allow updates with these Filters. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. pass_args (:obj:`bool`): Determines whether the handler should be passed ``args``. pass_update_queue (:obj:`bool`): Determines whether ``update_queue`` will be @@ -267,6 +276,9 @@ class PrefixHandler(CommandHandler): :class:`telegram.ext.filters.BaseFilter`. Standard filters can be found in :class:`telegram.ext.filters.Filters`. Filters can be combined using bitwise operators (& for and, | for or, ~ for not). + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). pass_args (:obj:`bool`, optional): Determines whether the handler should be passed the arguments passed to the command as a keyword argument called ``args``. It will contain a list of strings, which is the text following the command split on single or @@ -300,7 +312,8 @@ def __init__(self, pass_update_queue=False, pass_job_queue=False, pass_user_data=False, - pass_chat_data=False): + pass_chat_data=False, + roles=None): self._prefix = list() self._command = list() @@ -311,7 +324,8 @@ def __init__(self, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, - pass_chat_data=pass_chat_data) + pass_chat_data=pass_chat_data, + roles=roles) self.prefix = prefix self.command = command diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index d6114646201..b8e2bb3c1ec 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -107,6 +107,8 @@ class ConversationHandler(Handler): map_to_parent (Dict[:obj:`object`, :obj:`object`]): Optional. A :obj:`dict` that can be used to instruct a nested conversationhandler to transition into a mapped state on its parent conversationhandler in place of a specified nested state. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. Args: entry_points (List[:class:`telegram.ext.Handler`]): A list of ``Handler`` objects that can @@ -142,6 +144,9 @@ class ConversationHandler(Handler): map_to_parent (Dict[:obj:`object`, :obj:`object`], optional): A :obj:`dict` that can be used to instruct a nested conversationhandler to transition into a mapped state on its parent conversationhandler in place of a specified nested state. + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). Raises: ValueError @@ -166,7 +171,8 @@ def __init__(self, conversation_timeout=None, name=None, persistent=False, - map_to_parent=None): + map_to_parent=None, + roles=None): self._entry_points = entry_points self._states = states @@ -185,6 +191,7 @@ def __init__(self, """:obj:`telegram.ext.BasePersistance`: The persistence used to store conversations. Set by dispatcher""" self._map_to_parent = map_to_parent + self.roles = roles self.timeout_jobs = dict() self._timeout_jobs_lock = Lock() diff --git a/telegram/ext/dictpersistence.py b/telegram/ext/dictpersistence.py index 42a2eca18fa..36f48d42092 100644 --- a/telegram/ext/dictpersistence.py +++ b/telegram/ext/dictpersistence.py @@ -28,6 +28,7 @@ import json from collections import defaultdict from telegram.ext import BasePersistence +from .roles import Roles class DictPersistence(BasePersistence): @@ -40,6 +41,8 @@ class DictPersistence(BasePersistence): persistence class. store_bot_data (:obj:`bool`): Whether bot_data should be saved by this persistence class. + store_roles (:obj:`bool`): Optional. Whether roles should be saved by this persistence + class. Args: store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this @@ -48,6 +51,8 @@ class DictPersistence(BasePersistence): persistence class. Default is ``True``. store_bot_data (:obj:`bool`, optional): Whether bot_data should be saved by this persistence class. Default is ``True`` . + store_roles (:obj:`bool`, optional): Whether roles should be saved by this persistence + class. Default is ``True``. user_data_json (:obj:`str`, optional): Json string that will be used to reconstruct user_data on creating this persistence. Default is ``""``. chat_data_json (:obj:`str`, optional): Json string that will be used to reconstruct @@ -56,6 +61,8 @@ class DictPersistence(BasePersistence): bot_data on creating this persistence. Default is ``""``. conversations_json (:obj:`str`, optional): Json string that will be used to reconstruct conversation on creating this persistence. Default is ``""``. + roles_json (:obj:`str`, optional): Json string that will be used to reconstruct + roles on creating this persistence. Default is ``""``. """ def __init__(self, @@ -65,18 +72,23 @@ def __init__(self, user_data_json='', chat_data_json='', bot_data_json='', - conversations_json=''): + conversations_json='', + store_roles=True, + roles_json=''): super(DictPersistence, self).__init__(store_user_data=store_user_data, store_chat_data=store_chat_data, - store_bot_data=store_bot_data) + store_bot_data=store_bot_data, + store_roles=store_roles) self._user_data = None self._chat_data = None self._bot_data = None self._conversations = None + self._roles = None self._user_data_json = None self._chat_data_json = None self._bot_data_json = None self._conversations_json = None + self._roles_json = None if user_data_json: try: self._user_data = decode_user_chat_data_from_json(user_data_json) @@ -105,6 +117,13 @@ def __init__(self, except (ValueError, AttributeError): raise TypeError("Unable to deserialize conversations_json. Not valid JSON") + if roles_json: + try: + self._roles = Roles.decode_from_json(roles_json, None) + self._roles_json = roles_json + except (ValueError, AttributeError, TypeError): + raise TypeError("Unable to deserialize roles_json. Not valid JSON") + @property def user_data(self): """:obj:`dict`: The user_data as a dict""" @@ -144,6 +163,20 @@ def bot_data_json(self): else: return json.dumps(self.bot_data) + @property + def roles(self): + """:class:`telegram.ext.Roles`: The roles. Doesn't have a bot assigned. Use + :attr:`telegram.ext.Roles.set_bot` to set it.""" + return self._roles + + @property + def roles_json(self): + """:obj:`str`: The roles serialized as a JSON-string.""" + if self._roles_json: + return self._roles_json + else: + return self._roles.encode_to_json() + @property def conversations(self): """:obj:`dict`: The conversations as a dict""" @@ -193,6 +226,25 @@ def get_bot_data(self): self._bot_data = {} return deepcopy(self.bot_data) + def get_roles(self): + """Returns the roles created from the ``roles_json`` or an empty + :class:`telegram.ext.Roles` instance. + + Warning: + The produced roles instance will have no bot assigned. Use + :attr:`telegram.ext.Roles.set_bot` to set it. + + Returns: + :class:`telegram.ext.Roles`: The restored roles. + """ + if self.roles: + pass + elif self._roles: + self._roles = Roles.decode_from_json(self._roles_json, None) + else: + self._roles = Roles(None) + return deepcopy(self.roles) + def get_conversations(self, name): """Returns the conversations created from the ``conversations_json`` or an empty defaultdict. @@ -257,3 +309,14 @@ def update_bot_data(self, data): return self._bot_data = data.copy() self._bot_data_json = None + + def update_roles(self, data): + """Will update the roles (if changed). + + Args: + data (:class:`telegram.ext.Roles`): The :attr:`telegram.ext.dispatcher.roles` . + """ + if self._roles == data: + return + self._roles = deepcopy(data) + self._roles_json = None diff --git a/telegram/ext/dispatcher.py b/telegram/ext/dispatcher.py index e3a3aee2e0c..cfe6d90774c 100644 --- a/telegram/ext/dispatcher.py +++ b/telegram/ext/dispatcher.py @@ -37,6 +37,7 @@ from telegram.utils.deprecate import TelegramDeprecationWarning from telegram.utils.promise import Promise from telegram.ext import BasePersistence +from .roles import Roles logging.getLogger(__name__).addHandler(logging.NullHandler()) DEFAULT_GROUP = 0 @@ -80,6 +81,7 @@ class Dispatcher(object): user_data (:obj:`defaultdict`): A dictionary handlers can use to store data for the user. chat_data (:obj:`defaultdict`): A dictionary handlers can use to store data for the chat. bot_data (:obj:`dict`): A dictionary handlers can use to store data for the bot. + roles (:class:`telegram.ext.Roles`): An object you can use to restrict access to handlers. persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to store data that should be persistent over restarts @@ -124,6 +126,7 @@ def __init__(self, self.user_data = defaultdict(dict) self.chat_data = defaultdict(dict) self.bot_data = {} + self.roles = Roles(self.bot) if persistence: if not isinstance(persistence, BasePersistence): raise TypeError("persistence should be based on telegram.ext.BasePersistence") @@ -140,6 +143,11 @@ def __init__(self, self.bot_data = self.persistence.get_bot_data() if not isinstance(self.bot_data, dict): raise ValueError("bot_data must be of type dict") + if self.persistence.store_roles: + self.roles = self.persistence.get_roles() + if not isinstance(self.roles, Roles): + raise ValueError("roles must be of type Roles") + self.roles.set_bot(self.bot_data) else: self.persistence = None @@ -425,7 +433,8 @@ def remove_handler(self, handler, group=DEFAULT_GROUP): self.groups.remove(group) def update_persistence(self, update=None): - """Update :attr:`user_data`, :attr:`chat_data` and :attr:`bot_data` in :attr:`persistence`. + """Update :attr:`user_data`, :attr:`chat_data`, :attr:`bot_data` and :attr:`roles` in + :attr:`persistence`. Args: update (:class:`telegram.Update`, optional): The update to process. If passed, only the @@ -445,6 +454,17 @@ def update_persistence(self, update=None): else: user_ids = [] + if self.persistence.store_roles: + try: + self.persistence.update_roles(self.roles) + except Exception as e: + try: + self.dispatch_error(update, e) + except Exception: + message = 'Saving roles raised an error and an ' \ + 'uncaught error was raised while handling ' \ + 'the error with an error_handler' + self.logger.exception(message) if self.persistence.store_bot_data: try: self.persistence.update_bot_data(self.bot_data) diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index c02e88822ae..da7e525690b 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -21,6 +21,7 @@ import re from future.utils import string_types +from threading import Lock from telegram import Chat, Update, MessageEntity @@ -850,10 +851,15 @@ class user(BaseFilter): Examples: ``MessageHandler(Filters.user(1234), callback_method)`` + Attributes: + user_ids(set(:obj:`int`), optional): Which user ID(s) to allow through. + usernames(set(:obj:`str`), optional): Which username(s) (without leading '@') to allow + through. Args: - user_id(:obj:`int` | List[:obj:`int`], optional): Which user ID(s) to allow through. - username(:obj:`str` | List[:obj:`str`], optional): Which username(s) to allow through. - If username starts with '@' symbol, it will be ignored. + user_id(:obj:`int` | iterable(:obj:`int`), optional): Which user ID(s) to allow + through. + username(:obj:`str` | iterable(:obj:`str`), optional): Which username(s) to allow + through. If username starts with '@' symbol, it will be ignored. Raises: ValueError: If chat_id and username are both present, or neither is. @@ -861,18 +867,52 @@ class user(BaseFilter): """ def __init__(self, user_id=None, username=None): - if not (bool(user_id) ^ bool(username)): + if (user_id is None) == (username is None): raise ValueError('One and only one of user_id or username must be used') - if user_id is not None and isinstance(user_id, int): - self.user_ids = [user_id] - else: - self.user_ids = user_id - if username is None: - self.usernames = username - elif isinstance(username, string_types): - self.usernames = [username.replace('@', '')] - else: - self.usernames = [user.replace('@', '') for user in username] + + self._user_ids_lock = Lock() + self._usernames_lock = Lock() + + # Initialize in a way that will not fail the first setter calls + self._user_ids = user_id + self._usernames = username + # Actually initialize + self.user_ids = user_id + self.usernames = username + + @property + def user_ids(self): + with self._user_ids_lock: + return self._user_ids + + @user_ids.setter + def user_ids(self, user_id): + if (user_id is None) == (self.usernames is None): + raise ValueError('One and only one of user_id or username must be used') + with self._user_ids_lock: + if user_id is None: + self._user_ids = None + elif isinstance(user_id, int): + self._user_ids = set([user_id]) + else: + self._user_ids = set(user_id) + + @property + def usernames(self): + with self._usernames_lock: + return self._usernames + + @usernames.setter + def usernames(self, username): + if (username is None) == (self.user_ids is None): + raise ValueError('One and only one of user_id or username must be used') + with self._usernames_lock: + if username is None: + self._usernames = None + elif isinstance(username, str): + self._usernames = set([username.replace('@', '')]) + else: + self._usernames = set([user.replace('@', '') for user in username]) def filter(self, message): """""" # remove method from docs diff --git a/telegram/ext/handler.py b/telegram/ext/handler.py index b01aa58b74e..651d33787aa 100644 --- a/telegram/ext/handler.py +++ b/telegram/ext/handler.py @@ -24,6 +24,8 @@ class Handler(object): Attributes: callback (:obj:`callable`): The callback function for this handler. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. pass_update_queue (:obj:`bool`): Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Determines whether ``job_queue`` will be passed to @@ -51,6 +53,9 @@ class Handler(object): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -70,17 +75,34 @@ class Handler(object): """ + def __new__(cls, *args, **kwargs): + instance = super(Handler, cls).__new__(cls) + check_update = instance.check_update + + def check_update_with_filters(update): + if instance.roles: + if instance.roles(update): + return check_update(update) + else: + return False + return check_update(update) + + instance.check_update = check_update_with_filters + return instance + def __init__(self, callback, pass_update_queue=False, pass_job_queue=False, pass_user_data=False, - pass_chat_data=False): + pass_chat_data=False, + roles=None): self.callback = callback self.pass_update_queue = pass_update_queue self.pass_job_queue = pass_job_queue self.pass_user_data = pass_user_data self.pass_chat_data = pass_chat_data + self.roles = roles def check_update(self, update): """ diff --git a/telegram/ext/inlinequeryhandler.py b/telegram/ext/inlinequeryhandler.py index adef02cc7c5..dba24fad0b8 100644 --- a/telegram/ext/inlinequeryhandler.py +++ b/telegram/ext/inlinequeryhandler.py @@ -32,6 +32,8 @@ class InlineQueryHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. pass_update_queue (:obj:`bool`): Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Determines whether ``job_queue`` will be passed to @@ -65,6 +67,9 @@ class InlineQueryHandler(Handler): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -103,13 +108,15 @@ def __init__(self, pass_groups=False, pass_groupdict=False, pass_user_data=False, - pass_chat_data=False): + pass_chat_data=False, + roles=None): super(InlineQueryHandler, self).__init__( callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, - pass_chat_data=pass_chat_data) + pass_chat_data=pass_chat_data, + roles=roles) if isinstance(pattern, string_types): pattern = re.compile(pattern) diff --git a/telegram/ext/messagehandler.py b/telegram/ext/messagehandler.py index b9bc0487ad8..33528136a5f 100644 --- a/telegram/ext/messagehandler.py +++ b/telegram/ext/messagehandler.py @@ -34,6 +34,8 @@ class MessageHandler(Handler): filters (:obj:`Filter`): Only allow updates with these Filters. See :mod:`telegram.ext.filters` for a full list of all available filters. callback (:obj:`callable`): The callback function for this handler. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. pass_update_queue (:obj:`bool`): Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Determines whether ``job_queue`` will be passed to @@ -75,6 +77,9 @@ class MessageHandler(Handler): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -115,14 +120,16 @@ def __init__(self, pass_chat_data=False, message_updates=None, channel_post_updates=None, - edited_updates=None): + edited_updates=None, + roles=None): super(MessageHandler, self).__init__( callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, - pass_chat_data=pass_chat_data) + pass_chat_data=pass_chat_data, + roles=roles) if message_updates is False and channel_post_updates is False and edited_updates is False: raise ValueError( 'message_updates, channel_post_updates and edited_updates are all False') diff --git a/telegram/ext/picklepersistence.py b/telegram/ext/picklepersistence.py index 55e5e55f201..2159b740fb4 100644 --- a/telegram/ext/picklepersistence.py +++ b/telegram/ext/picklepersistence.py @@ -22,6 +22,7 @@ from copy import deepcopy from telegram.ext import BasePersistence +from .roles import Roles class PicklePersistence(BasePersistence): @@ -36,10 +37,12 @@ class PicklePersistence(BasePersistence): persistence class. store_bot_data (:obj:`bool`): Optional. Whether bot_data should be saved by this persistence class. - single_file (:obj:`bool`): Optional. When ``False`` will store 3 sperate files of - `filename_user_data`, `filename_chat_data` and `filename_conversations`. Default is - ``True``. - on_flush (:obj:`bool`, optional): When ``True`` will only save to file when :meth:`flush` + store_roles (:obj:`bool`): Optional. Whether roles should be saved by this persistence + class. + single_file (:obj:`bool`): Optional. When ``False`` will store 5 sperate files of + `filename_bot_data`, `filename_user_data`, `filename_chat_data`, + `filename_conversations` and `filename_roles`. Default is ``True``. + on_flush (:obj:`bool`): Optional. When ``True`` will only save to file when :meth:`flush` is called and keep data in memory until that happens. When ``False`` will store data on any transaction *and* on call fo :meth:`flush`. Default is ``False``. @@ -52,9 +55,11 @@ class PicklePersistence(BasePersistence): persistence class. Default is ``True``. store_bot_data (:obj:`bool`, optional): Whether bot_data should be saved by this persistence class. Default is ``True`` . - single_file (:obj:`bool`, optional): When ``False`` will store 3 sperate files of - `filename_user_data`, `filename_chat_data` and `filename_conversations`. Default is - ``True``. + store_roles (:obj:`bool`, optional): Whether roles should be saved by this persistence + class. Default is ``True``. + single_file (:obj:`bool`, optional): When ``False`` will store 5 sperate files of + `filename_bot_data`, `filename_user_data`, `filename_chat_data`, + `filename_conversations` and `filename_roles`. Default is ``True``. on_flush (:obj:`bool`, optional): When ``True`` will only save to file when :meth:`flush` is called and keep data in memory until that happens. When ``False`` will store data on any transaction *and* on call fo :meth:`flush`. Default is ``False``. @@ -65,16 +70,19 @@ def __init__(self, filename, store_chat_data=True, store_bot_data=True, single_file=True, - on_flush=False): + on_flush=False, + store_roles=True): super(PicklePersistence, self).__init__(store_user_data=store_user_data, store_chat_data=store_chat_data, - store_bot_data=store_bot_data) + store_bot_data=store_bot_data, + store_roles=store_roles) self.filename = filename self.single_file = single_file self.on_flush = on_flush self.user_data = None self.chat_data = None self.bot_data = None + self.roles = None self.conversations = None def load_singlefile(self): @@ -86,12 +94,16 @@ def load_singlefile(self): self.chat_data = defaultdict(dict, data['chat_data']) # For backwards compatibility with files not containing bot data self.bot_data = data.get('bot_data', {}) + self.roles = data.get('roles', Roles(None)) + if self.roles: + self.roles = Roles.decode_from_json(self.roles, None) self.conversations = data['conversations'] except IOError: self.conversations = {} self.user_data = defaultdict(dict) self.chat_data = defaultdict(dict) self.bot_data = {} + self.roles = Roles(None) except pickle.UnpicklingError: raise TypeError("File {} does not contain valid pickle data".format(filename)) except Exception: @@ -111,7 +123,9 @@ def load_file(self, filename): def dump_singlefile(self): with open(self.filename, "wb") as f: data = {'conversations': self.conversations, 'user_data': self.user_data, - 'chat_data': self.chat_data, 'bot_data': self.bot_data} + 'chat_data': self.chat_data, 'bot_data': self.bot_data, + # Roles have locks, so we just use the json encoding + 'roles': self.roles.encode_to_json()} pickle.dump(data, f) def dump_file(self, filename, data): @@ -176,6 +190,31 @@ def get_bot_data(self): self.load_singlefile() return deepcopy(self.bot_data) + def get_roles(self): + """Returns the roles created from the pickle file if it exists or an empty + :class:`telegram.ext.Roles` instance. + + Warning: + The produced roles instance usually will have no bot assigned. Use + :attr:`telegram.ext.Roles.set_bot` to set it. + + Returns: + :class:`telegram.ext.Roles`: The restored roles. + """ + if self.roles: + pass + elif not self.single_file: + filename = "{}_roles".format(self.filename) + data = self.load_file(filename) + if not data: + data = Roles(None) + else: + data = Roles.decode_from_json(data, None) + self.roles = data + else: + self.load_singlefile() + return deepcopy(self.roles) + def get_conversations(self, name): """Returns the conversations from the pickle file if it exsists or an empty defaultdict. @@ -273,11 +312,28 @@ def update_bot_data(self, data): else: self.dump_singlefile() + def update_roles(self, data): + """Will update the roles (if changed) and depending on :attr:`on_flush` save the + pickle file. + + Args: + data (:class:`telegram.ext.Roles`): The :attr:`telegram.ext.dispatcher.roles` . + """ + if self.roles == data: + return + self.roles = deepcopy(data) + if not self.on_flush: + if not self.single_file: + filename = "{}_roles".format(self.filename) + self.dump_file(filename, self.roles.encode_to_json()) + else: + self.dump_singlefile() + def flush(self): """ Will save all data in memory to pickle file(s). """ if self.single_file: - if self.user_data or self.chat_data or self.conversations: + if self.user_data or self.chat_data or self.conversations or self.roles: self.dump_singlefile() else: if self.user_data: @@ -286,5 +342,7 @@ def flush(self): self.dump_file("{}_chat_data".format(self.filename), self.chat_data) if self.bot_data: self.dump_file("{}_bot_data".format(self.filename), self.bot_data) + if self.roles: + self.dump_file("{}_roles".format(self.filename), self.roles.encode_to_json()) if self.conversations: self.dump_file("{}_conversations".format(self.filename), self.conversations) diff --git a/telegram/ext/pollanswerhandler.py b/telegram/ext/pollanswerhandler.py index 7a7ccfed129..ba0bde5f6e1 100644 --- a/telegram/ext/pollanswerhandler.py +++ b/telegram/ext/pollanswerhandler.py @@ -26,6 +26,8 @@ class PollAnswerHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. pass_update_queue (:obj:`bool`): Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Determines whether ``job_queue`` will be passed to @@ -53,6 +55,9 @@ class PollAnswerHandler(Handler): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` diff --git a/telegram/ext/precheckoutqueryhandler.py b/telegram/ext/precheckoutqueryhandler.py index 900eedf8aea..2b68a2f1a6b 100644 --- a/telegram/ext/precheckoutqueryhandler.py +++ b/telegram/ext/precheckoutqueryhandler.py @@ -27,6 +27,8 @@ class PreCheckoutQueryHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. pass_update_queue (:obj:`bool`): Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Determines whether ``job_queue`` will be passed to @@ -54,6 +56,9 @@ class PreCheckoutQueryHandler(Handler): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` DEPRECATED: Please switch to context based callbacks. diff --git a/telegram/ext/regexhandler.py b/telegram/ext/regexhandler.py index 874125c6dc3..ec60653cdac 100644 --- a/telegram/ext/regexhandler.py +++ b/telegram/ext/regexhandler.py @@ -36,6 +36,8 @@ class RegexHandler(MessageHandler): Attributes: pattern (:obj:`str` | :obj:`Pattern`): The regex pattern. callback (:obj:`callable`): The callback function for this handler. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. pass_groups (:obj:`bool`): Determines whether ``groups`` will be passed to the callback function. pass_groupdict (:obj:`bool`): Determines whether ``groupdict``. will be passed to @@ -64,6 +66,9 @@ class RegexHandler(MessageHandler): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). pass_groups (:obj:`bool`, optional): If the callback should be passed the result of ``re.match(pattern, data).groups()`` as a keyword argument called ``groups``. Default is ``False`` @@ -106,7 +111,8 @@ def __init__(self, allow_edited=False, message_updates=True, channel_post_updates=False, - edited_updates=False): + edited_updates=False, + roles=None): warnings.warn('RegexHandler is deprecated. See https://git.io/fxJuV for more info', TelegramDeprecationWarning, stacklevel=2) @@ -118,7 +124,8 @@ def __init__(self, pass_chat_data=pass_chat_data, message_updates=message_updates, channel_post_updates=channel_post_updates, - edited_updates=edited_updates) + edited_updates=edited_updates, + roles=roles) self.pass_groups = pass_groups self.pass_groupdict = pass_groupdict diff --git a/telegram/ext/roles.py b/telegram/ext/roles.py new file mode 100644 index 00000000000..117efc5cbbe --- /dev/null +++ b/telegram/ext/roles.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2020 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains the class Role, which allows to restrict access to handlers.""" +import time + +try: + import ujson as json +except ImportError: + import json + +from copy import deepcopy + +from telegram import ChatMember, TelegramError, Bot +from .filters import Filters + + +class Role(Filters.user): + """This class represents a security level used by :class:`telegram.ext.Roles`. Roles have a + hierarchy, i.e. a role can do everthing, its child roles can do. To compare two roles you may + use the following syntax:: + + role_1 < role_2 + role 2 >= role_3 + + ``role_1 < role_2`` will be true, if ``role_2`` is a parent of ``role_1`` or a parents of one + of ``role_1`` s parents and similarly for ``role_1 < role_2``. + ``role_2 >= role_3`` will be true, if ``role_3`` is ``role_2`` or ``role_2 > role_3`` and + similarly for ``role_2 <= role_3``. + + Note: + If two roles are not related, i.e. neither is a (indirect) parent of the other, comparing + the roles will always yield ``False``. + + Warning: + ``role_1 == role_2`` does not test for the hierarchical order of the roles, but in fact if + both roles are the same object. To test for equality in terms of hierarchical order, i.e. + if :attr:`child_roles` and :attr:`chat_ids` coincide, use :attr:`equals`. + + Roles can be combined using bitwise operators: + + And: + + >>> (Roles(name='group_1') & Roles(name='user_2')) + + Grants access only for ``user_2`` within the chat ``group_1``. + + Or: + + >>> (Roles(name='group_1') | Roles(name='user_2')) + + Grants access for ``user_2`` and the whole chat ``group_1``. + + Not: + + >>> ~ Roles(name='user_1') + + Grants access to everyone except ``user_1`` + + Note: + Negated roles do `not` exclude their parent roles. E.g. with + + >>> ~ Roles(name='user_1', parent_roles=Role(name='user_2')) + + ``user_2`` will still have access, where ``user_1`` is restricted. Child roles, however + will be excluded. + + Also works with more than two roles: + + >>> (Roles(name='group_1') & (Roles(name='user_2') | Roles(name='user_3'))) + >>> Roles(name='group_1') & (~ FRoles(name='user_2')) + + Note: + Roles use the same short circuiting logic that pythons `and`, `or` and `not`. + This means that for example: + + >>> Role(chat_ids=123) | Role(chat_ids=456) + + With an update from user ``123``, will only ever evaluate the first role. + + Attributes: + chat_ids (set(:obj:`int`)): The ids of the users/chats of this role. Updates + will only be parsed, if the id of :attr:`telegram.Update.effective_user` or + :attr:`telegram.Update.effective_chat` respectiveley is listed here. May be empty. + parent_roles (set(:class:`telegram.ext.Role`)): Parent roles of this role. All the parent + roles can do anything, this role can do. May be empty. + child_roles (set(:class:`telegram.ext.Role`)): Child roles of this role. This role can do + anything, its child roles can do. May be empty. + + Args: + chat_ids (:obj:`int` | iterable(:obj:`int`), optional): The ids of the users/chats of this + role. Updates will only be parsed, if the id of :attr:`telegram.Update.effective_user` + or :attr:`telegram.Update.effective_chat` respectiveley is listed here. + parent_roles (:class:`telegram.ext.Role` | set(:class:`telegram.ext.Role`)), optional): + Parent roles of this role. + child_roles (:class:`telegram.ext.Role` | set(:class:`telegram.ext.Role`)), optional): + Child roles of this role. + name (:obj:`str`, optional): A name for this role. + + """ + update_filter = True + + def __init__(self, chat_ids=None, parent_roles=None, child_roles=None, name=None): + if chat_ids is None: + chat_ids = set() + super(Role, self).__init__(chat_ids) + self.name = name + + self.parent_roles = set() + if isinstance(parent_roles, Role): + self.add_parent_role(parent_roles) + elif parent_roles is not None: + for pr in parent_roles: + self.add_parent_role(pr) + + self.child_roles = set() + if isinstance(child_roles, Role): + self.add_child_role(child_roles) + elif child_roles is not None: + for cr in child_roles: + self.add_child_role(cr) + + self._inverted = False + + def __invert__(self): + self._inverted = True + return super(Role, self).__invert__() + + @property + def chat_ids(self): + return self.user_ids + + @chat_ids.setter + def chat_ids(self, chat_id): + self.user_ids = chat_id + + def filter_children(self, user, chat): + # filters only downward + if user and user.id in self.chat_ids: + return True + if chat and chat.id in self.chat_ids: + return True + if any([child.filter_children(user, chat) for child in self.child_roles]): + return True + + def filter(self, update): + user = update.effective_user + chat = update.effective_chat + if user and user.id in self.chat_ids: + return True + if chat and chat.id in self.chat_ids: + return True + if user or chat: + if self._inverted: + # If this is an inverted role (i.e. ~role) and we arrived here, the user is + # either ... + if self.filter_children(user, chat): + # ... in a child role of this. In this case, and must be excluded. Since the + # output of this will be negated, return True + return True + # ... not in a child role of this and must *nto* be excluded. In particular, we + # dont want to exclude the parents (see below). Since the output of this will be + # negated, return False + return False + else: + return any([parent(update) for parent in self.parent_roles]) + + def add_member(self, chat_id): + """Adds a user/chat to this role. Will do nothing, if user/chat is already present. + + Args: + chat_id (:obj:`int`): The users/chats id + """ + self.chat_ids.add(chat_id) + + def kick_member(self, chat_id): + """Kicks a user/chat to from role. Will do nothing, if user/chat is not present. + + Args: + chat_id (:obj:`int`): The users/chats id + """ + self.chat_ids.discard(chat_id) + + def add_parent_role(self, parent_role): + """Adds a parent role to this role. Also adds this role to the parents child roles. Will do + nothing, if parent role is already present. + + Args: + parent_role (:class:`telegram.ext.Role`): The parent role + """ + if self is parent_role: + raise ValueError('You must not add a role as its own parent!') + if self >= parent_role: + raise ValueError('You must not add a child role as a parent!') + self.parent_roles.add(parent_role) + parent_role.child_roles.add(self) + + def remove_parent_role(self, parent_role): + """Removes a parent role from this role. Also removes this role from the parents child + roles. Will do nothing, if parent role is not present. + + Args: + parent_role (:class:`telegram.ext.Role`): The parent role + """ + self.parent_roles.discard(parent_role) + parent_role.child_roles.discard(self) + + def add_child_role(self, child_role): + """Adds a child role to this role. Also adds this role to the childs parent roles. Will do + nothing, if child role is already present. + + Args: + child_role (:class:`telegram.ext.Role`): The child role + """ + if self is child_role: + raise ValueError('You must not add a role as its own child!') + if self <= child_role: + raise ValueError('You must not add a parent role as a child!') + self.child_roles.add(child_role) + child_role.parent_roles.add(self) + + def remove_child_role(self, child_role): + """Removes a child role from this role. Also removes this role from the childs parent + roles. Will do nothing, if child role is not present. + + Args: + child_role (:class:`telegram.ext.Role`): The child role + """ + self.child_roles.discard(child_role) + child_role.parent_roles.discard(self) + + def __lt__(self, other): + # Test for hierarchical order + if isinstance(other, Role): + return any([pr <= other for pr in self.parent_roles]) + return False + + def __le__(self, other): + # Test for hierarchical order + return self is other or self < other + + def __gt__(self, other): + # Test for hierarchical order + if isinstance(other, Role): + return any([self >= pr for pr in other.parent_roles]) + return False + + def __ge__(self, other): + # Test for hierarchical order + return self is other or self > other + + def __eq__(self, other): + return self is other + + def __ne__(self, other): + return not self == other + + def equals(self, other): + """Test if two roles are equal in terms of hierarchy. Returns ``True``, if the chat_ids + coincide and the child roles are equal in terms of this method. Note, that the result of + this comparison may change by adding or removing child/parent roles or members. + + Args: + other (:class:`telegram.ext.Role`): + + Returns: + :obj:`bool`: + """ + if self.chat_ids == other.chat_ids: + if len(self.child_roles) == len(other.child_roles): + if len(self.child_roles) == 0: + return True + for cr in self.child_roles: + if not any([cr.equals(ocr) for ocr in other.child_roles]): + return False + for ocr in other.child_roles: + if not any([ocr.equals(cr) for cr in self.child_roles]): + return False + return True + return False + + def __hash__(self): + return id(self) + + def __deepcopy__(self, memo): + new_role = Role(chat_ids=self.chat_ids, name=self.name) + memo[id(self)] = new_role + for pr in self.parent_roles: + new_role.add_parent_role(deepcopy(pr, memo)) + for cr in self.child_roles: + new_role.add_child_role(deepcopy(cr, memo)) + return new_role + + def __repr__(self): + if self.name: + return 'Role({})'.format(self.name) + elif self.chat_ids: + return 'Role({})'.format(self.chat_ids) + else: + return 'Role({})' + + +class ChatAdminsRole(Role): + """A :class:`telegram.ext.Role` that allows only the administrators of a chat. Private chats + always allowed. To minimize the number of API calls, for each chat the admins will be cached. + + Attributes: + parent_roles (set(:class:`telegram.ext.Role`)): Parent roles of this role. All the parent + roles can do anything, this role can do. May be empty. + child_roles (set(:class:`telegram.ext.Role`)): Child roles of this role. This role can do + anything, its child roles can do. May be empty. + timeout (:obj:`int`): The caching timeout in seconds. For each chat, the admins will be + cached and refreshed only after this timeout. + + Args: + bot (:class:`telegram.Bot`): A bot to use for getting the administrators of a chat. + timeout (:obj:`int`, optional): The caching timeout in seconds. For each chat, the admins + will be cached and refreshed only after this timeout. Defaults to ``1800`` (half an + hour). + + """ + def __init__(self, bot, timeout=1800): + super(ChatAdminsRole, self).__init__(name='chat_admins') + self.bot = bot + self.cache = {} + self.timeout = timeout + + def filter(self, update): + user = update.effective_user + chat = update.effective_chat + if user and chat: + # Always true in private chats + if user.id == chat.id: + return True + # Check for cached info first + if (self.cache.get(chat.id, None) + and (time.time() - self.cache[chat.id][0]) < self.timeout): + return user.id in self.cache[chat.id][1] + admins = [m.user.id for m in self.bot.get_chat_administrators(chat.id)] + self.cache[chat.id] = (time.time(), admins) + return user.id in admins + + def __deepcopy__(self, memo): + new_role = super(ChatAdminsRole, self).__deepcopy__(memo) + new_role.bot = self.bot + new_role.cache = self.cache + new_role.timeout = self.timeout + return new_role + + +class ChatCreatorRole(Role): + """A :class:`telegram.ext.Role` that allows only the creator of a chat. Private chats are + always allowed. To minimize the number of API calls, for each chat the creator will be saved. + + Attributes: + parent_roles (set(:class:`telegram.ext.Role`)): Parent roles of this role. All the parent + roles can do anything, this role can do. May be empty. + child_roles (set(:class:`telegram.ext.Role`)): Child roles of this role. This role can do + anything, its child roles can do. May be empty. + + Args: + bot (:class:`telegram.Bot`): A bot to use for getting the creator of a chat. + + """ + def __init__(self, bot): + super(ChatCreatorRole, self).__init__(name='chat_creator') + self.bot = bot + self.cache = {} + + def filter(self, update): + user = update.effective_user + chat = update.effective_chat + if user and chat: + # Always true in private chats + if user.id == chat.id: + return True + # Check for cached info first + if self.cache.get(chat.id, None): + return user.id == self.cache[chat.id] + try: + member = self.bot.get_chat_member(chat.id, user.id) + if member.status == ChatMember.CREATOR: + self.cache[chat.id] = user.id + return True + return False + except TelegramError: + # user is not a chat member or bot has no access + return False + + def __deepcopy__(self, memo): + new_role = super(ChatCreatorRole, self).__deepcopy__(memo) + new_role.bot = self.bot + new_role.cache = self.cache + return new_role + + +class Roles(dict): + """This class represents a collection of :class:`telegram.ext.Role` s that can be used to + manage access control to functionality of a bot. Each role can be accessed by its name, e.g.:: + + roles.add_role('my_role') + role = roles['my_role'] + + Note: + In fact, :class:`telegram.ext.Roles` inherits from :obj:`dict` and thus provides most + methods needed for the common use cases. Methods that are *not* supported are: + ``__delitem__``, ``__setitem__``, ``setdefault``, ``update``, ``pop``, ``popitem``, + ``clear`` and ``copy``. + Please use :attr:`add_role` and :attr:`remove_role` instead. + + Attributes: + ADMINS (:class:`telegram.ext.Role`): A role reserved for administrators of the bot. All + roles added to this instance will be child roles of :attr:`ADMINS`. + CHAT_ADMINS (:class:`telegram.ext.roles.ChatAdminsRole`): Use this role to restrict access + to admins of a chat. Handlers with this role wont handle updates that don't have an + ``effective_chat``. Admins are cached for each chat. + CHAT_CREATOR (:class:`telegram.ext.roles.ChatCreatorRole`): Use this role to restrict + access to the creator of a chat. Handlers with this role wont handle updates that don't + have an ``effective_chat``. + + Args: + bot (:class:`telegram.Bot`): A bot associated with this instance. + + """ + + def __init__(self, bot): + super(Roles, self).__init__() + self.bot = bot + self.ADMINS = Role(name='admins') + self.CHAT_ADMINS = ChatAdminsRole(bot=self.bot) + self.CHAT_CREATOR = ChatCreatorRole(bot=self.bot) + + def set_bot(self, bot): + """If for some reason you can't pass the bot on initialization, you can set it with this + method. Make sure to set the bot before the first call of :attr:`CHAT_ADMINS` or + :attr:`CHAT_CREATOR`. + + Args: + bot (:class:`telegram.Bot`): The bot to set. + + Raises: + ValueError + """ + if isinstance(self.bot, Bot): + raise ValueError('Bot is already set for this Roles instance') + self.bot = bot + + def __delitem__(self, key): + """""" # Remove method from docs + raise NotImplementedError('Please use remove_role.') + + def __setitem__(self, key, value): + """""" # Remove method from docs + raise ValueError('Roles are immutable!') + + def setitem(self, key, value): + super(Roles, self).__setitem__(key, value) + + def setdefault(self, key, value=None): + """""" # Remove method from docs + raise ValueError('Roles are immutable!') + + def update(self, other): + """""" # Remove method from docs + raise ValueError('Roles are immutable!') + + def pop(self, key, default=None): + """""" # Remove method from docs + raise NotImplementedError('Please use remove_role.') + + def _pop(self, key, default=None): + return super(Roles, self).pop(key, default) + + def popitem(self, key): + """""" # Remove method from docs + raise NotImplementedError('Please use remove_role.') + + def clear(self): + """""" # Remove method from docs + raise NotImplementedError('Please use remove_role.') + + def copy(self): + """""" # Remove method from docs + raise NotImplementedError + + def add_admin(self, chat_id): + """Adds a user/chat to the :attr:`ADMINS` role. Will do nothing if user/chat is already + present. + + Args: + chat_id (:obj:`int`): The users id + """ + self.ADMINS.add_member(chat_id) + + def kick_admin(self, chat_id): + """Kicks a user/chat from the :attr:`ADMINS` role. Will do nothing if user/chat is not + present. + + Args: + chat_id (:obj:`int`): The users/chats id + """ + self.ADMINS.kick_member(chat_id) + + def add_role(self, name, chat_ids=None, parent_roles=None, child_roles=None): + """Creates and registers a new role. :attr:`ADMINS` will automatically be added to the + roles parent roles, i.e. admins can do everything. The role can be accessed by it's + name. + + Args: + name (:obj:`str`, optional): A name for this role. + chat_ids (:obj:`int` | iterable(:obj:`int`), optional): The ids of the users/chats of + this role. + parent_roles (:class:`telegram.ext.Role` | set(:class:`telegram.ext.Role`), optional): + Parent roles of this role. + child_roles (:class:`telegram.ext.Role` | set(:class:`telegram.ext.Role`), optional): + Child roles of this role. + + Raises: + ValueError + """ + if name in self: + raise ValueError('Role name is already taken.') + role = Role(chat_ids=chat_ids, parent_roles=parent_roles, + child_roles=child_roles, name=name) + self.setitem(name, role) + role.add_parent_role(self.ADMINS) + + def remove_role(self, name): + """Removes a role. + + Args: + name (:obj:`str`): The name of the role to be removed + """ + role = self._pop(name, None) + role.remove_parent_role(self.ADMINS) + + def __eq__(self, other): + if isinstance(other, Roles): + for name, role in self.items(): + orole = other.get(name, None) + if not orole: + return False + if not role.equals(orole): + return False + if any([self.get(name, None) is None for name in other]): + return False + return self.ADMINS.equals(other.ADMINS) + return False + + def __ne__(self, other): + return not self == other + + def __deepcopy__(self, memo): + new_roles = Roles(self.bot) + new_roles.CHAT_ADMINS.timeout = self.CHAT_ADMINS.timeout + memo[id(self)] = new_roles + for chat_id in self.ADMINS.chat_ids: + new_roles.add_admin(chat_id) + for role in self.values(): + new_roles.add_role(name=role.name, chat_ids=role.chat_ids) + for pr in role.parent_roles: + if pr is not self.ADMINS: + new_roles[role.name].add_parent_role(deepcopy(pr, memo)) + for cr in role.child_roles: + new_roles[role.name].add_child_role(deepcopy(cr, memo)) + return new_roles + + def encode_to_json(self): + """Helper method to encode a roles object to a JSON-serializable way. Use + :attr:`decode_from_json` to decode. + + Args: + roles (:class:`telegram.ext.Roles`): The roles object to transofrm to JSON. + + Returns: + :obj:`str`: The JSON-serialized roles object + """ + def _encode_role_to_json(role, memo, trace): + id_ = id(role) + if id_ not in memo and id_ not in trace: + trace.append(id_) + inner_tmp = {'name': role.name, 'chat_ids': sorted(role.chat_ids)} + inner_tmp['parent_roles'] = [ + _encode_role_to_json(pr, memo, trace) for pr in role.parent_roles + ] + inner_tmp['child_roles'] = [ + _encode_role_to_json(cr, memo, trace) for cr in role.child_roles + ] + memo[id_] = inner_tmp + return id_ + + tmp = {'admins': id(self.ADMINS), 'admins_timeout': self.CHAT_ADMINS.timeout, + 'roles': [], 'memo': {}} + tmp['roles'] = [_encode_role_to_json(self[name], tmp['memo'], []) for name in self] + return json.dumps(tmp) + + @staticmethod + def decode_from_json(json_string, bot): + """Helper method to decode a roles object to a JSON-string created with + :attr:`encode_roles_to_json`. + + Args: + json_string (:obj:`str`): The roles object as JSON string. + bot (:class:`telegram.Bot`): The bot to be passed to the roles object. + + Returns: + :class:`telegram.ext.Roles`: The roles object after decoding + """ + def _decode_role_from_json(id_, memo): + id_ = str(id_) + if isinstance(memo[id_], Role): + return memo[id_] + + tmp = memo[id_] + role = Role(name=tmp['name'], chat_ids=tmp['chat_ids']) + memo[id_] = role + for pid in tmp['parent_roles']: + role.add_parent_role(_decode_role_from_json(pid, memo)) + for cid in tmp['child_roles']: + role.add_child_role(_decode_role_from_json(cid, memo)) + return role + + tmp = json.loads(json_string) + memo = tmp['memo'] + roles = Roles(bot) + roles.ADMINS = _decode_role_from_json(tmp['admins'], memo) + roles.CHAT_ADMINS.timeout = tmp['admins_timeout'] + for id_ in tmp['roles']: + role = _decode_role_from_json(id_, memo) + roles.setitem(role.name, role) + return roles diff --git a/telegram/ext/shippingqueryhandler.py b/telegram/ext/shippingqueryhandler.py index a6d2603f126..5755fd48b6d 100644 --- a/telegram/ext/shippingqueryhandler.py +++ b/telegram/ext/shippingqueryhandler.py @@ -27,6 +27,8 @@ class ShippingQueryHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + roles (:obj:`telegram.ext.Role`): Optional. A user role used to restrict access to the + handler. pass_update_queue (:obj:`bool`): Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Determines whether ``job_queue`` will be passed to @@ -54,6 +56,9 @@ class ShippingQueryHandler(Handler): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. + roles (:obj:`telegram.ext.Role`, optional): A user role used to restrict access to the + handler. Roles can be combined using bitwise operators (& for and, | for or, ~ for + not). pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` diff --git a/tests/conftest.py b/tests/conftest.py index a40fb9756c0..df8a5dc54d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ from telegram import (Bot, Message, User, Chat, MessageEntity, Update, InlineQuery, CallbackQuery, ShippingQuery, PreCheckoutQuery, ChosenInlineResult) -from telegram.ext import Dispatcher, JobQueue, Updater, BaseFilter, Defaults +from telegram.ext import Dispatcher, JobQueue, Updater, BaseFilter, Defaults, Role from telegram.utils.helpers import _UtcOffsetTimezone from tests.bots import get_bot @@ -163,6 +163,11 @@ def class_thumb_file(): f.close() +@pytest.fixture(scope='function') +def role(): + yield Role(0) + + def pytest_configure(config): if sys.version_info >= (3,): config.addinivalue_line('filterwarnings', 'ignore::ResourceWarning') diff --git a/tests/test_callbackcontext.py b/tests/test_callbackcontext.py index ce76b5079ce..f7fbd475625 100644 --- a/tests/test_callbackcontext.py +++ b/tests/test_callbackcontext.py @@ -36,6 +36,7 @@ def test_from_job(self, cdp): assert callback_context.chat_data is None assert callback_context.user_data is None assert callback_context.bot_data is cdp.bot_data + assert callback_context.roles is cdp.roles assert callback_context.bot is cdp.bot assert callback_context.job_queue is cdp.job_queue assert callback_context.update_queue is cdp.update_queue @@ -49,6 +50,7 @@ def test_from_update(self, cdp): assert callback_context.user_data == {} assert callback_context.bot_data is cdp.bot_data assert callback_context.bot is cdp.bot + assert callback_context.roles is cdp.roles assert callback_context.job_queue is cdp.job_queue assert callback_context.update_queue is cdp.update_queue @@ -58,6 +60,7 @@ def test_from_update(self, cdp): callback_context.chat_data['test'] = 'chat' callback_context.user_data['test'] = 'user' + assert callback_context_same_user_chat.roles is callback_context.roles assert callback_context_same_user_chat.bot_data is callback_context.bot_data assert callback_context_same_user_chat.chat_data is callback_context.chat_data assert callback_context_same_user_chat.user_data is callback_context.user_data @@ -67,6 +70,7 @@ def test_from_update(self, cdp): callback_context_other_user_chat = CallbackContext.from_update(update_other_user_chat, cdp) + assert callback_context_other_user_chat.roles is callback_context.roles assert callback_context_other_user_chat.bot_data is callback_context.bot_data assert callback_context_other_user_chat.chat_data is not callback_context.chat_data assert callback_context_other_user_chat.user_data is not callback_context.user_data @@ -78,6 +82,7 @@ def test_from_update_not_update(self, cdp): assert callback_context.user_data is None assert callback_context.bot_data is cdp.bot_data assert callback_context.bot is cdp.bot + assert callback_context.roles is cdp.roles assert callback_context.job_queue is cdp.job_queue assert callback_context.update_queue is cdp.update_queue @@ -86,6 +91,7 @@ def test_from_update_not_update(self, cdp): assert callback_context.chat_data is None assert callback_context.user_data is None assert callback_context.bot_data is cdp.bot_data + assert callback_context.roles is cdp.roles assert callback_context.bot is cdp.bot assert callback_context.job_queue is cdp.job_queue assert callback_context.update_queue is cdp.update_queue @@ -101,6 +107,7 @@ def test_from_error(self, cdp): assert callback_context.chat_data == {} assert callback_context.user_data == {} assert callback_context.bot_data is cdp.bot_data + assert callback_context.roles is cdp.roles assert callback_context.bot is cdp.bot assert callback_context.job_queue is cdp.job_queue assert callback_context.update_queue is cdp.update_queue @@ -125,6 +132,8 @@ def test_data_assignment(self, cdp): callback_context.user_data = {} with pytest.raises(AttributeError): callback_context.chat_data = "test" + with pytest.raises(AttributeError): + callback_context.roles = "test" def test_dispatcher_attribute(self, cdp): callback_context = CallbackContext(cdp) diff --git a/tests/test_callbackqueryhandler.py b/tests/test_callbackqueryhandler.py index 66fe5359e8f..232d65fae92 100644 --- a/tests/test_callbackqueryhandler.py +++ b/tests/test_callbackqueryhandler.py @@ -91,6 +91,7 @@ def callback_context(self, update, context): and isinstance(context.user_data, dict) and context.chat_data is None and isinstance(context.bot_data, dict) + and isinstance(context.roles, dict) and isinstance(update.callback_query, CallbackQuery)) def callback_context_pattern(self, update, context): @@ -116,6 +117,13 @@ def test_with_pattern(self, callback_query): callback_query.callback_query.data = 'nothing here' assert not handler.check_update(callback_query) + def test_with_role(self, callback_query, role): + handler = CallbackQueryHandler(self.callback_basic, roles=role) + assert not handler.check_update(callback_query) + + role.chat_ids = 1 + assert handler.check_update(callback_query) + def test_with_passing_group_dict(self, dp, callback_query): handler = CallbackQueryHandler(self.callback_group, pattern='(?P.*)est(?P.*)', diff --git a/tests/test_choseninlineresulthandler.py b/tests/test_choseninlineresulthandler.py index c8e0711ebec..f03b5bd431c 100644 --- a/tests/test_choseninlineresulthandler.py +++ b/tests/test_choseninlineresulthandler.py @@ -88,6 +88,7 @@ def callback_context(self, update, context): and isinstance(context.user_data, dict) and context.chat_data is None and isinstance(context.bot_data, dict) + and isinstance(context.roles, dict) and isinstance(update.chosen_inline_result, ChosenInlineResult)) def test_basic(self, dp, chosen_inline_result): @@ -98,6 +99,13 @@ def test_basic(self, dp, chosen_inline_result): dp.process_update(chosen_inline_result) assert self.test_flag + def test_with_role(self, chosen_inline_result, role): + handler = ChosenInlineResultHandler(self.callback_basic, roles=role) + assert not handler.check_update(chosen_inline_result) + + role.chat_ids = 1 + assert handler.check_update(chosen_inline_result) + def test_pass_user_or_chat_data(self, dp, chosen_inline_result): handler = ChosenInlineResultHandler(self.callback_data_1, pass_user_data=True) diff --git a/tests/test_commandhandler.py b/tests/test_commandhandler.py index b37c76594c2..2295b2ed6d7 100644 --- a/tests/test_commandhandler.py +++ b/tests/test_commandhandler.py @@ -89,6 +89,7 @@ def callback_context(self, update, context): and isinstance(context.user_data, dict) and isinstance(context.chat_data, dict) and isinstance(context.bot_data, dict) + and isinstance(context.roles, dict) and isinstance(update.message, Message)) def callback_context_args(self, update, context): @@ -272,6 +273,13 @@ def test_context_multiple_regex(self, cdp, command): filters=Filters.regex('one') & Filters.regex('two')) self._test_context_args_or_regex(cdp, handler, command) + def test_with_role(self, command, role): + handler = self.make_default_handler(roles=role) + assert not is_match(handler, make_command_update('/test')) + + role.chat_ids = 1 + assert is_match(handler, make_command_update('/test')) + # ----------------------------- PrefixHandler ----------------------------- @@ -421,3 +429,12 @@ def test_context_multiple_regex(self, cdp, prefix_message_text): filters=Filters.regex('one') & Filters.regex( 'two')) self._test_context_args_or_regex(cdp, handler, prefix_message_text) + + def test_with_role(self, dp, prefix, command, role): + handler = self.make_default_handler(roles=role) + dp.add_handler(handler) + text = prefix + command + assert not self.response(dp, make_message_update(text)) + + role.chat_ids = 1 + assert self.response(dp, make_message_update(text)) diff --git a/tests/test_conversationhandler.py b/tests/test_conversationhandler.py index ff890628c6f..4e49fee8d01 100644 --- a/tests/test_conversationhandler.py +++ b/tests/test_conversationhandler.py @@ -849,6 +849,23 @@ def bye(bot, update): " since inline queries have no chat context." ) + def test_conversationhandler_with_role(self, dp, bot, user1, role): + handler = ConversationHandler(entry_points=self.entry_points, states=self.states, + fallbacks=self.fallbacks, roles=role) + dp.add_handler(handler) + + # User one, starts the state machine. + message = Message(0, user1, None, self.second_group, text='/start', + entities=[MessageEntity(type=MessageEntity.BOT_COMMAND, + offset=0, length=len('/start'))], + bot=bot) + dp.process_update(Update(update_id=0, message=message)) + assert user1.id not in self.current_state + + role.chat_ids = 123 + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.THIRSTY + def test_nested_conversation_handler(self, dp, bot, user1, user2): self.nested_states[self.DRINKING] = [ConversationHandler( entry_points=self.drinking_entry_points, diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index b3e1c3eb32b..8179d9b3c43 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -25,7 +25,7 @@ from telegram import TelegramError, Message, User, Chat, Update, Bot, MessageEntity from telegram.ext import (MessageHandler, Filters, CommandHandler, CallbackContext, - JobQueue, BasePersistence) + JobQueue, BasePersistence, Roles) from telegram.ext.dispatcher import run_async, Dispatcher, DispatcherHandlerStop from telegram.utils.deprecate import TelegramDeprecationWarning from tests.conftest import create_dp @@ -350,10 +350,17 @@ def test_error_while_saving_chat_data(self, dp, bot): class OwnPersistence(BasePersistence): def __init__(self): - super(BasePersistence, self).__init__() + super(OwnPersistence, self).__init__() self.store_user_data = True self.store_chat_data = True self.store_bot_data = True + self.store_roles = True + + def get_roles(self): + return Roles(None) + + def update_roles(self, data): + raise Exception def get_bot_data(self): return dict() @@ -393,7 +400,7 @@ def error(b, u, e): dp.add_handler(CommandHandler('start', start1)) dp.add_error_handler(error) dp.process_update(update) - assert increment == ["error", "error", "error"] + assert increment == ["error", "error", "error", "error"] def test_flow_stop_in_error_handler(self, dp, bot): passed = [] diff --git a/tests/test_filters.py b/tests/test_filters.py index f081fed087c..a1dd3682e5c 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -563,6 +563,10 @@ def test_filters_user(self): with pytest.raises(ValueError, match='user_id or username'): Filters.user() + def test_filters_user_empty_args(self, update): + assert not Filters.user(user_id=[])(update) + assert not Filters.user(username=[])(update) + def test_filters_user_id(self, update): assert not Filters.user(user_id=1)(update) update.message.from_user.id = 1 @@ -580,6 +584,30 @@ def test_filters_username(self, update): assert Filters.user(username=['user1', 'user', 'user2'])(update) assert not Filters.user(username=['@username', '@user_2'])(update) + def test_filters_user_change_id(self, update): + f = Filters.user(user_id=1) + update.message.from_user.id = 1 + assert f(update) + update.message.from_user.id = 2 + assert not f(update) + f.user_ids = 2 + assert f(update) + + with pytest.raises(ValueError, match='user_id or username'): + f.usernames = 'user' + + def test_filters_user_change_username(self, update): + f = Filters.user(username='user') + update.message.from_user.username = 'user' + assert f(update) + update.message.from_user.username = 'User' + assert not f(update) + f.usernames = 'User' + assert f(update) + + with pytest.raises(ValueError, match='user_id or username'): + f.user_ids = 1 + def test_filters_chat(self): with pytest.raises(ValueError, match='chat_id or username'): Filters.chat(chat_id=-1, username='chat') diff --git a/tests/test_inlinequeryhandler.py b/tests/test_inlinequeryhandler.py index 9e3e7a95159..2fe29e0875b 100644 --- a/tests/test_inlinequeryhandler.py +++ b/tests/test_inlinequeryhandler.py @@ -95,6 +95,7 @@ def callback_context(self, update, context): and isinstance(context.user_data, dict) and context.chat_data is None and isinstance(context.bot_data, dict) + and isinstance(context.roles, dict) and isinstance(update.inline_query, InlineQuery)) def callback_context_pattern(self, update, context): @@ -121,6 +122,13 @@ def test_with_pattern(self, inline_query): inline_query.inline_query.query = 'nothing here' assert not handler.check_update(inline_query) + def test_with_role(self, inline_query, role): + handler = InlineQueryHandler(self.callback_basic, roles=role) + assert not handler.check_update(inline_query) + + role.chat_ids = 2 + assert handler.check_update(inline_query) + def test_with_passing_group_dict(self, dp, inline_query): handler = InlineQueryHandler(self.callback_group, pattern='(?P.*)est(?P.*)', diff --git a/tests/test_messagehandler.py b/tests/test_messagehandler.py index 7e2f5fb63ab..4076a398cc4 100644 --- a/tests/test_messagehandler.py +++ b/tests/test_messagehandler.py @@ -84,6 +84,7 @@ def callback_context(self, update, context): and isinstance(context.job_queue, JobQueue) and isinstance(context.chat_data, dict) and isinstance(context.bot_data, dict) + and isinstance(context.roles, dict) and ((isinstance(context.user_data, dict) and (isinstance(update.message, Message) or isinstance(update.edited_message, Message))) @@ -172,6 +173,13 @@ def test_specific_filters(self, message): assert not handler.check_update(Update(0, channel_post=message)) assert handler.check_update(Update(0, edited_channel_post=message)) + def test_with_role(self, message, role): + handler = MessageHandler(None, self.callback_basic, roles=role) + assert not handler.check_update(Update(0, message)) + + role.chat_ids = 1 + assert handler.check_update(Update(0, message)) + def test_pass_user_or_chat_data(self, dp, message): handler = MessageHandler(None, self.callback_data_1, pass_user_data=True) diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 20fe75d5783..b3b6764abd3 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -29,13 +29,14 @@ import os import pickle from collections import defaultdict +from copy import deepcopy from time import sleep import pytest from telegram import Update, Message, User, Chat, MessageEntity from telegram.ext import BasePersistence, Updater, ConversationHandler, MessageHandler, Filters, \ - PicklePersistence, CommandHandler, DictPersistence, TypeHandler, JobQueue + PicklePersistence, CommandHandler, DictPersistence, TypeHandler, Roles, Role, JobQueue @pytest.fixture(autouse=True) @@ -76,15 +77,26 @@ def conversations(): 'name3': {(123, 321): 1, (890, 890): 2}} +@pytest.fixture(scope='function') +def roles(): + roles = Roles(None) + roles.add_admin(12345) + roles.add_role(name='parent_role', chat_ids=[456]) + roles.add_role(name='role', chat_ids=[123], parent_roles=roles['parent_role']) + return roles + + @pytest.fixture(scope="function") def updater(bot, base_persistence): base_persistence.store_chat_data = False base_persistence.store_bot_data = False base_persistence.store_user_data = False + base_persistence.store_roles = False u = Updater(bot=bot, persistence=base_persistence) base_persistence.store_bot_data = True base_persistence.store_chat_data = True base_persistence.store_user_data = True + base_persistence.store_roles = True return u @@ -100,12 +112,16 @@ class TestBasePersistence(object): def test_creation(self, base_persistence): assert base_persistence.store_chat_data assert base_persistence.store_user_data + assert base_persistence.store_bot_data + assert base_persistence.store_roles with pytest.raises(NotImplementedError): base_persistence.get_bot_data() with pytest.raises(NotImplementedError): base_persistence.get_chat_data() with pytest.raises(NotImplementedError): base_persistence.get_user_data() + with pytest.raises(NotImplementedError): + base_persistence.get_roles() with pytest.raises(NotImplementedError): base_persistence.get_conversations("test") with pytest.raises(NotImplementedError): @@ -114,6 +130,8 @@ def test_creation(self, base_persistence): base_persistence.update_chat_data(None, None) with pytest.raises(NotImplementedError): base_persistence.update_user_data(None, None) + with pytest.raises(NotImplementedError): + base_persistence.update_roles(None) with pytest.raises(NotImplementedError): base_persistence.update_conversation(None, None, None) @@ -131,7 +149,7 @@ def test_conversationhandler_addition(self, dp, base_persistence): dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler")) def test_dispatcher_integration_init(self, bot, base_persistence, chat_data, user_data, - bot_data): + bot_data, roles): def get_user_data(): return "test" @@ -141,9 +159,13 @@ def get_chat_data(): def get_bot_data(): return "test" + def get_roles(): + return "test" + base_persistence.get_user_data = get_user_data base_persistence.get_chat_data = get_chat_data base_persistence.get_bot_data = get_bot_data + base_persistence.get_roles = get_roles with pytest.raises(ValueError, match="user_data must be of type defaultdict"): u = Updater(bot=bot, persistence=base_persistence) @@ -166,6 +188,13 @@ def get_bot_data(): return bot_data base_persistence.get_bot_data = get_bot_data + with pytest.raises(ValueError, match="roles must be of type Roles"): + u = Updater(bot=bot, persistence=base_persistence) + + def get_roles(): + return roles + + base_persistence.get_roles = get_roles u = Updater(bot=bot, persistence=base_persistence) assert u.dispatcher.bot_data == bot_data assert u.dispatcher.chat_data == chat_data @@ -174,7 +203,7 @@ def get_bot_data(): assert u.dispatcher.chat_data[442233]['test5'] == 'test6' def test_dispatcher_integration_handlers(self, caplog, bot, base_persistence, - chat_data, user_data, bot_data): + chat_data, user_data, bot_data, roles): def get_user_data(): return user_data @@ -184,9 +213,13 @@ def get_chat_data(): def get_bot_data(): return bot_data + def get_roles(): + return roles + base_persistence.get_user_data = get_user_data base_persistence.get_chat_data = get_chat_data base_persistence.get_bot_data = get_bot_data + base_persistence.get_roles = get_roles # base_persistence.update_chat_data = lambda x: x # base_persistence.update_user_data = lambda x: x updater = Updater(bot=bot, persistence=base_persistence, use_context=True) @@ -197,12 +230,16 @@ def callback_known_user(update, context): pytest.fail('user_data corrupt') if not context.bot_data == bot_data: pytest.fail('bot_data corrupt') + if not context.roles == roles: + pytest.fail('roles corrupt') def callback_known_chat(update, context): if not context.chat_data['test3'] == 'test4': pytest.fail('chat_data corrupt') if not context.bot_data == bot_data: pytest.fail('bot_data corrupt') + if not context.roles == roles: + pytest.fail('roles corrupt') def callback_unknown_user_or_chat(update, context): if not context.user_data == {}: @@ -211,9 +248,12 @@ def callback_unknown_user_or_chat(update, context): pytest.fail('chat_data corrupt') if not context.bot_data == bot_data: pytest.fail('bot_data corrupt') + if not context.roles == roles: + pytest.fail('roles corrupt') context.user_data[1] = 'test7' context.chat_data[2] = 'test8' context.bot_data['test0'] = 'test0' + context.roles.add_role(name='test0', chat_ids=[1, 2, 3]) known_user = MessageHandler(Filters.user(user_id=12345), callback_known_user, pass_chat_data=True, pass_user_data=True) @@ -257,14 +297,21 @@ def save_user_data(data): if 54321 not in data: pytest.fail() + def save_roles(data): + if 'test0' not in data: + pytest.fail() + base_persistence.update_chat_data = save_chat_data base_persistence.update_user_data = save_user_data base_persistence.update_bot_data = save_bot_data + base_persistence.update_roles = save_roles dp.process_update(u) assert dp.user_data[54321][1] == 'test7' assert dp.chat_data[-987654][2] == 'test8' assert dp.bot_data['test0'] == 'test0' + assert dp.roles['test0'].equals(Role(name='test0', chat_ids=[1, 2, 3], + parent_roles=dp.roles.ADMINS)) def test_persistence_dispatcher_arbitrary_update_types(self, dp, base_persistence, caplog): # Updates used with TypeHandler doesn't necessarily have the proper attributes for @@ -288,6 +335,18 @@ def pickle_persistence(): store_user_data=True, store_chat_data=True, store_bot_data=True, + store_roles=True, + single_file=False, + on_flush=False) + + +@pytest.fixture(scope='function') +def pickle_persistence_only_roles(): + return PicklePersistence(filename='pickletest', + store_user_data=False, + store_chat_data=False, + store_bot_data=False, + store_roles=True, single_file=False, on_flush=False) @@ -298,6 +357,7 @@ def pickle_persistence_only_bot(): store_user_data=False, store_chat_data=False, store_bot_data=True, + store_roles=False, single_file=False, on_flush=False) @@ -308,6 +368,7 @@ def pickle_persistence_only_chat(): store_user_data=False, store_chat_data=True, store_bot_data=False, + store_roles=False, single_file=False, on_flush=False) @@ -318,6 +379,7 @@ def pickle_persistence_only_user(): store_user_data=True, store_chat_data=False, store_bot_data=False, + store_roles=False, single_file=False, on_flush=False) @@ -325,22 +387,27 @@ def pickle_persistence_only_user(): @pytest.fixture(scope='function') def bad_pickle_files(): for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_bot_data', - 'pickletest_conversations', 'pickletest']: + 'pickletest_conversations', 'pickletest_roles', 'pickletest']: with open(name, 'w') as f: f.write('(())') yield True + for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_bot_data', + 'pickletest_conversations', 'pickletest_roles', 'pickletest']: + os.remove(name) @pytest.fixture(scope='function') -def good_pickle_files(user_data, chat_data, bot_data, conversations): +def good_pickle_files(user_data, chat_data, bot_data, conversations, roles): data = {'user_data': user_data, 'chat_data': chat_data, - 'bot_data': bot_data, 'conversations': conversations} + 'bot_data': bot_data, 'conversations': conversations, 'roles': roles.encode_to_json()} with open('pickletest_user_data', 'wb') as f: pickle.dump(user_data, f) with open('pickletest_chat_data', 'wb') as f: pickle.dump(chat_data, f) with open('pickletest_bot_data', 'wb') as f: pickle.dump(bot_data, f) + with open('pickletest_roles', 'wb') as f: + pickle.dump(roles.encode_to_json(), f) with open('pickletest_conversations', 'wb') as f: pickle.dump(conversations, f) with open('pickletest', 'wb') as f: @@ -378,14 +445,22 @@ def test_no_files_present_multi_file(self, pickle_persistence): assert pickle_persistence.get_chat_data() == defaultdict(dict) assert pickle_persistence.get_bot_data() == {} assert pickle_persistence.get_bot_data() == {} + assert pickle_persistence.get_roles() == Roles(None) + assert pickle_persistence.get_roles() == Roles(None) assert pickle_persistence.get_conversations('noname') == {} assert pickle_persistence.get_conversations('noname') == {} def test_no_files_present_single_file(self, pickle_persistence): pickle_persistence.single_file = True assert pickle_persistence.get_user_data() == defaultdict(dict) + assert pickle_persistence.get_user_data() == defaultdict(dict) assert pickle_persistence.get_chat_data() == defaultdict(dict) - assert pickle_persistence.get_chat_data() == {} + assert pickle_persistence.get_chat_data() == defaultdict(dict) + assert pickle_persistence.get_bot_data() == {} + assert pickle_persistence.get_bot_data() == {} + assert pickle_persistence.get_roles() == Roles(None) + assert pickle_persistence.get_roles() == Roles(None) + assert pickle_persistence.get_conversations('noname') == {} assert pickle_persistence.get_conversations('noname') == {} def test_with_bad_multi_file(self, pickle_persistence, bad_pickle_files): @@ -395,6 +470,8 @@ def test_with_bad_multi_file(self, pickle_persistence, bad_pickle_files): pickle_persistence.get_chat_data() with pytest.raises(TypeError, match='pickletest_bot_data'): pickle_persistence.get_bot_data() + with pytest.raises(TypeError, match='pickletest_roles'): + pickle_persistence.get_roles() with pytest.raises(TypeError, match='pickletest_conversations'): pickle_persistence.get_conversations('name') @@ -406,6 +483,8 @@ def test_with_bad_single_file(self, pickle_persistence, bad_pickle_files): pickle_persistence.get_chat_data() with pytest.raises(TypeError, match='pickletest'): pickle_persistence.get_bot_data() + with pytest.raises(TypeError, match='pickletest'): + pickle_persistence.get_roles() with pytest.raises(TypeError, match='pickletest'): pickle_persistence.get_conversations('name') @@ -428,6 +507,15 @@ def test_with_good_multi_file(self, pickle_persistence, good_pickle_files): assert bot_data['test3']['test4'] == 'test5' assert 'test0' not in bot_data + roles = pickle_persistence.get_roles() + assert isinstance(roles, Roles) + pr = Role(name='parent_role', chat_ids=456) + r = Role(name='role', chat_ids=123, parent_roles=pr) + assert roles.ADMINS.equals(Role(name='admins', chat_ids=12345, child_roles=[r, pr])) + assert roles['parent_role'].equals(pr) + assert roles['role'].equals(r) + assert not roles.get('test', None) + conversation1 = pickle_persistence.get_conversations('name1') assert isinstance(conversation1, dict) assert conversation1[(123, 123)] == 3 @@ -461,6 +549,15 @@ def test_with_good_single_file(self, pickle_persistence, good_pickle_files): assert bot_data['test3']['test4'] == 'test5' assert 'test0' not in bot_data + roles = pickle_persistence.get_roles() + assert isinstance(roles, Roles) + pr = Role(name='parent_role', chat_ids=456) + r = Role(name='role', chat_ids=123, parent_roles=pr) + assert roles.ADMINS.equals(Role(name='admins', chat_ids=12345, child_roles=[r, pr])) + assert roles['parent_role'].equals(pr) + assert roles['role'].equals(r) + assert not roles.get('test', None) + conversation1 = pickle_persistence.get_conversations('name1') assert isinstance(conversation1, dict) assert conversation1[(123, 123)] == 3 @@ -550,6 +647,16 @@ def test_updating_multi_file(self, pickle_persistence, good_pickle_files): bot_data_test = pickle.load(f) assert bot_data_test == bot_data + roles = pickle_persistence.get_roles() + roles.add_role(name='new_role', chat_ids=10) + assert not pickle_persistence.roles == roles + pickle_persistence.update_roles(roles) + assert pickle_persistence.roles == roles + with open('pickletest_roles', 'rb') as f: + roles_test = pickle.load(f) + roles_test = Roles.decode_from_json(roles_test, None) + assert roles_test == roles + conversation1 = pickle_persistence.get_conversations('name1') conversation1[(123, 123)] = 5 assert not pickle_persistence.conversations['name1'] == conversation1 @@ -589,6 +696,16 @@ def test_updating_single_file(self, pickle_persistence, good_pickle_files): bot_data_test = pickle.load(f)['bot_data'] assert bot_data_test == bot_data + roles = pickle_persistence.get_roles() + roles.add_role(name='new_role', chat_ids=10) + assert not pickle_persistence.roles == roles + pickle_persistence.update_roles(roles) + assert pickle_persistence.roles == roles + with open('pickletest', 'rb') as f: + roles_test = pickle.load(f)['roles'] + roles_test = Roles.decode_from_json(roles_test, None) + assert roles_test == roles + conversation1 = pickle_persistence.get_conversations('name1') conversation1[(123, 123)] = 5 assert not pickle_persistence.conversations['name1'] == conversation1 @@ -636,6 +753,18 @@ def test_save_on_flush_multi_files(self, pickle_persistence, good_pickle_files): bot_data_test = pickle.load(f) assert not bot_data_test == bot_data + roles = pickle_persistence.get_roles() + roles.add_role(name='new_role', chat_ids=10) + assert not pickle_persistence.roles == roles + + pickle_persistence.update_roles(roles) + assert pickle_persistence.roles == roles + + with open('pickletest_roles', 'rb') as f: + roles_test = pickle.load(f) + roles_test = Roles.decode_from_json(roles_test, None) + assert not roles_test == bot_data + conversation1 = pickle_persistence.get_conversations('name1') conversation1[(123, 123)] = 5 assert not pickle_persistence.conversations['name1'] == conversation1 @@ -660,6 +789,11 @@ def test_save_on_flush_multi_files(self, pickle_persistence, good_pickle_files): bot_data_test = pickle.load(f) assert bot_data_test == bot_data + with open('pickletest_roles', 'rb') as f: + roles_test = pickle.load(f) + roles_test = Roles.decode_from_json(roles_test, None) + assert roles_test == roles + with open('pickletest_conversations', 'rb') as f: conversations_test = defaultdict(dict, pickle.load(f)) assert conversations_test['name1'] == conversation1 @@ -698,6 +832,18 @@ def test_save_on_flush_single_files(self, pickle_persistence, good_pickle_files) bot_data_test = pickle.load(f)['bot_data'] assert not bot_data_test == bot_data + roles = pickle_persistence.get_roles() + roles.add_role(name='new_role', chat_ids=10) + assert not pickle_persistence.roles == roles + + pickle_persistence.update_roles(roles) + assert pickle_persistence.roles == roles + + with open('pickletest', 'rb') as f: + roles_test = pickle.load(f)['roles'] + roles_test = Roles.decode_from_json(roles_test, None) + assert not roles_test == bot_data + conversation1 = pickle_persistence.get_conversations('name1') conversation1[(123, 123)] = 5 assert not pickle_persistence.conversations['name1'] == conversation1 @@ -720,11 +866,17 @@ def test_save_on_flush_single_files(self, pickle_persistence, good_pickle_files) bot_data_test = pickle.load(f)['bot_data'] assert bot_data_test == bot_data + with open('pickletest', 'rb') as f: + roles_test = pickle.load(f)['roles'] + roles_test = Roles.decode_from_json(roles_test, None) + assert roles_test == roles + with open('pickletest', 'rb') as f: conversations_test = defaultdict(dict, pickle.load(f)['conversations']) assert conversations_test['name1'] == conversation1 - def test_with_handler(self, bot, update, bot_data, pickle_persistence, good_pickle_files): + def test_with_handler(self, bot, update, bot_data, roles, pickle_persistence, + good_pickle_files): u = Updater(bot=bot, persistence=pickle_persistence, use_context=True) dp = u.dispatcher @@ -735,9 +887,12 @@ def first(update, context): pytest.fail() if not context.bot_data == bot_data: pytest.fail() + if not context.roles == roles: + pytest.fail() context.user_data['test1'] = 'test2' context.chat_data['test3'] = 'test4' context.bot_data['test1'] = 'test0' + context.roles.add_role(name='test2', chat_ids=[4, 5]) def second(update, context): if not context.user_data['test1'] == 'test2': @@ -746,6 +901,8 @@ def second(update, context): pytest.fail() if not context.bot_data['test1'] == 'test0': pytest.fail() + if not context.roles['test2'].user_ids == set([4, 5]): + pytest.fail() h1 = MessageHandler(None, first, pass_user_data=True, pass_chat_data=True) h2 = MessageHandler(None, second, pass_user_data=True, pass_chat_data=True) @@ -758,6 +915,7 @@ def second(update, context): store_user_data=True, store_chat_data=True, store_bot_data=True, + store_roles=True, single_file=False, on_flush=False) u = Updater(bot=bot, persistence=pickle_persistence_2) @@ -771,7 +929,8 @@ def test_flush_on_stop(self, bot, update, pickle_persistence): u.running = True dp.user_data[4242424242]['my_test'] = 'Working!' dp.chat_data[-4242424242]['my_test2'] = 'Working2!' - dp.bot_data['test'] = 'Working3!' + dp.bot_data['my_test3'] = 'Working3!' + dp.roles.add_role(name='Working4!', chat_ids=[4, 5]) u.signal_handler(signal.SIGINT, None) del (dp) del (u) @@ -779,11 +938,38 @@ def test_flush_on_stop(self, bot, update, pickle_persistence): pickle_persistence_2 = PicklePersistence(filename='pickletest', store_user_data=True, store_chat_data=True, + store_bot_data=True, + store_roles=True, single_file=False, on_flush=False) assert pickle_persistence_2.get_user_data()[4242424242]['my_test'] == 'Working!' assert pickle_persistence_2.get_chat_data()[-4242424242]['my_test2'] == 'Working2!' - assert pickle_persistence_2.get_bot_data()['test'] == 'Working3!' + assert pickle_persistence_2.get_bot_data()['my_test3'] == 'Working3!' + assert pickle_persistence_2.get_roles()['Working4!'].chat_ids == set([4, 5]) + + def test_flush_on_stop_only_roles(self, bot, update, pickle_persistence_only_roles): + u = Updater(bot=bot, persistence=pickle_persistence_only_roles) + dp = u.dispatcher + u.running = True + dp.user_data[4242424242]['my_test'] = 'Working!' + dp.chat_data[-4242424242]['my_test2'] = 'Working2!' + dp.bot_data['test'] = 'Working3!' + dp.roles.add_role(name='Working5!', chat_ids=[4, 5]) + u.signal_handler(signal.SIGINT, None) + del (dp) + del (u) + del (pickle_persistence_only_roles) + pickle_persistence_2 = PicklePersistence(filename='pickletest', + store_user_data=False, + store_chat_data=False, + store_bot_data=False, + store_roles=True, + single_file=False, + on_flush=False) + assert pickle_persistence_2.get_user_data() == {} + assert pickle_persistence_2.get_chat_data() == {} + assert pickle_persistence_2.get_bot_data() == {} + assert pickle_persistence_2.get_roles()['Working5!'].chat_ids == set([4, 5]) def test_flush_on_stop_only_bot(self, bot, update, pickle_persistence_only_bot): u = Updater(bot=bot, persistence=pickle_persistence_only_bot) @@ -792,6 +978,7 @@ def test_flush_on_stop_only_bot(self, bot, update, pickle_persistence_only_bot): dp.user_data[4242424242]['my_test'] = 'Working!' dp.chat_data[-4242424242]['my_test2'] = 'Working2!' dp.bot_data['my_test3'] = 'Working3!' + dp.roles.add_role(name='Working4!', chat_ids=[4, 5]) u.signal_handler(signal.SIGINT, None) del (dp) del (u) @@ -800,11 +987,13 @@ def test_flush_on_stop_only_bot(self, bot, update, pickle_persistence_only_bot): store_user_data=False, store_chat_data=False, store_bot_data=True, + store_roles=False, single_file=False, on_flush=False) assert pickle_persistence_2.get_user_data() == {} assert pickle_persistence_2.get_chat_data() == {} assert pickle_persistence_2.get_bot_data()['my_test3'] == 'Working3!' + assert pickle_persistence_2.get_roles() == Roles(None) def test_flush_on_stop_only_chat(self, bot, update, pickle_persistence_only_chat): u = Updater(bot=bot, persistence=pickle_persistence_only_chat) @@ -812,6 +1001,8 @@ def test_flush_on_stop_only_chat(self, bot, update, pickle_persistence_only_chat u.running = True dp.user_data[4242424242]['my_test'] = 'Working!' dp.chat_data[-4242424242]['my_test2'] = 'Working2!' + dp.bot_data['my_test3'] = 'Working3!' + dp.roles.add_role(name='Working4!', chat_ids=[4, 5]) u.signal_handler(signal.SIGINT, None) del (dp) del (u) @@ -820,11 +1011,13 @@ def test_flush_on_stop_only_chat(self, bot, update, pickle_persistence_only_chat store_user_data=False, store_chat_data=True, store_bot_data=False, + store_roles=False, single_file=False, on_flush=False) assert pickle_persistence_2.get_user_data() == {} assert pickle_persistence_2.get_chat_data()[-4242424242]['my_test2'] == 'Working2!' assert pickle_persistence_2.get_bot_data() == {} + assert pickle_persistence_2.get_roles() == Roles(None) def test_flush_on_stop_only_user(self, bot, update, pickle_persistence_only_user): u = Updater(bot=bot, persistence=pickle_persistence_only_user) @@ -832,6 +1025,8 @@ def test_flush_on_stop_only_user(self, bot, update, pickle_persistence_only_user u.running = True dp.user_data[4242424242]['my_test'] = 'Working!' dp.chat_data[-4242424242]['my_test2'] = 'Working2!' + dp.bot_data['my_test3'] = 'Working3!' + dp.roles.add_role(name='Working4!', chat_ids=[4, 5]) u.signal_handler(signal.SIGINT, None) del (dp) del (u) @@ -840,11 +1035,13 @@ def test_flush_on_stop_only_user(self, bot, update, pickle_persistence_only_user store_user_data=True, store_chat_data=False, store_bot_data=False, + store_roles=False, single_file=False, on_flush=False) assert pickle_persistence_2.get_user_data()[4242424242]['my_test'] == 'Working!' assert pickle_persistence_2.get_chat_data()[-4242424242] == {} assert pickle_persistence_2.get_bot_data() == {} + assert pickle_persistence_2.get_roles() == Roles(None) def test_with_conversationHandler(self, dp, update, good_pickle_files, pickle_persistence): dp.persistence = pickle_persistence @@ -933,6 +1130,7 @@ def job_callback(context): context.bot_data['test1'] = '456' context.dispatcher.chat_data[123]['test2'] = '789' context.dispatcher.user_data[789]['test3'] = '123' + context.roles.add_role(name='test2', chat_ids=[4, 5]) cdp.persistence = pickle_persistence job_queue.set_dispatcher(cdp) @@ -945,6 +1143,8 @@ def job_callback(context): assert chat_data[123] == {'test2': '789'} user_data = pickle_persistence.get_user_data() assert user_data[789] == {'test3': '123'} + roles = pickle_persistence.get_roles() + assert roles['test2'].user_ids == set([4, 5]) @pytest.fixture(scope='function') @@ -969,18 +1169,25 @@ def conversations_json(conversations): {"[123, 321]": 1, "[890, 890]": 2}}""" +@pytest.fixture(scope='function') +def roles_json(roles): + return roles.encode_to_json() + + class TestDictPersistence(object): def test_no_json_given(self): dict_persistence = DictPersistence() assert dict_persistence.get_user_data() == defaultdict(dict) assert dict_persistence.get_chat_data() == defaultdict(dict) assert dict_persistence.get_bot_data() == {} + assert dict_persistence.get_roles() == Roles(None) assert dict_persistence.get_conversations('noname') == {} def test_bad_json_string_given(self): bad_user_data = 'thisisnojson99900()))(' bad_chat_data = 'thisisnojson99900()))(' bad_bot_data = 'thisisnojson99900()))(' + bad_roles = 'thisisnojson99900()))(' bad_conversations = 'thisisnojson99900()))(' with pytest.raises(TypeError, match='user_data'): DictPersistence(user_data_json=bad_user_data) @@ -988,6 +1195,8 @@ def test_bad_json_string_given(self): DictPersistence(chat_data_json=bad_chat_data) with pytest.raises(TypeError, match='bot_data'): DictPersistence(bot_data_json=bad_bot_data) + with pytest.raises(TypeError, match='roles'): + DictPersistence(roles_json=bad_roles) with pytest.raises(TypeError, match='conversations'): DictPersistence(conversations_json=bad_conversations) @@ -995,6 +1204,7 @@ def test_invalid_json_string_given(self, pickle_persistence, bad_pickle_files): bad_user_data = '["this", "is", "json"]' bad_chat_data = '["this", "is", "json"]' bad_bot_data = '["this", "is", "json"]' + bad_roles = '["this", "is", "json"]' bad_conversations = '["this", "is", "json"]' with pytest.raises(TypeError, match='user_data'): DictPersistence(user_data_json=bad_user_data) @@ -1002,14 +1212,17 @@ def test_invalid_json_string_given(self, pickle_persistence, bad_pickle_files): DictPersistence(chat_data_json=bad_chat_data) with pytest.raises(TypeError, match='bot_data'): DictPersistence(bot_data_json=bad_bot_data) + with pytest.raises(TypeError, match='roles'): + DictPersistence(roles_json=bad_roles) with pytest.raises(TypeError, match='conversations'): DictPersistence(conversations_json=bad_conversations) - def test_good_json_input(self, user_data_json, chat_data_json, bot_data_json, + def test_good_json_input(self, user_data_json, chat_data_json, bot_data_json, roles_json, conversations_json): dict_persistence = DictPersistence(user_data_json=user_data_json, chat_data_json=chat_data_json, bot_data_json=bot_data_json, + roles_json=roles_json, conversations_json=conversations_json) user_data = dict_persistence.get_user_data() assert isinstance(user_data, defaultdict) @@ -1029,6 +1242,15 @@ def test_good_json_input(self, user_data_json, chat_data_json, bot_data_json, assert bot_data['test3']['test4'] == 'test5' assert 'test6' not in bot_data + roles = dict_persistence.get_roles() + assert isinstance(roles, Roles) + pr = Role(name='parent_role', chat_ids=456) + r = Role(name='role', chat_ids=123, parent_roles=pr) + assert roles.ADMINS.equals(Role(name='admins', chat_ids=12345, child_roles=[r, pr])) + assert roles['parent_role'].equals(pr) + assert roles['role'].equals(r) + assert not roles.get('test', None) + conversation1 = dict_persistence.get_conversations('name1') assert isinstance(conversation1, dict) assert conversation1[(123, 123)] == 3 @@ -1043,35 +1265,41 @@ def test_good_json_input(self, user_data_json, chat_data_json, bot_data_json, conversation2[(123, 123)] def test_dict_outputs(self, user_data, user_data_json, chat_data, chat_data_json, - bot_data, bot_data_json, + bot_data, bot_data_json, roles, roles_json, conversations, conversations_json): dict_persistence = DictPersistence(user_data_json=user_data_json, chat_data_json=chat_data_json, bot_data_json=bot_data_json, + roles_json=roles_json, conversations_json=conversations_json) assert dict_persistence.user_data == user_data assert dict_persistence.chat_data == chat_data assert dict_persistence.bot_data == bot_data + assert dict_persistence.roles == roles assert dict_persistence.conversations == conversations @pytest.mark.skipif(sys.version_info < (3, 6), reason="dicts are not ordered in py<=3.5") - def test_json_outputs(self, user_data_json, chat_data_json, bot_data_json, conversations_json): + def test_json_outputs(self, user_data_json, chat_data_json, bot_data_json, roles_json, + conversations_json): dict_persistence = DictPersistence(user_data_json=user_data_json, chat_data_json=chat_data_json, bot_data_json=bot_data_json, + roles_json=roles_json, conversations_json=conversations_json) assert dict_persistence.user_data_json == user_data_json assert dict_persistence.chat_data_json == chat_data_json assert dict_persistence.bot_data_json == bot_data_json + assert dict_persistence.roles_json == roles_json assert dict_persistence.conversations_json == conversations_json @pytest.mark.skipif(sys.version_info < (3, 6), reason="dicts are not ordered in py<=3.5") def test_json_changes(self, user_data, user_data_json, chat_data, chat_data_json, - bot_data, bot_data_json, + bot_data, bot_data_json, roles, roles_json, conversations, conversations_json): dict_persistence = DictPersistence(user_data_json=user_data_json, chat_data_json=chat_data_json, bot_data_json=bot_data_json, + roles_json=roles_json, conversations_json=conversations_json) user_data_two = user_data.copy() user_data_two.update({4: {5: 6}}) @@ -1095,6 +1323,14 @@ def test_json_changes(self, user_data, user_data_json, chat_data, chat_data_json assert dict_persistence.bot_data_json != bot_data_json assert dict_persistence.bot_data_json == json.dumps(bot_data_two) + roles_two = deepcopy(roles) + roles_two.add_role(name='role_two', chat_ids=[7, 8]) + roles.add_role(name='role_two', chat_ids=[7, 8]) + dict_persistence.update_roles(roles) + assert dict_persistence.roles == roles_two + assert dict_persistence.roles_json != roles_json + assert Roles.decode_from_json(dict_persistence.roles_json, None) == roles + conversations_two = conversations.copy() conversations_two.update({'name4': {(1, 2): 3}}) dict_persistence.update_conversation('name4', (1, 2), 3) @@ -1118,6 +1354,7 @@ def first(update, context): context.user_data['test1'] = 'test2' context.chat_data[3] = 'test4' context.bot_data['test1'] = 'test2' + context.roles.add_role(name='test2', chat_ids=[4, 5]) def second(update, context): if not context.user_data['test1'] == 'test2': @@ -1126,6 +1363,8 @@ def second(update, context): pytest.fail() if not context.bot_data['test1'] == 'test2': pytest.fail() + if not context.roles['test2'].user_ids == set([4, 5]): + pytest.fail() h1 = MessageHandler(None, first, pass_user_data=True, pass_chat_data=True) h2 = MessageHandler(None, second, pass_user_data=True, pass_chat_data=True) @@ -1234,6 +1473,7 @@ def job_callback(context): context.bot_data['test1'] = '456' context.dispatcher.chat_data[123]['test2'] = '789' context.dispatcher.user_data[789]['test3'] = '123' + context.roles.add_role(name='test3', chat_ids=[4, 5]) dict_persistence = DictPersistence() cdp.persistence = dict_persistence @@ -1247,3 +1487,5 @@ def job_callback(context): assert chat_data[123] == {'test2': '789'} user_data = dict_persistence.get_user_data() assert user_data[789] == {'test3': '123'} + roles = dict_persistence.get_roles() + assert roles['test3'].user_ids == set([4, 5]) diff --git a/tests/test_pollanswerhandler.py b/tests/test_pollanswerhandler.py index d16c403fba7..72a15bbd470 100644 --- a/tests/test_pollanswerhandler.py +++ b/tests/test_pollanswerhandler.py @@ -86,6 +86,7 @@ def callback_context(self, update, context): and isinstance(context.user_data, dict) and context.chat_data is None and isinstance(context.bot_data, dict) + and isinstance(context.roles, dict) and isinstance(update.poll_answer, PollAnswer)) def test_basic(self, dp, poll_answer): @@ -97,6 +98,14 @@ def test_basic(self, dp, poll_answer): dp.process_update(poll_answer) assert self.test_flag + def test_with_role(self, poll_answer, role): + handler = PollAnswerHandler(self.callback_basic, roles=role) + print(role.chat_ids) + assert not handler.check_update(poll_answer) + + role.chat_ids = 2 + assert handler.check_update(poll_answer) + def test_pass_user_or_chat_data(self, dp, poll_answer): handler = PollAnswerHandler(self.callback_data_1, pass_user_data=True) dp.add_handler(handler) diff --git a/tests/test_pollhandler.py b/tests/test_pollhandler.py index 2c3012756a0..a2d54438fab 100644 --- a/tests/test_pollhandler.py +++ b/tests/test_pollhandler.py @@ -87,6 +87,7 @@ def callback_context(self, update, context): and context.user_data is None and context.chat_data is None and isinstance(context.bot_data, dict) + and isinstance(context.roles, dict) and isinstance(update.poll, Poll)) def test_basic(self, dp, poll): diff --git a/tests/test_precheckoutqueryhandler.py b/tests/test_precheckoutqueryhandler.py index 72dad97dc7e..c71ecc629d9 100644 --- a/tests/test_precheckoutqueryhandler.py +++ b/tests/test_precheckoutqueryhandler.py @@ -88,6 +88,7 @@ def callback_context(self, update, context): and isinstance(context.user_data, dict) and context.chat_data is None and isinstance(context.bot_data, dict) + and isinstance(context.roles, dict) and isinstance(update.pre_checkout_query, PreCheckoutQuery)) def test_basic(self, dp, pre_checkout_query): @@ -98,6 +99,18 @@ def test_basic(self, dp, pre_checkout_query): dp.process_update(pre_checkout_query) assert self.test_flag + def test_with_role(self, dp, pre_checkout_query, role): + handler = PreCheckoutQueryHandler(self.callback_basic, roles=role) + dp.add_handler(handler) + assert not handler.check_update(pre_checkout_query) + dp.process_update(pre_checkout_query) + assert not self.test_flag + + role.chat_ids = 1 + assert handler.check_update(pre_checkout_query) + dp.process_update(pre_checkout_query) + assert self.test_flag + def test_pass_user_or_chat_data(self, dp, pre_checkout_query): handler = PreCheckoutQueryHandler(self.callback_data_1, pass_user_data=True) diff --git a/tests/test_regexhandler.py b/tests/test_regexhandler.py index ae6b614e5a6..2f4449b17b5 100644 --- a/tests/test_regexhandler.py +++ b/tests/test_regexhandler.py @@ -89,6 +89,7 @@ def callback_context(self, update, context): and isinstance(context.user_data, dict) and isinstance(context.chat_data, dict) and isinstance(context.bot_data, dict) + and isinstance(context.roles, dict) and isinstance(update.message, Message)) def callback_context_pattern(self, update, context): @@ -117,6 +118,13 @@ def test_pattern(self, message): handler = RegexHandler('.*not in here.*', self.callback_basic) assert not handler.check_update(Update(0, message)) + def test_with_role(self, message, role): + handler = RegexHandler('.*est.*', self.callback_basic, roles=role) + assert not handler.check_update(Update(0, message)) + + role.chat_ids = 1 + assert handler.check_update(Update(0, message)) + def test_with_passing_group_dict(self, dp, message): handler = RegexHandler('(?P.*)est(?P.*)', self.callback_group, pass_groups=True) diff --git a/tests/test_roles.py b/tests/test_roles.py new file mode 100644 index 00000000000..1e467306394 --- /dev/null +++ b/tests/test_roles.py @@ -0,0 +1,706 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2020 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import datetime + +import pytest +import sys +import time + +from copy import deepcopy +from telegram import Message, User, InlineQuery, Update, ChatMember, Chat, TelegramError +from telegram.ext import (Role, Roles, MessageHandler, InlineQueryHandler, ChatAdminsRole, + ChatCreatorRole) + + +@pytest.fixture(scope='function') +def update(): + return Update(0, Message(0, User(0, 'Testuser', False), datetime.datetime.utcnow(), + Chat(0, 'private'))) + + +@pytest.fixture(scope='function') +def roles(bot): + return Roles(bot) + + +@pytest.fixture(scope='function') +def parent_role(): + return Role(name='parent_role') + + +@pytest.fixture(scope='function') +def role(): + return Role(name='role') + + +@pytest.fixture(scope='function') +def chat_admins_role(bot): + return ChatAdminsRole(bot, 0.05) + + +@pytest.fixture(scope='function') +def chat_creator_role(bot): + return ChatCreatorRole(bot) + + +class TestRole(object): + def test_creation(self, parent_role): + r = Role(parent_roles=[parent_role, parent_role]) + assert r.chat_ids == set() + assert str(r) == 'Role({})' + assert r.parent_roles == set([parent_role]) + + r = Role(child_roles=[parent_role, parent_role]) + assert r.chat_ids == set() + assert str(r) == 'Role({})' + assert r.child_roles == set([parent_role]) + + parent_role_2 = Role(name='parent_role_2') + r = Role(parent_roles=[parent_role, parent_role_2]) + assert r.chat_ids == set() + assert str(r) == 'Role({})' + assert r.parent_roles == set([parent_role, parent_role_2]) + + r = Role(1) + assert r.chat_ids == set([1]) + assert str(r) == 'Role({1})' + assert r.parent_roles == set() + + r = Role([1, 2]) + assert r.chat_ids == set([1, 2]) + assert str(r) == 'Role({1, 2})' + assert r.parent_roles == set() + + r = Role([1, 2], name='role') + assert r.chat_ids == set([1, 2]) + assert str(r) == 'Role(role)' + assert r.parent_roles == set() + + def test_chat_ids_property(self, role): + assert role.chat_ids is role.user_ids + role.chat_ids = 5 + assert role.chat_ids == set([5]) + + def test_add_member(self, role): + assert role.chat_ids == set() + role.add_member(1) + assert role.chat_ids == set([1]) + role.add_member(2) + assert role.chat_ids == set([1, 2]) + role.add_member(1) + assert role.chat_ids == set([1, 2]) + + def test_kick_member(self, role): + assert role.chat_ids == set() + role.add_member(1) + role.add_member(2) + assert role.chat_ids == set([1, 2]) + role.kick_member(1) + assert role.chat_ids == set([2]) + role.kick_member(1) + assert role.chat_ids == set([2]) + role.kick_member(2) + assert role.chat_ids == set() + + def test_add_remove_parent_role(self, role, parent_role): + assert role.parent_roles == set() + parent_role_2 = Role(chat_ids=456, name='pr2') + role.add_parent_role(parent_role) + assert role.parent_roles == set([parent_role]) + role.add_parent_role(parent_role_2) + assert role.parent_roles == set([parent_role, parent_role_2]) + + role.remove_parent_role(parent_role) + assert role.parent_roles == set([parent_role_2]) + role.remove_parent_role(parent_role_2) + assert role.parent_roles == set() + + with pytest.raises(ValueError, match='You must not add a role as its own parent!'): + role.add_parent_role(role) + + parent_role.add_parent_role(role) + with pytest.raises(ValueError, match='You must not add a child role as a parent!'): + role.add_parent_role(parent_role) + + def test_add_remove_child_role(self, role, parent_role): + assert role.child_roles == set() + parent_role_2 = Role(chat_ids=456, name='pr2') + role.add_child_role(parent_role) + assert role.child_roles == set([parent_role]) + role.add_child_role(parent_role_2) + assert role.child_roles == set([parent_role, parent_role_2]) + + role.remove_child_role(parent_role) + assert role.child_roles == set([parent_role_2]) + role.remove_child_role(parent_role_2) + assert role.child_roles == set() + + with pytest.raises(ValueError, match='You must not add a role as its own child!'): + role.add_child_role(role) + + parent_role.add_child_role(role) + with pytest.raises(ValueError, match='You must not add a parent role as a child!'): + role.add_child_role(parent_role) + + def test_equals(self, role, parent_role): + r = Role(name='test1') + r2 = Role(name='test2') + r3 = Role(name='test3', chat_ids=[1, 2]) + r4 = Role(name='test4') + assert role.equals(parent_role) + role.add_child_role(r) + assert not role.equals(parent_role) + parent_role.add_child_role(r2) + assert role.equals(parent_role) + parent_role.add_child_role(r3) + role.add_child_role(r4) + assert not role.equals(parent_role) + role.remove_child_role(r4) + parent_role.remove_child_role(r3) + + role.add_member(1) + assert not role.equals(parent_role) + parent_role.add_member(1) + assert role.equals(parent_role) + role.add_member(2) + assert not role.equals(parent_role) + parent_role.add_member(2) + assert role.equals(parent_role) + role.kick_member(2) + assert not role.equals(parent_role) + parent_role.kick_member(2) + assert role.equals(parent_role) + + r.add_member(1) + assert not role.equals(parent_role) + r2.add_member(1) + assert role.equals(parent_role) + + def test_comparison(self, role, parent_role): + assert not role <= 1 + assert not role >= 1 + + assert not role < parent_role + assert not parent_role < role + assert role <= role + assert role >= role + assert parent_role <= parent_role + assert parent_role >= parent_role + + role.add_parent_role(parent_role) + assert role < parent_role + assert role <= parent_role + assert parent_role >= role + assert parent_role > role + + role.remove_parent_role(parent_role) + assert not role < parent_role + assert not parent_role < role + + role.add_parent_role(parent_role) + assert role < parent_role + assert role <= parent_role + assert parent_role >= role + assert parent_role > role + + def test_hash(self, role, parent_role): + assert role != parent_role + assert hash(role) != hash(parent_role) + + assert role == role + assert hash(role) == hash(role) + + assert parent_role == parent_role + assert hash(parent_role) == hash(parent_role) + + def test_deepcopy(self, role, parent_role): + role.add_parent_role(parent_role) + child = Role(name='cr', chat_ids=[1, 2, 3], parent_roles=role) + crole = deepcopy(role) + + assert role is not crole + assert role.equals(crole) + assert role.chat_ids is not crole.chat_ids + assert role.chat_ids == crole.chat_ids + assert role.parent_roles is not crole.parent_roles + parent = role.parent_roles.pop() + cparent = crole.parent_roles.pop() + assert parent is not cparent + assert parent.equals(cparent) + cchild = crole.child_roles.pop() + assert child is not cchild + assert child.equals(cchild) + + def test_handler_user(self, update, role, parent_role): + handler = MessageHandler(role, None) + assert not handler.check_update(update) + + role.add_member(0) + parent_role.add_member(1) + + assert handler.check_update(update) + update.message.from_user.id = 1 + update.message.chat.id = 1 + assert not handler.check_update(update) + role.add_parent_role(parent_role) + assert handler.check_update(update) + + def test_handler_chat(self, update, role, parent_role): + handler = MessageHandler(role, None) + update.message.chat.id = 5 + assert not handler.check_update(update) + + role.add_member(5) + parent_role.add_member(6) + + assert handler.check_update(update) + update.message.chat.id = 6 + assert not handler.check_update(update) + role.add_parent_role(parent_role) + assert handler.check_update(update) + + def test_handler_merged_roles(self, update, role): + role.add_member(0) + r = Role(0) + + handler = MessageHandler(None, None, roles=role & (~r)) + assert not handler.check_update(update) + + r = Role(1) + handler = MessageHandler(None, None, roles=role & r) + assert not handler.check_update(update) + handler = MessageHandler(None, None, roles=role | r) + assert handler.check_update(update) + + def test_handler_allow_parent(self, update, role, parent_role): + role.add_member(0) + parent_role.add_member(1) + role.add_parent_role(parent_role) + + handler = MessageHandler(None, None, roles=~role) + assert not handler.check_update(update) + update.message.from_user.id = 1 + update.message.chat.id = 1 + assert handler.check_update(update) + + def test_handler_exclude_children(self, update, role, parent_role): + role.add_parent_role(parent_role) + parent_role.add_member(0) + role.add_member(1) + + handler = MessageHandler(None, None, roles=~parent_role) + assert not handler.check_update(update) + update.message.from_user.id = 1 + update.message.chat.id = 1 + assert not handler.check_update(update) + update.message.from_user.id = 2 + update.message.chat.id = 1 + assert not handler.check_update(update) + + def test_handler_without_user(self, update, role): + handler = MessageHandler(role, None) + role.add_member(0) + update.message = None + update.channel_post = Message(0, None, datetime.datetime.utcnow(), Chat(-1, 'channel')) + assert not handler.check_update(update) + + role.add_member(-1) + assert handler.check_update(update) + + +class TestChatAdminsRole(object): + def test_creation(self, bot): + admins = ChatAdminsRole(bot, timeout=7) + assert admins.timeout == 7 + assert admins.bot is bot + + def test_deepcopy(self, chat_admins_role): + chat_admins_role.cache.update({1: 2, 3: 4}) + cadmins = deepcopy(chat_admins_role) + + assert chat_admins_role is not cadmins + assert chat_admins_role.equals(cadmins) + assert chat_admins_role.chat_ids is not cadmins.chat_ids + assert chat_admins_role.chat_ids == cadmins.chat_ids + assert chat_admins_role.parent_roles is not cadmins.parent_roles + assert chat_admins_role.child_roles is not cadmins.child_roles + + assert chat_admins_role.bot is cadmins.bot + assert chat_admins_role.cache == cadmins.cache + assert chat_admins_role.timeout == cadmins.timeout + + def test_simple(self, chat_admins_role, update, monkeypatch): + def admins(*args, **kwargs): + return [ChatMember(User(0, 'TestUser0', False), 'administrator'), + ChatMember(User(1, 'TestUser1', False), 'creator')] + + monkeypatch.setattr(chat_admins_role.bot, 'get_chat_administrators', admins) + handler = MessageHandler(None, None, roles=chat_admins_role) + + update.message.from_user.id = 2 + assert not handler.check_update(update) + update.message.from_user.id = 1 + assert handler.check_update(update) + update.message.from_user.id = 0 + assert handler.check_update(update) + + def test_private_chat(self, chat_admins_role, update): + update.message.from_user.id = 2 + update.message.chat.id = 2 + handler = MessageHandler(None, None, roles=chat_admins_role) + + assert handler.check_update(update) + + def test_no_chat(self, chat_admins_role, update): + update.message = None + update.inline_query = InlineQuery(1, User(0, 'TestUser', False), 'query', 0) + handler = InlineQueryHandler(None, roles=chat_admins_role) + + assert not handler.check_update(update) + + def test_no_user(self, chat_admins_role, update): + update.message = None + update.channel_post = Message(1, None, datetime.datetime.utcnow(), Chat(0, 'channel')) + handler = InlineQueryHandler(None, roles=chat_admins_role) + + assert not handler.check_update(update) + + def test_caching(self, chat_admins_role, update, monkeypatch): + def admins(*args, **kwargs): + return [ChatMember(User(0, 'TestUser0', False), 'administrator'), + ChatMember(User(1, 'TestUser1', False), 'creator')] + + monkeypatch.setattr(chat_admins_role.bot, 'get_chat_administrators', admins) + handler = MessageHandler(None, None, roles=chat_admins_role) + + update.message.from_user.id = 2 + assert not handler.check_update(update) + assert isinstance(chat_admins_role.cache[0], tuple) + assert pytest.approx(chat_admins_role.cache[0][0]) == time.time() + assert chat_admins_role.cache[0][1] == [0, 1] + + def admins(*args, **kwargs): + raise ValueError('This method should not be called!') + + monkeypatch.setattr(chat_admins_role.bot, 'get_chat_administrators', admins) + + update.message.from_user.id = 1 + assert handler.check_update(update) + + time.sleep(0.05) + + def admins(*args, **kwargs): + return [ChatMember(User(2, 'TestUser0', False), 'administrator')] + + monkeypatch.setattr(chat_admins_role.bot, 'get_chat_administrators', admins) + + update.message.from_user.id = 2 + assert handler.check_update(update) + assert isinstance(chat_admins_role.cache[0], tuple) + assert pytest.approx(chat_admins_role.cache[0][0]) == time.time() + assert chat_admins_role.cache[0][1] == [2] + + +class TestChatCreatorRole(object): + def test_creation(self, bot): + creator = ChatCreatorRole(bot) + assert creator.bot is bot + + def test_deepcopy(self, chat_creator_role): + chat_creator_role.cache.update({1: 2, 3: 4}) + ccreator = deepcopy(chat_creator_role) + + assert chat_creator_role is not ccreator + assert chat_creator_role.equals(ccreator) + assert chat_creator_role.chat_ids is not ccreator.chat_ids + assert chat_creator_role.chat_ids == ccreator.chat_ids + assert chat_creator_role.parent_roles is not ccreator.parent_roles + assert chat_creator_role.child_roles is not ccreator.child_roles + + assert chat_creator_role.bot is ccreator.bot + assert chat_creator_role.cache == ccreator.cache + + def test_simple(self, chat_creator_role, monkeypatch, update): + def member(*args, **kwargs): + if args[1] == 0: + return ChatMember(User(0, 'TestUser0', False), 'administrator') + if args[1] == 1: + return ChatMember(User(1, 'TestUser1', False), 'creator') + raise TelegramError('User is not a member') + + monkeypatch.setattr(chat_creator_role.bot, 'get_chat_member', member) + handler = MessageHandler(None, None, roles=chat_creator_role) + + update.message.from_user.id = 0 + update.message.chat.id = -1 + assert not handler.check_update(update) + update.message.from_user.id = 1 + update.message.chat.id = 1 + assert handler.check_update(update) + update.message.from_user.id = 2 + update.message.chat.id = -2 + assert not handler.check_update(update) + + def test_no_chat(self, chat_creator_role, update): + update.message = None + update.inline_query = InlineQuery(1, User(0, 'TestUser', False), 'query', 0) + handler = InlineQueryHandler(None, roles=chat_creator_role) + + assert not handler.check_update(update) + + def test_no_user(self, chat_creator_role, update): + update.message = None + update.channel_post = Message(1, None, datetime.datetime.utcnow(), Chat(0, 'channel')) + handler = InlineQueryHandler(None, roles=chat_creator_role) + + assert not handler.check_update(update) + + def test_private_chat(self, chat_creator_role, update): + update.message.from_user.id = 2 + update.message.chat.id = 2 + handler = MessageHandler(None, None, roles=chat_creator_role) + + assert handler.check_update(update) + + def test_caching(self, chat_creator_role, monkeypatch, update): + def member(*args, **kwargs): + if args[1] == 0: + return ChatMember(User(0, 'TestUser0', False), 'administrator') + if args[1] == 1: + return ChatMember(User(1, 'TestUser1', False), 'creator') + raise TelegramError('User is not a member') + + monkeypatch.setattr(chat_creator_role.bot, 'get_chat_member', member) + handler = MessageHandler(None, None, roles=chat_creator_role) + + update.message.from_user.id = 1 + assert handler.check_update(update) + assert chat_creator_role.cache == {0: 1} + + def member(*args, **kwargs): + raise ValueError('This method should not be called!') + + monkeypatch.setattr(chat_creator_role.bot, 'get_chat_member', member) + + update.message.from_user.id = 1 + assert handler.check_update(update) + + update.message.from_user.id = 2 + assert not handler.check_update(update) + + +class TestRoles(object): + def test_creation(self, bot): + roles = Roles(bot) + assert isinstance(roles, dict) + assert isinstance(roles.ADMINS, Role) + assert isinstance(roles.CHAT_ADMINS, Role) + assert roles.CHAT_ADMINS.bot is bot + assert isinstance(roles.CHAT_CREATOR, Role) + assert roles.CHAT_CREATOR.bot is bot + assert roles.bot is bot + + def test_set_bot(self, bot): + roles = Roles(1) + assert roles.bot == 1 + roles.set_bot(2) + assert roles.bot == 2 + roles.set_bot(bot) + assert roles.bot is bot + with pytest.raises(ValueError, match='already set'): + roles.set_bot(bot) + + def test_add_kick_admin(self, roles): + assert roles.ADMINS.chat_ids == set() + roles.add_admin(1) + assert roles.ADMINS.chat_ids == set([1]) + roles.add_admin(2) + assert roles.ADMINS.chat_ids == set([1, 2]) + roles.kick_admin(1) + assert roles.ADMINS.chat_ids == set([2]) + roles.kick_admin(2) + assert roles.ADMINS.chat_ids == set() + + def test_equality(self, parent_role, roles, bot): + parent_role_2 = deepcopy(parent_role) + child_role = Role(name='child_role') + child_role_2 = deepcopy(child_role) + + roles2 = Roles(bot) + assert roles == roles2 + + roles.add_admin(1) + assert roles != roles2 + + roles2.add_admin(1) + assert roles == roles2 + + roles.add_role('test_role', chat_ids=123) + assert roles != roles2 + + roles2.add_role('test_role', chat_ids=123) + assert roles == roles2 + + roles.add_role('test_role_2', chat_ids=456) + assert roles != roles2 + + roles2.add_role('test_role_2', chat_ids=456) + assert roles == roles2 + + roles['test_role'].add_parent_role(parent_role) + roles2['test_role'].add_parent_role(parent_role_2) + assert roles == roles2 + + roles['test_role'].add_child_role(child_role) + assert roles != roles2 + + roles2['test_role'].add_child_role(child_role_2) + assert roles == roles2 + + def test_raise_errors(self, roles): + with pytest.raises(NotImplementedError, match='remove_role'): + del roles['test'] + with pytest.raises(ValueError, match='immutable'): + roles['test'] = True + with pytest.raises(ValueError, match='immutable'): + roles.setdefault('test', None) + with pytest.raises(ValueError, match='immutable'): + roles.update({'test': None}) + with pytest.raises(NotImplementedError, match='remove_role'): + roles.pop('test', None) + with pytest.raises(NotImplementedError, match='remove_role'): + roles.popitem('test') + with pytest.raises(NotImplementedError, match='remove_role'): + roles.clear() + with pytest.raises(NotImplementedError): + roles.copy() + + @pytest.mark.skipif(sys.version_info < (3, 6), reason="dicts are not ordered in py<=3.5") + def test_dict_functionality(self, roles): + roles.add_role('role0', 0) + roles.add_role('role1', 1) + roles.add_role('role2', 2) + + assert 'role2' in roles + assert 'role3' not in roles + + a = set([name for name in roles]) + assert a == set(['role{}'.format(k) for k in range(3)]) + + b = {name: role.chat_ids for name, role in roles.items()} + assert b == {'role{}'.format(k): set([k]) for k in range(3)} + + c = [name for name in roles.keys()] + assert c == ['role{}'.format(k) for k in range(3)] + + d = [r.chat_ids for r in roles.values()] + assert d == [set([k]) for k in range(3)] + + def test_deepcopy(self, roles, parent_role): + roles.add_admin(123) + roles.CHAT_ADMINS.timeout = 7 + child_role = Role(name='child_role') + roles.add_role(name='test', chat_ids=[1, 2], parent_roles=parent_role, + child_roles=child_role) + roles.add_role(name='test2', chat_ids=[3, 4], child_roles=roles['test']) + croles = deepcopy(roles) + + assert croles is not roles + assert croles == roles + assert roles.ADMINS is not croles.ADMINS + assert roles.ADMINS.equals(croles.ADMINS) + assert roles.CHAT_ADMINS.timeout == croles.CHAT_ADMINS.timeout + assert roles['test'] is not croles['test'] + assert roles['test'].equals(croles['test']) + assert roles['test2'] is not croles['test2'] + assert roles['test2'].equals(croles['test2']) + + def test_add_remove_role(self, roles, parent_role): + roles.add_role('role', parent_roles=parent_role) + role = roles['role'] + assert role.chat_ids == set() + assert role.parent_roles == set([parent_role, roles.ADMINS]) + assert str(role) == 'Role(role)' + assert roles.ADMINS in role.parent_roles + + with pytest.raises(ValueError, match='Role name is already taken.'): + roles.add_role('role', parent_roles=parent_role) + + roles.remove_role('role') + assert not roles.get('role', None) + assert roles.ADMINS not in role.parent_roles + + def test_handler_admins(self, roles, update): + roles.add_role('role', 0) + roles.add_admin(1) + handler = MessageHandler(None, None, roles=roles['role']) + assert handler.check_update(update) + update.message.from_user.id = 1 + update.message.chat.id = 1 + assert handler.check_update(update) + roles.kick_admin(1) + assert not handler.check_update(update) + + def test_handler_admins_merged(self, roles, update): + roles.add_role('role_1', 0) + roles.add_role('role_2', 1) + roles.add_admin(2) + handler = MessageHandler(None, None, roles=roles['role_1'] & ~roles['role_2']) + assert handler.check_update(update) + update.message.from_user.id = 2 + update.message.chat.id = 2 + assert handler.check_update(update) + roles.kick_admin(2) + assert not handler.check_update(update) + + def test_json_encoding_decoding(self, roles, parent_role, bot): + child_role = Role(name='child_role') + roles.add_role('role_1', chat_ids=[1, 2, 3]) + roles.add_role('role_2', chat_ids=[4, 5, 6], parent_roles=parent_role, + child_roles=child_role) + roles.add_role('role_3', chat_ids=[7, 8], parent_roles=parent_role, child_roles=child_role) + roles.add_admin(9) + roles.add_admin(10) + roles.CHAT_ADMINS.timeout = 7 + + json_str = roles.encode_to_json() + assert isinstance(json_str, str) + assert json_str != '' + + rroles = Roles.decode_from_json(json_str, bot) + assert rroles == roles + assert rroles.bot is bot + for name in rroles: + assert rroles[name] <= rroles.ADMINS + assert rroles.ADMINS.chat_ids == set([9, 10]) + assert rroles.ADMINS.equals(roles.ADMINS) + assert rroles.CHAT_ADMINS.timeout == roles.CHAT_ADMINS.timeout + assert rroles['role_1'].chat_ids == set([1, 2, 3]) + assert rroles['role_1'].equals(Role(name='role_1', chat_ids=[1, 2, 3])) + assert rroles['role_2'].chat_ids == set([4, 5, 6]) + assert rroles['role_2'].equals(Role(name='role_2', chat_ids=[4, 5, 6], + parent_roles=parent_role, child_roles=child_role)) + assert rroles['role_3'].chat_ids == set([7, 8]) + assert rroles['role_3'].equals(Role(name='role_3', chat_ids=[7, 8], + parent_roles=parent_role, child_roles=child_role)) + for name in rroles: + assert rroles[name] <= rroles.ADMINS + assert rroles[name] < rroles.ADMINS + assert rroles.ADMINS >= rroles[name] + assert rroles.ADMINS > rroles[name] diff --git a/tests/test_shippingqueryhandler.py b/tests/test_shippingqueryhandler.py index 65870c76a85..af8b3031006 100644 --- a/tests/test_shippingqueryhandler.py +++ b/tests/test_shippingqueryhandler.py @@ -89,6 +89,7 @@ def callback_context(self, update, context): and isinstance(context.user_data, dict) and context.chat_data is None and isinstance(context.bot_data, dict) + and isinstance(context.roles, dict) and isinstance(update.shipping_query, ShippingQuery)) def test_basic(self, dp, shiping_query): @@ -99,6 +100,19 @@ def test_basic(self, dp, shiping_query): dp.process_update(shiping_query) assert self.test_flag + def test_with_role(self, dp, shiping_query, role): + handler = ShippingQueryHandler(self.callback_basic, roles=role) + dp.add_handler(handler) + + assert not handler.check_update(shiping_query) + dp.process_update(shiping_query) + assert not self.test_flag + + role.chat_ids = 1 + assert handler.check_update(shiping_query) + dp.process_update(shiping_query) + assert self.test_flag + def test_pass_user_or_chat_data(self, dp, shiping_query): handler = ShippingQueryHandler(self.callback_data_1, pass_user_data=True)