diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index 43d2fe3eccc..f10b66f8772 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -20,8 +20,9 @@ import re -from future.utils import string_types from abc import ABC, abstractmethod +from future.utils import string_types +from threading import Lock from telegram import Chat, Update, MessageEntity @@ -892,38 +893,166 @@ class user(BaseFilter): Examples: ``MessageHandler(Filters.user(1234), callback_method)`` + Warning: + :attr:`user_ids` will give a *copy* of the saved user ids as :class:`frozenset`. This + is to ensure thread safety. To add/remove a user, you should use :meth:`add_usernames`, + :meth:`add_user_ids`, :meth:`remove_usernames` and :meth:`remove_user_ids`. Only update + the entire set by ``filter.user_ids/usernames = new_set``, if you are entirely sure + that it is not causing race conditions, as this will complete replace the current set + of allowed users. + + Attributes: + user_ids(set(:obj:`int`), optional): Which user ID(s) to allow through. + usernames(set(:obj:`str`), optional): Which username(s) (without leading '@') to allow + through. + allow_empty(:obj:`bool`, optional): Whether updates should be processed, if no user + is specified in :attr:`user_ids` and :attr:`usernames`. + Args: - user_id(:obj:`int` | List[:obj:`int`], optional): Which user ID(s) to allow through. - username(:obj:`str` | List[:obj:`str`], optional): Which username(s) to allow through. - If username starts with '@' symbol, it will be ignored. + user_id(:obj:`int` | List[:obj:`int`], optional): Which user ID(s) to allow + through. + username(:obj:`str` | List[:obj:`str`], optional): Which username(s) to allow + through. Leading '@'s in usernames will be discarded. + allow_empty(:obj:`bool`, optional): Whether updates should be processed, if no user + is specified in :attr:`user_ids` and :attr:`usernames`. Defaults to :obj:`False` Raises: - ValueError: If chat_id and username are both present, or neither is. + RuntimeError: If user_id and username are both present. """ - def __init__(self, user_id=None, username=None): - if not (bool(user_id) ^ bool(username)): - raise ValueError('One and only one of user_id or username must be used') - if user_id is not None and isinstance(user_id, int): - self.user_ids = [user_id] - else: - self.user_ids = user_id + def __init__(self, user_id=None, username=None, allow_empty=False): + self.allow_empty = allow_empty + self.__lock = Lock() + + self._user_ids = set() + self._usernames = set() + + self._set_user_ids(user_id) + self._set_usernames(username) + + @staticmethod + def _parse_user_id(user_id): + if user_id is None: + return set() + if isinstance(user_id, int): + return {user_id} + return set(user_id) + + @staticmethod + def _parse_username(username): if username is None: - self.usernames = username - elif isinstance(username, string_types): - self.usernames = [username.replace('@', '')] - else: - self.usernames = [user.replace('@', '') for user in username] + return set() + if isinstance(username, str): + return {username[1:] if username.startswith('@') else username} + return {user[1:] if user.startswith('@') else user for user in username} + + def _set_user_ids(self, user_id): + with self.__lock: + if user_id and self._usernames: + raise RuntimeError("Can't set user_id in conjunction with (already set) " + "usernames.") + self._user_ids = self._parse_user_id(user_id) + + def _set_usernames(self, username): + with self.__lock: + if username and self._user_ids: + raise RuntimeError("Can't set username in conjunction with (already set) " + "user_ids.") + self._usernames = self._parse_username(username) + + @property + def user_ids(self): + with self.__lock: + return frozenset(self._user_ids) + + @user_ids.setter + def user_ids(self, user_id): + self._set_user_ids(user_id) + + @property + def usernames(self): + with self.__lock: + return frozenset(self._usernames) + + @usernames.setter + def usernames(self, username): + self._set_usernames(username) + + def add_usernames(self, username): + """ + Add one or more users to the allowed usernames. + + Args: + username(:obj:`str` | List[:obj:`str`], optional): Which username(s) to allow + through. Leading '@'s in usernames will be discarded. + """ + with self.__lock: + if self._user_ids: + raise RuntimeError("Can't set username in conjunction with (already set) " + "user_ids.") + + username = self._parse_username(username) + self._usernames |= username + + def add_user_ids(self, user_id): + """ + Add one or more users to the allowed user ids. + + Args: + user_id(:obj:`int` | List[:obj:`int`], optional): Which user ID(s) to allow + through. + """ + with self.__lock: + if self._usernames: + raise RuntimeError("Can't set user_id in conjunction with (already set) " + "usernames.") + + user_id = self._parse_user_id(user_id) + + self._user_ids |= user_id + + def remove_usernames(self, username): + """ + Remove one or more users from allowed usernames. + + Args: + username(:obj:`str` | List[:obj:`str`], optional): Which username(s) to disallow + through. Leading '@'s in usernames will be discarded. + """ + with self.__lock: + if self._user_ids: + raise RuntimeError("Can't set username in conjunction with (already set) " + "user_ids.") + + username = self._parse_username(username) + self._usernames -= username + + def remove_user_ids(self, user_id): + """ + Remove one or more users from allowed user ids. + + Args: + user_id(:obj:`int` | List[:obj:`int`], optional): Which user ID(s) to disallow + through. + """ + with self.__lock: + if self._usernames: + raise RuntimeError("Can't set user_id in conjunction with (already set) " + "usernames.") + user_id = self._parse_user_id(user_id) + self._user_ids -= user_id def filter(self, message): """""" # remove method from docs - if self.user_ids is not None: - return bool(message.from_user and message.from_user.id in self.user_ids) - else: - # self.usernames is not None - return bool(message.from_user and message.from_user.username + if message.from_user: + if self.user_ids: + return message.from_user.id in self.user_ids + if self.usernames: + return (message.from_user.username and message.from_user.username in self.usernames) + return self.allow_empty + return False class chat(BaseFilter): """Filters messages to allow only those which are from specified chat ID. @@ -931,37 +1060,166 @@ class chat(BaseFilter): Examples: ``MessageHandler(Filters.chat(-1234), callback_method)`` + Warning: + :attr:`chat_ids` will give a *copy* of the saved chat ids as :class:`frozenset`. This + is to ensure thread safety. To add/remove a chat, you should use :meth:`add_usernames`, + :meth:`add_chat_ids`, :meth:`remove_usernames` and :meth:`remove_chat_ids`. Only update + the entire set by ``filter.chat_ids/usernames = new_set``, if you are entirely sure + that it is not causing race conditions, as this will complete replace the current set + of allowed chats. + + Attributes: + chat_ids(set(:obj:`int`), optional): Which chat ID(s) to allow through. + usernames(set(:obj:`str`), optional): Which username(s) (without leading '@') to allow + through. + allow_empty(:obj:`bool`, optional): Whether updates should be processed, if no chat + is specified in :attr:`chat_ids` and :attr:`usernames`. + Args: - chat_id(:obj:`int` | List[:obj:`int`], optional): Which chat ID(s) to allow through. - username(:obj:`str` | List[:obj:`str`], optional): Which username(s) to allow through. - If username start swith '@' symbol, it will be ignored. + chat_id(:obj:`int` | List[:obj:`int`], optional): Which chat ID(s) to allow + through. + username(:obj:`str` | List[:obj:`str`], optional): Which username(s) to allow + through. Leading '@'s in usernames will be discarded. + allow_empty(:obj:`bool`, optional): Whether updates should be processed, if no chat + is specified in :attr:`chat_ids` and :attr:`usernames`. Defaults to :obj:`False` Raises: - ValueError: If chat_id and username are both present, or neither is. + RuntimeError: If chat_id and username are both present. """ - def __init__(self, chat_id=None, username=None): - if not (bool(chat_id) ^ bool(username)): - raise ValueError('One and only one of chat_id or username must be used') - if chat_id is not None and isinstance(chat_id, int): - self.chat_ids = [chat_id] - else: - self.chat_ids = chat_id + def __init__(self, chat_id=None, username=None, allow_empty=False): + self.allow_empty = allow_empty + self.__lock = Lock() + + self._chat_ids = set() + self._usernames = set() + + self._set_chat_ids(chat_id) + self._set_usernames(username) + + @staticmethod + def _parse_chat_id(chat_id): + if chat_id is None: + return set() + if isinstance(chat_id, int): + return {chat_id} + return set(chat_id) + + @staticmethod + def _parse_username(username): if username is None: - self.usernames = username - elif isinstance(username, string_types): - self.usernames = [username.replace('@', '')] - else: - self.usernames = [chat.replace('@', '') for chat in username] + return set() + if isinstance(username, str): + return {username[1:] if username.startswith('@') else username} + return {chat[1:] if chat.startswith('@') else chat for chat in username} + + def _set_chat_ids(self, chat_id): + with self.__lock: + if chat_id and self._usernames: + raise RuntimeError("Can't set chat_id in conjunction with (already set) " + "usernames.") + self._chat_ids = self._parse_chat_id(chat_id) + + def _set_usernames(self, username): + with self.__lock: + if username and self._chat_ids: + raise RuntimeError("Can't set username in conjunction with (already set) " + "chat_ids.") + self._usernames = self._parse_username(username) + + @property + def chat_ids(self): + with self.__lock: + return frozenset(self._chat_ids) + + @chat_ids.setter + def chat_ids(self, chat_id): + self._set_chat_ids(chat_id) + + @property + def usernames(self): + with self.__lock: + return frozenset(self._usernames) + + @usernames.setter + def usernames(self, username): + self._set_usernames(username) + + def add_usernames(self, username): + """ + Add one or more chats to the allowed usernames. + + Args: + username(:obj:`str` | List[:obj:`str`], optional): Which username(s) to allow + through. Leading '@'s in usernames will be discarded. + """ + with self.__lock: + if self._chat_ids: + raise RuntimeError("Can't set username in conjunction with (already set) " + "chat_ids.") + + username = self._parse_username(username) + self._usernames |= username + + def add_chat_ids(self, chat_id): + """ + Add one or more chats to the allowed chat ids. + + Args: + chat_id(:obj:`int` | List[:obj:`int`], optional): Which chat ID(s) to allow + through. + """ + with self.__lock: + if self._usernames: + raise RuntimeError("Can't set chat_id in conjunction with (already set) " + "usernames.") + + chat_id = self._parse_chat_id(chat_id) + + self._chat_ids |= chat_id + + def remove_usernames(self, username): + """ + Remove one or more chats from allowed usernames. + + Args: + username(:obj:`str` | List[:obj:`str`], optional): Which username(s) to disallow + through. Leading '@'s in usernames will be discarded. + """ + with self.__lock: + if self._chat_ids: + raise RuntimeError("Can't set username in conjunction with (already set) " + "chat_ids.") + + username = self._parse_username(username) + self._usernames -= username + + def remove_chat_ids(self, chat_id): + """ + Remove one or more chats from allowed chat ids. + + Args: + chat_id(:obj:`int` | List[:obj:`int`], optional): Which chat ID(s) to disallow + through. + """ + with self.__lock: + if self._usernames: + raise RuntimeError("Can't set chat_id in conjunction with (already set) " + "usernames.") + chat_id = self._parse_chat_id(chat_id) + self._chat_ids -= chat_id def filter(self, message): """""" # remove method from docs - if self.chat_ids is not None: - return bool(message.chat_id in self.chat_ids) - else: - # self.usernames is not None - return bool(message.chat.username and message.chat.username in self.usernames) + if message.chat: + if self.chat_ids: + return message.chat.id in self.chat_ids + if self.usernames: + return (message.chat.username + and message.chat.username in self.usernames) + return self.allow_empty + return False class _Invoice(BaseFilter): name = 'Filters.invoice' diff --git a/tests/test_filters.py b/tests/test_filters.py index 2fa1adad80e..6d8917d8727 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -557,11 +557,13 @@ def test_group_filter(self, update): update.message.chat.type = 'supergroup' assert Filters.group(update) - def test_filters_user(self): - with pytest.raises(ValueError, match='user_id or username'): + def test_filters_user_init(self): + with pytest.raises(RuntimeError, match='in conjunction with'): Filters.user(user_id=1, username='user') - with pytest.raises(ValueError, match='user_id or username'): - Filters.user() + + def test_filters_user_allow_empty(self, update): + assert not Filters.user()(update) + assert Filters.user(allow_empty=True)(update) def test_filters_user_id(self, update): assert not Filters.user(user_id=1)(update) @@ -570,37 +572,240 @@ def test_filters_user_id(self, update): update.message.from_user.id = 2 assert Filters.user(user_id=[1, 2])(update) assert not Filters.user(user_id=[3, 4])(update) + update.message.from_user = None + assert not Filters.user(user_id=[3, 4])(update) def test_filters_username(self, update): assert not Filters.user(username='user')(update) assert not Filters.user(username='Testuser')(update) - update.message.from_user.username = 'user' - assert Filters.user(username='@user')(update) - assert Filters.user(username='user')(update) - assert Filters.user(username=['user1', 'user', 'user2'])(update) + update.message.from_user.username = 'user@' + assert Filters.user(username='@user@')(update) + assert Filters.user(username='user@')(update) + assert Filters.user(username=['user1', 'user@', 'user2'])(update) assert not Filters.user(username=['@username', '@user_2'])(update) + update.message.from_user = None + assert not Filters.user(username=['@username', '@user_2'])(update) + + def test_filters_user_change_id(self, update): + f = Filters.user(user_id=1) + update.message.from_user.id = 1 + assert f(update) + update.message.from_user.id = 2 + assert not f(update) + f.user_ids = 2 + assert f(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.usernames = 'user' + + def test_filters_user_change_username(self, update): + f = Filters.user(username='user') + update.message.from_user.username = 'user' + assert f(update) + update.message.from_user.username = 'User' + assert not f(update) + f.usernames = 'User' + assert f(update) + + with pytest.raises(RuntimeError, match='user_id in conjunction'): + f.user_ids = 1 + + def test_filters_user_add_user_by_name(self, update): + users = ['user_a', 'user_b', 'user_c'] + f = Filters.user() + + for user in users: + update.message.from_user.username = user + assert not f(update) + + f.add_usernames('user_a') + f.add_usernames(['user_b', 'user_c']) + + for user in users: + update.message.from_user.username = user + assert f(update) + + with pytest.raises(RuntimeError, match='user_id in conjunction'): + f.add_user_ids(1) + + def test_filters_user_add_user_by_id(self, update): + users = [1, 2, 3] + f = Filters.user() + + for user in users: + update.message.from_user.id = user + assert not f(update) + + f.add_user_ids(1) + f.add_user_ids([2, 3]) + + for user in users: + update.message.from_user.username = user + assert f(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.add_usernames('user') + + def test_filters_user_remove_user_by_name(self, update): + users = ['user_a', 'user_b', 'user_c'] + f = Filters.user(username=users) + + with pytest.raises(RuntimeError, match='user_id in conjunction'): + f.remove_user_ids(1) + + for user in users: + update.message.from_user.username = user + assert f(update) + + f.remove_usernames('user_a') + f.remove_usernames(['user_b', 'user_c']) + + for user in users: + update.message.from_user.username = user + assert not f(update) + + def test_filters_user_remove_user_by_id(self, update): + users = [1, 2, 3] + f = Filters.user(user_id=users) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.remove_usernames('user') - def test_filters_chat(self): - with pytest.raises(ValueError, match='chat_id or username'): - Filters.chat(chat_id=-1, username='chat') - with pytest.raises(ValueError, match='chat_id or username'): - Filters.chat() + for user in users: + update.message.from_user.id = user + assert f(update) + + f.remove_user_ids(1) + f.remove_user_ids([2, 3]) + + for user in users: + update.message.from_user.username = user + assert not f(update) + + def test_filters_chat_init(self): + with pytest.raises(RuntimeError, match='in conjunction with'): + Filters.chat(chat_id=1, username='chat') + + def test_filters_chat_allow_empty(self, update): + assert not Filters.chat()(update) + assert Filters.chat(allow_empty=True)(update) def test_filters_chat_id(self, update): - assert not Filters.chat(chat_id=-1)(update) - update.message.chat.id = -1 - assert Filters.chat(chat_id=-1)(update) - update.message.chat.id = -2 - assert Filters.chat(chat_id=[-1, -2])(update) - assert not Filters.chat(chat_id=[-3, -4])(update) + assert not Filters.chat(chat_id=1)(update) + update.message.chat.id = 1 + assert Filters.chat(chat_id=1)(update) + update.message.chat.id = 2 + assert Filters.chat(chat_id=[1, 2])(update) + assert not Filters.chat(chat_id=[3, 4])(update) + update.message.chat = None + assert not Filters.chat(chat_id=[3, 4])(update) def test_filters_chat_username(self, update): assert not Filters.chat(username='chat')(update) + assert not Filters.chat(username='Testchat')(update) + update.message.chat.username = 'chat@' + assert Filters.chat(username='@chat@')(update) + assert Filters.chat(username='chat@')(update) + assert Filters.chat(username=['chat1', 'chat@', 'chat2'])(update) + assert not Filters.chat(username=['@username', '@chat_2'])(update) + update.message.chat = None + assert not Filters.chat(username=['@username', '@chat_2'])(update) + + def test_filters_chat_change_id(self, update): + f = Filters.chat(chat_id=1) + update.message.chat.id = 1 + assert f(update) + update.message.chat.id = 2 + assert not f(update) + f.chat_ids = 2 + assert f(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.usernames = 'chat' + + def test_filters_chat_change_username(self, update): + f = Filters.chat(username='chat') update.message.chat.username = 'chat' - assert Filters.chat(username='@chat')(update) - assert Filters.chat(username='chat')(update) - assert Filters.chat(username=['chat1', 'chat', 'chat2'])(update) - assert not Filters.chat(username=['@chat1', 'chat_2'])(update) + assert f(update) + update.message.chat.username = 'User' + assert not f(update) + f.usernames = 'User' + assert f(update) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.chat_ids = 1 + + def test_filters_chat_add_chat_by_name(self, update): + chats = ['chat_a', 'chat_b', 'chat_c'] + f = Filters.chat() + + for chat in chats: + update.message.chat.username = chat + assert not f(update) + + f.add_usernames('chat_a') + f.add_usernames(['chat_b', 'chat_c']) + + for chat in chats: + update.message.chat.username = chat + assert f(update) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.add_chat_ids(1) + + def test_filters_chat_add_chat_by_id(self, update): + chats = [1, 2, 3] + f = Filters.chat() + + for chat in chats: + update.message.chat.id = chat + assert not f(update) + + f.add_chat_ids(1) + f.add_chat_ids([2, 3]) + + for chat in chats: + update.message.chat.username = chat + assert f(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.add_usernames('chat') + + def test_filters_chat_remove_chat_by_name(self, update): + chats = ['chat_a', 'chat_b', 'chat_c'] + f = Filters.chat(username=chats) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.remove_chat_ids(1) + + for chat in chats: + update.message.chat.username = chat + assert f(update) + + f.remove_usernames('chat_a') + f.remove_usernames(['chat_b', 'chat_c']) + + for chat in chats: + update.message.chat.username = chat + assert not f(update) + + def test_filters_chat_remove_chat_by_id(self, update): + chats = [1, 2, 3] + f = Filters.chat(chat_id=chats) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.remove_usernames('chat') + + for chat in chats: + update.message.chat.id = chat + assert f(update) + + f.remove_chat_ids(1) + f.remove_chat_ids([2, 3]) + + for chat in chats: + update.message.chat.username = chat + assert not f(update) def test_filters_invoice(self, update): assert not Filters.invoice(update)