From 79fcf7758899d4e443b3b6368b58d5753a282c5d Mon Sep 17 00:00:00 2001 From: Hinrich Mahler Date: Mon, 27 Jul 2020 21:44:32 +0200 Subject: [PATCH 1/2] Refactor handling of message vs update filters --- telegram/ext/__init__.py | 14 +-- telegram/ext/filters.py | 206 ++++++++++++++++++++++----------------- telegram/files/venue.py | 2 +- tests/conftest.py | 4 +- tests/test_filters.py | 12 +-- 5 files changed, 132 insertions(+), 106 deletions(-) diff --git a/telegram/ext/__init__.py b/telegram/ext/__init__.py index e77b5567334..a39b067e9b1 100644 --- a/telegram/ext/__init__.py +++ b/telegram/ext/__init__.py @@ -29,7 +29,7 @@ from .callbackqueryhandler import CallbackQueryHandler from .choseninlineresulthandler import ChosenInlineResultHandler from .inlinequeryhandler import InlineQueryHandler -from .filters import BaseFilter, Filters +from .filters import BaseFilter, MessageFilter, UpdateFilter, Filters from .messagehandler import MessageHandler from .commandhandler import CommandHandler, PrefixHandler from .regexhandler import RegexHandler @@ -47,9 +47,9 @@ __all__ = ('Dispatcher', 'JobQueue', 'Job', 'Updater', 'CallbackQueryHandler', 'ChosenInlineResultHandler', 'CommandHandler', 'Handler', 'InlineQueryHandler', - 'MessageHandler', 'BaseFilter', 'Filters', 'RegexHandler', 'StringCommandHandler', - 'StringRegexHandler', 'TypeHandler', 'ConversationHandler', - 'PreCheckoutQueryHandler', 'ShippingQueryHandler', 'MessageQueue', 'DelayQueue', - 'DispatcherHandlerStop', 'run_async', 'CallbackContext', 'BasePersistence', - 'PicklePersistence', 'DictPersistence', 'PrefixHandler', 'PollAnswerHandler', - 'PollHandler', 'Defaults') + 'MessageHandler', 'BaseFilter', 'MessageFilter', 'UpdateFilter', 'Filters', + 'RegexHandler', 'StringCommandHandler', 'StringRegexHandler', 'TypeHandler', + 'ConversationHandler', 'PreCheckoutQueryHandler', 'ShippingQueryHandler', + 'MessageQueue', 'DelayQueue', 'DispatcherHandlerStop', 'run_async', 'CallbackContext', + 'BasePersistence', 'PicklePersistence', 'DictPersistence', 'PrefixHandler', + 'PollAnswerHandler', 'PollHandler', 'Defaults') diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index de1d85771d9..44051bd22b3 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -25,11 +25,12 @@ from telegram import Chat, Update, MessageEntity -__all__ = ['Filters', 'BaseFilter', 'InvertedFilter', 'MergedFilter'] +__all__ = ['Filters', 'BaseFilter', 'MessageFilter', 'UpdateFilter', 'InvertedFilter', + 'MergedFilter'] class BaseFilter(ABC): - """Base class for all Message Filters. + """Base class for all Filters. Subclassing from this class filters to be combined using bitwise operators: @@ -56,14 +57,15 @@ class BaseFilter(ABC): >>> Filters.regex(r'(a?x)') | Filters.regex(r'(b?x)') - With a message.text of `x`, will only ever return the matches for the first filter, + With ``message.text == x``, will only ever return the matches for the first filter, since the second one is never evaluated. - If you want to create your own filters create a class inheriting from this class and implement - a `filter` method that returns a boolean: `True` if the message should be handled, `False` - otherwise. Note that the filters work only as class instances, not actual class objects - (so remember to initialize your filter classes). + If you want to create your own filters create a class inheriting from either + :class:`MessageFilter` or :class:`UpdateFilter` and implement a ``filter`` method that + returns a boolean: :obj:`True` if the message should be handled, :obj:`False` otherwise. + Note that the filters work only as class instances, not actual class objects (so remember to + initialize your filter classes). By default the filters name (what will get printed when converted to a string for display) will be the class name. If you want to overwrite this assign a better name to the `name` @@ -71,8 +73,6 @@ class variable. Attributes: name (:obj:`str`): Name for this filter. Defaults to the type of filter. - update_filter (:obj:`bool`): Whether this filter should work on update. If ``False`` it - will run the filter on :attr:`update.effective_message``. Default is ``False``. data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should return a dict with lists. The dict will be merged with :class:`telegram.ext.CallbackContext`'s internal dict in most cases @@ -80,14 +80,11 @@ class variable. """ name = None - update_filter = False data_filter = False + @abstractmethod def __call__(self, update): - if self.update_filter: - return self.filter(update) - else: - return self.filter(update.effective_message) + ... def __and__(self, other): return MergedFilter(self, and_filter=other) @@ -104,13 +101,54 @@ def __repr__(self): self.name = self.__class__.__name__ return self.name + +class MessageFilter(BaseFilter, ABC): + """Base class for all Message Filters. In contrast to :class:`UpdateFilter`, the object passed + to :meth:`filter` is ``update.effective_message``. + + Attributes: + name (:obj:`str`): Name for this filter. Defaults to the type of filter. + data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should + return a dict with lists. The dict will be merged with + :class:`telegram.ext.CallbackContext`'s internal dict in most cases + (depends on the handler). + + """ + def __call__(self, update): + return self.filter(update.effective_message) + @abstractmethod - def filter(self, update): + def filter(self, message): """This method must be overwritten. - Note: - If :attr:`update_filter` is false then the first argument is `message` and of - type :class:`telegram.Message`. + Args: + message (:class:`telegram.Message`): The message that is tested. + + Returns: + :obj:`dict` or :obj:`bool` + + """ + + +class UpdateFilter(BaseFilter, ABC): + """Base class for all Update Filters. In contrast to :class:`UpdateFilter`, the object + passed to :meth:`filter` is ``update``, which allows to create filters like + :attr:`Filters.update.edited_message`. + + Attributes: + name (:obj:`str`): Name for this filter. Defaults to the type of filter. + data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should + return a dict with lists. The dict will be merged with + :class:`telegram.ext.CallbackContext`'s internal dict in most cases + (depends on the handler). + + """ + def __call__(self, update): + return self.filter(update) + + @abstractmethod + def filter(self, update): + """This method must be overwritten. Args: update (:class:`telegram.Update`): The update that is tested. @@ -121,15 +159,13 @@ def filter(self, update): """ -class InvertedFilter(BaseFilter): +class InvertedFilter(UpdateFilter): """Represents a filter that has been inverted. Args: f: The filter to invert. """ - update_filter = True - def __init__(self, f): self.f = f @@ -140,7 +176,7 @@ def __repr__(self): return "".format(self.f) -class MergedFilter(BaseFilter): +class MergedFilter(UpdateFilter): """Represents a filter consisting of two other filters. Args: @@ -149,8 +185,6 @@ class MergedFilter(BaseFilter): or_filter: Optional filter to "or" with base_filter. Mutually exclusive with and_filter. """ - update_filter = True - def __init__(self, base_filter, and_filter=None, or_filter=None): self.base_filter = base_filter if self.base_filter.data_filter: @@ -215,13 +249,13 @@ def __repr__(self): self.and_filter or self.or_filter) -class _DiceEmoji(BaseFilter): +class _DiceEmoji(MessageFilter): def __init__(self, emoji=None, name=None): self.name = 'Filters.dice.{}'.format(name) if name else 'Filters.dice' self.emoji = emoji - class _DiceValues(BaseFilter): + class _DiceValues(MessageFilter): def __init__(self, values, name, emoji=None): self.values = [values] if isinstance(values, int) else values @@ -248,7 +282,8 @@ def filter(self, message): class Filters: - """Predefined filters for use as the `filter` argument of :class:`telegram.ext.MessageHandler`. + """Predefined filters for use as the ``filter`` argument of + :class:`telegram.ext.MessageHandler`. Examples: Use ``MessageHandler(Filters.video, callback_method)`` to filter all video @@ -256,7 +291,7 @@ class Filters: """ - class _All(BaseFilter): + class _All(MessageFilter): name = 'Filters.all' def filter(self, message): @@ -265,10 +300,10 @@ def filter(self, message): all = _All() """All Messages.""" - class _Text(BaseFilter): + class _Text(MessageFilter): name = 'Filters.text' - class _TextStrings(BaseFilter): + class _TextStrings(MessageFilter): def __init__(self, strings): self.strings = strings @@ -316,10 +351,10 @@ def filter(self, message): exact matches are allowed. If not specified, will allow any text message. """ - class _Caption(BaseFilter): + class _Caption(MessageFilter): name = 'Filters.caption' - class _CaptionStrings(BaseFilter): + class _CaptionStrings(MessageFilter): def __init__(self, strings): self.strings = strings @@ -351,10 +386,10 @@ def filter(self, message): exact matches are allowed. If not specified, will allow any message with a caption. """ - class _Command(BaseFilter): + class _Command(MessageFilter): name = 'Filters.command' - class _CommandOnlyStart(BaseFilter): + class _CommandOnlyStart(MessageFilter): def __init__(self, only_start): self.only_start = only_start @@ -393,7 +428,7 @@ def filter(self, message): command. Defaults to ``True``. """ - class regex(BaseFilter): + class regex(MessageFilter): """ Filters updates by searching for an occurrence of ``pattern`` in the message text. The ``re.search`` function is used to determine whether an update should be filtered. @@ -438,7 +473,7 @@ def filter(self, message): return {'matches': [match]} return {} - class _Reply(BaseFilter): + class _Reply(MessageFilter): name = 'Filters.reply' def filter(self, message): @@ -447,7 +482,7 @@ def filter(self, message): reply = _Reply() """Messages that are a reply to another message.""" - class _Audio(BaseFilter): + class _Audio(MessageFilter): name = 'Filters.audio' def filter(self, message): @@ -456,10 +491,10 @@ def filter(self, message): audio = _Audio() """Messages that contain :class:`telegram.Audio`.""" - class _Document(BaseFilter): + class _Document(MessageFilter): name = 'Filters.document' - class category(BaseFilter): + class category(MessageFilter): """This Filter filters documents by their category in the mime-type attribute Note: @@ -469,7 +504,7 @@ class category(BaseFilter): send media with wrong types that don't fit to this handler. Example: - Filters.documents.category('audio/') returns `True` for all types + Filters.documents.category('audio/') returns :obj:`True` for all types of audio sent as file, for example 'audio/mpeg' or 'audio/x-wav' """ @@ -492,7 +527,7 @@ def filter(self, message): video = category('video/') text = category('text/') - class mime_type(BaseFilter): + class mime_type(MessageFilter): """This Filter filters documents by their mime-type attribute Note: @@ -592,7 +627,7 @@ def filter(self, message): zip: Same as ``Filters.document.mime_type("application/zip")``- """ - class _Animation(BaseFilter): + class _Animation(MessageFilter): name = 'Filters.animation' def filter(self, message): @@ -601,7 +636,7 @@ def filter(self, message): animation = _Animation() """Messages that contain :class:`telegram.Animation`.""" - class _Photo(BaseFilter): + class _Photo(MessageFilter): name = 'Filters.photo' def filter(self, message): @@ -610,7 +645,7 @@ def filter(self, message): photo = _Photo() """Messages that contain :class:`telegram.PhotoSize`.""" - class _Sticker(BaseFilter): + class _Sticker(MessageFilter): name = 'Filters.sticker' def filter(self, message): @@ -619,7 +654,7 @@ def filter(self, message): sticker = _Sticker() """Messages that contain :class:`telegram.Sticker`.""" - class _Video(BaseFilter): + class _Video(MessageFilter): name = 'Filters.video' def filter(self, message): @@ -628,7 +663,7 @@ def filter(self, message): video = _Video() """Messages that contain :class:`telegram.Video`.""" - class _Voice(BaseFilter): + class _Voice(MessageFilter): name = 'Filters.voice' def filter(self, message): @@ -637,7 +672,7 @@ def filter(self, message): voice = _Voice() """Messages that contain :class:`telegram.Voice`.""" - class _VideoNote(BaseFilter): + class _VideoNote(MessageFilter): name = 'Filters.video_note' def filter(self, message): @@ -646,7 +681,7 @@ def filter(self, message): video_note = _VideoNote() """Messages that contain :class:`telegram.VideoNote`.""" - class _Contact(BaseFilter): + class _Contact(MessageFilter): name = 'Filters.contact' def filter(self, message): @@ -655,7 +690,7 @@ def filter(self, message): contact = _Contact() """Messages that contain :class:`telegram.Contact`.""" - class _Location(BaseFilter): + class _Location(MessageFilter): name = 'Filters.location' def filter(self, message): @@ -664,7 +699,7 @@ def filter(self, message): location = _Location() """Messages that contain :class:`telegram.Location`.""" - class _Venue(BaseFilter): + class _Venue(MessageFilter): name = 'Filters.venue' def filter(self, message): @@ -673,7 +708,7 @@ def filter(self, message): venue = _Venue() """Messages that contain :class:`telegram.Venue`.""" - class _StatusUpdate(BaseFilter): + class _StatusUpdate(UpdateFilter): """Subset for messages containing a status update. Examples: @@ -681,9 +716,7 @@ class _StatusUpdate(BaseFilter): ``Filters.status_update`` for all status update messages. """ - update_filter = True - - class _NewChatMembers(BaseFilter): + class _NewChatMembers(MessageFilter): name = 'Filters.status_update.new_chat_members' def filter(self, message): @@ -692,7 +725,7 @@ def filter(self, message): new_chat_members = _NewChatMembers() """Messages that contain :attr:`telegram.Message.new_chat_members`.""" - class _LeftChatMember(BaseFilter): + class _LeftChatMember(MessageFilter): name = 'Filters.status_update.left_chat_member' def filter(self, message): @@ -701,7 +734,7 @@ def filter(self, message): left_chat_member = _LeftChatMember() """Messages that contain :attr:`telegram.Message.left_chat_member`.""" - class _NewChatTitle(BaseFilter): + class _NewChatTitle(MessageFilter): name = 'Filters.status_update.new_chat_title' def filter(self, message): @@ -710,7 +743,7 @@ def filter(self, message): new_chat_title = _NewChatTitle() """Messages that contain :attr:`telegram.Message.new_chat_title`.""" - class _NewChatPhoto(BaseFilter): + class _NewChatPhoto(MessageFilter): name = 'Filters.status_update.new_chat_photo' def filter(self, message): @@ -719,7 +752,7 @@ def filter(self, message): new_chat_photo = _NewChatPhoto() """Messages that contain :attr:`telegram.Message.new_chat_photo`.""" - class _DeleteChatPhoto(BaseFilter): + class _DeleteChatPhoto(MessageFilter): name = 'Filters.status_update.delete_chat_photo' def filter(self, message): @@ -728,7 +761,7 @@ def filter(self, message): delete_chat_photo = _DeleteChatPhoto() """Messages that contain :attr:`telegram.Message.delete_chat_photo`.""" - class _ChatCreated(BaseFilter): + class _ChatCreated(MessageFilter): name = 'Filters.status_update.chat_created' def filter(self, message): @@ -740,7 +773,7 @@ def filter(self, message): :attr: `telegram.Message.supergroup_chat_created` or :attr: `telegram.Message.channel_chat_created`.""" - class _Migrate(BaseFilter): + class _Migrate(MessageFilter): name = 'Filters.status_update.migrate' def filter(self, message): @@ -750,7 +783,7 @@ def filter(self, message): """Messages that contain :attr:`telegram.Message.migrate_from_chat_id` or :attr: `telegram.Message.migrate_to_chat_id`.""" - class _PinnedMessage(BaseFilter): + class _PinnedMessage(MessageFilter): name = 'Filters.status_update.pinned_message' def filter(self, message): @@ -759,7 +792,7 @@ def filter(self, message): pinned_message = _PinnedMessage() """Messages that contain :attr:`telegram.Message.pinned_message`.""" - class _ConnectedWebsite(BaseFilter): + class _ConnectedWebsite(MessageFilter): name = 'Filters.status_update.connected_website' def filter(self, message): @@ -806,7 +839,7 @@ def filter(self, message): :attr:`telegram.Message.pinned_message`. """ - class _Forwarded(BaseFilter): + class _Forwarded(MessageFilter): name = 'Filters.forwarded' def filter(self, message): @@ -815,7 +848,7 @@ def filter(self, message): forwarded = _Forwarded() """Messages that are forwarded.""" - class _Game(BaseFilter): + class _Game(MessageFilter): name = 'Filters.game' def filter(self, message): @@ -824,7 +857,7 @@ def filter(self, message): game = _Game() """Messages that contain :class:`telegram.Game`.""" - class entity(BaseFilter): + class entity(MessageFilter): """ Filters messages to only allow those which have a :class:`telegram.MessageEntity` where their `type` matches `entity_type`. @@ -846,7 +879,7 @@ def filter(self, message): """""" # remove method from docs return any(entity.type == self.entity_type for entity in message.entities) - class caption_entity(BaseFilter): + class caption_entity(MessageFilter): """ Filters media messages to only allow those which have a :class:`telegram.MessageEntity` where their `type` matches `entity_type`. @@ -868,7 +901,7 @@ def filter(self, message): """""" # remove method from docs return any(entity.type == self.entity_type for entity in message.caption_entities) - class _Private(BaseFilter): + class _Private(MessageFilter): name = 'Filters.private' def filter(self, message): @@ -877,7 +910,7 @@ def filter(self, message): private = _Private() """Messages sent in a private chat.""" - class _Group(BaseFilter): + class _Group(MessageFilter): name = 'Filters.group' def filter(self, message): @@ -886,7 +919,7 @@ def filter(self, message): group = _Group() """Messages sent in a group chat.""" - class user(BaseFilter): + class user(MessageFilter): """Filters messages to allow only those which are from specified user ID(s) or username(s). @@ -1053,7 +1086,7 @@ def filter(self, message): return self.allow_empty return False - class via_bot(BaseFilter): + class via_bot(MessageFilter): """Filters messages to allow only those which are from specified via_bot ID(s) or username(s). @@ -1216,7 +1249,7 @@ def filter(self, message): return self.allow_empty return False - class chat(BaseFilter): + class chat(MessageFilter): """Filters messages to allow only those which are from a specified chat ID or username. Examples: @@ -1383,7 +1416,7 @@ def filter(self, message): return self.allow_empty return False - class _Invoice(BaseFilter): + class _Invoice(MessageFilter): name = 'Filters.invoice' def filter(self, message): @@ -1392,7 +1425,7 @@ def filter(self, message): invoice = _Invoice() """Messages that contain :class:`telegram.Invoice`.""" - class _SuccessfulPayment(BaseFilter): + class _SuccessfulPayment(MessageFilter): name = 'Filters.successful_payment' def filter(self, message): @@ -1401,7 +1434,7 @@ def filter(self, message): successful_payment = _SuccessfulPayment() """Messages that confirm a :class:`telegram.SuccessfulPayment`.""" - class _PassportData(BaseFilter): + class _PassportData(MessageFilter): name = 'Filters.passport_data' def filter(self, message): @@ -1410,7 +1443,7 @@ def filter(self, message): passport_data = _PassportData() """Messages that contain a :class:`telegram.PassportData`""" - class _Poll(BaseFilter): + class _Poll(MessageFilter): name = 'Filters.poll' def filter(self, message): @@ -1453,7 +1486,7 @@ class _Dice(_DiceEmoji): as for :attr:`Filters.dice`. """ - class language(BaseFilter): + class language(MessageFilter): """Filters messages to only allow those which are from users with a certain language code. Note: @@ -1482,48 +1515,42 @@ def filter(self, message): return message.from_user.language_code and any( [message.from_user.language_code.startswith(x) for x in self.lang]) - class _UpdateType(BaseFilter): - update_filter = True + class _UpdateType(UpdateFilter): name = 'Filters.update' - class _Message(BaseFilter): + class _Message(UpdateFilter): name = 'Filters.update.message' - update_filter = True def filter(self, update): return update.message is not None message = _Message() - class _EditedMessage(BaseFilter): + class _EditedMessage(UpdateFilter): name = 'Filters.update.edited_message' - update_filter = True def filter(self, update): return update.edited_message is not None edited_message = _EditedMessage() - class _Messages(BaseFilter): + class _Messages(UpdateFilter): name = 'Filters.update.messages' - update_filter = True def filter(self, update): return update.message is not None or update.edited_message is not None messages = _Messages() - class _ChannelPost(BaseFilter): + class _ChannelPost(UpdateFilter): name = 'Filters.update.channel_post' - update_filter = True def filter(self, update): return update.channel_post is not None channel_post = _ChannelPost() - class _EditedChannelPost(BaseFilter): - update_filter = True + class _EditedChannelPost(UpdateFilter): name = 'Filters.update.edited_channel_post' def filter(self, update): @@ -1531,8 +1558,7 @@ def filter(self, update): edited_channel_post = _EditedChannelPost() - class _ChannelPosts(BaseFilter): - update_filter = True + class _ChannelPosts(UpdateFilter): name = 'Filters.update.channel_posts' def filter(self, update): diff --git a/telegram/files/venue.py b/telegram/files/venue.py index a54d7978553..142a0e9bfd8 100644 --- a/telegram/files/venue.py +++ b/telegram/files/venue.py @@ -25,7 +25,7 @@ class Venue(TelegramObject): """This object represents a venue. Objects of this class are comparable in terms of equality. Two objects of this class are - considered equal, if their :attr:`location` and :attr:`title`are equal. + considered equal, if their :attr:`location` and :attr:`title` are equal. Attributes: location (:class:`telegram.Location`): Venue location. diff --git a/tests/conftest.py b/tests/conftest.py index b4ecd2dd626..b2795740dbf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ from telegram import (Bot, Message, User, Chat, MessageEntity, Update, InlineQuery, CallbackQuery, ShippingQuery, PreCheckoutQuery, ChosenInlineResult) -from telegram.ext import Dispatcher, JobQueue, Updater, BaseFilter, Defaults +from telegram.ext import Dispatcher, JobQueue, Updater, MessageFilter, Defaults from telegram.error import BadRequest from tests.bots import get_bot @@ -241,7 +241,7 @@ def make_command_update(message, edited=False, **kwargs): @pytest.fixture(scope='function') def mock_filter(): - class MockFilter(BaseFilter): + class MockFilter(MessageFilter): def __init__(self): self.tested = False diff --git a/tests/test_filters.py b/tests/test_filters.py index 03847413d4c..d53373dad72 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -21,7 +21,7 @@ import pytest from telegram import Message, User, Chat, MessageEntity, Document, Update, Dice -from telegram.ext import Filters, BaseFilter +from telegram.ext import Filters, BaseFilter, MessageFilter import re @@ -963,7 +963,7 @@ class _CustomFilter(BaseFilter): _CustomFilter() def test_custom_unnamed_filter(self, update): - class Unnamed(BaseFilter): + class Unnamed(MessageFilter): def filter(self, mes): return True @@ -1016,7 +1016,7 @@ def test_merged_short_circuit_and(self, update): class TestException(Exception): pass - class RaisingFilter(BaseFilter): + class RaisingFilter(MessageFilter): def filter(self, _): raise TestException @@ -1035,7 +1035,7 @@ def test_merged_short_circuit_or(self, update): class TestException(Exception): pass - class RaisingFilter(BaseFilter): + class RaisingFilter(MessageFilter): def filter(self, _): raise TestException @@ -1052,7 +1052,7 @@ def test_merged_data_merging_and(self, update): update.message.text = '/test' update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - class DataFilter(BaseFilter): + class DataFilter(MessageFilter): data_filter = True def __init__(self, data): @@ -1075,7 +1075,7 @@ def filter(self, _): def test_merged_data_merging_or(self, update): update.message.text = '/test' - class DataFilter(BaseFilter): + class DataFilter(MessageFilter): data_filter = True def __init__(self, data): From b5daceb3548bb23c39d2f0a7f6bbf8285f8c7c18 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler Date: Mon, 27 Jul 2020 22:39:28 +0200 Subject: [PATCH 2/2] address review --- telegram/ext/filters.py | 8 ++++++-- tests/conftest.py | 15 ++++++++++----- tests/test_filters.py | 32 +++++++++++++++++++++----------- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index 44051bd22b3..3172b397630 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -32,7 +32,7 @@ class BaseFilter(ABC): """Base class for all Filters. - Subclassing from this class filters to be combined using bitwise operators: + Filters subclassing from this class can combined using bitwise operators: And: @@ -84,7 +84,7 @@ class variable. @abstractmethod def __call__(self, update): - ... + pass def __and__(self, other): return MergedFilter(self, and_filter=other) @@ -106,6 +106,8 @@ class MessageFilter(BaseFilter, ABC): """Base class for all Message Filters. In contrast to :class:`UpdateFilter`, the object passed to :meth:`filter` is ``update.effective_message``. + Please see :class:`telegram.ext.BaseFilter` for details on how to create custom filters. + Attributes: name (:obj:`str`): Name for this filter. Defaults to the type of filter. data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should @@ -135,6 +137,8 @@ class UpdateFilter(BaseFilter, ABC): passed to :meth:`filter` is ``update``, which allows to create filters like :attr:`Filters.update.edited_message`. + Please see :class:`telegram.ext.BaseFilter` for details on how to create custom filters. + Attributes: name (:obj:`str`): Name for this filter. Defaults to the type of filter. data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should diff --git a/tests/conftest.py b/tests/conftest.py index b2795740dbf..d957d0d04f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ from telegram import (Bot, Message, User, Chat, MessageEntity, Update, InlineQuery, CallbackQuery, ShippingQuery, PreCheckoutQuery, ChosenInlineResult) -from telegram.ext import Dispatcher, JobQueue, Updater, MessageFilter, Defaults +from telegram.ext import Dispatcher, JobQueue, Updater, MessageFilter, Defaults, UpdateFilter from telegram.error import BadRequest from tests.bots import get_bot @@ -239,13 +239,18 @@ def make_command_update(message, edited=False, **kwargs): return make_message_update(message, make_command_message, edited, **kwargs) -@pytest.fixture(scope='function') -def mock_filter(): - class MockFilter(MessageFilter): +@pytest.fixture(scope='class', + params=[ + {'class': MessageFilter}, + {'class': UpdateFilter} + ], + ids=['MessageFilter', 'UpdateFilter']) +def mock_filter(request): + class MockFilter(request.param['class']): def __init__(self): self.tested = False - def filter(self, message): + def filter(self, _): self.tested = True return MockFilter() diff --git a/tests/test_filters.py b/tests/test_filters.py index d53373dad72..fad30709d3f 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -21,7 +21,7 @@ import pytest from telegram import Message, User, Chat, MessageEntity, Document, Update, Dice -from telegram.ext import Filters, BaseFilter, MessageFilter +from telegram.ext import Filters, BaseFilter, MessageFilter, UpdateFilter import re @@ -37,6 +37,16 @@ def message_entity(request): return MessageEntity(request.param, 0, 0, url='', user='') +@pytest.fixture(scope='class', + params=[ + {'class': MessageFilter}, + {'class': UpdateFilter} + ], + ids=['MessageFilter', 'UpdateFilter']) +def base_class(request): + return request.param['class'] + + class TestFilters: def test_filters_all(self, update): assert Filters.all(update) @@ -962,8 +972,8 @@ class _CustomFilter(BaseFilter): with pytest.raises(TypeError, match='Can\'t instantiate abstract class _CustomFilter'): _CustomFilter() - def test_custom_unnamed_filter(self, update): - class Unnamed(MessageFilter): + def test_custom_unnamed_filter(self, update, base_class): + class Unnamed(base_class): def filter(self, mes): return True @@ -1009,14 +1019,14 @@ def test_update_type_edited_channel_post(self, update): assert Filters.update.channel_posts(update) assert Filters.update(update) - def test_merged_short_circuit_and(self, update): + def test_merged_short_circuit_and(self, update, base_class): update.message.text = '/test' update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] class TestException(Exception): pass - class RaisingFilter(MessageFilter): + class RaisingFilter(base_class): def filter(self, _): raise TestException @@ -1029,13 +1039,13 @@ def filter(self, _): update.message.entities = [] (Filters.command & raising_filter)(update) - def test_merged_short_circuit_or(self, update): + def test_merged_short_circuit_or(self, update, base_class): update.message.text = 'test' class TestException(Exception): pass - class RaisingFilter(MessageFilter): + class RaisingFilter(base_class): def filter(self, _): raise TestException @@ -1048,11 +1058,11 @@ def filter(self, _): update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] (Filters.command | raising_filter)(update) - def test_merged_data_merging_and(self, update): + def test_merged_data_merging_and(self, update, base_class): update.message.text = '/test' update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - class DataFilter(MessageFilter): + class DataFilter(base_class): data_filter = True def __init__(self, data): @@ -1072,10 +1082,10 @@ def filter(self, _): result = (Filters.command & DataFilter('blah'))(update) assert not result - def test_merged_data_merging_or(self, update): + def test_merged_data_merging_or(self, update, base_class): update.message.text = '/test' - class DataFilter(MessageFilter): + class DataFilter(base_class): data_filter = True def __init__(self, data):