diff options
author | Evgeny Zinoviev <me@ch1p.io> | 2023-05-31 09:22:00 +0300 |
---|---|---|
committer | Evgeny Zinoviev <me@ch1p.io> | 2023-06-10 02:07:23 +0300 |
commit | f29e139cbb7e4a4d539cba6e894ef4a6acd312d6 (patch) | |
tree | 6246f126325c5c36fb573134a05f2771cd747966 /src/home/telegram/bot.py | |
parent | 3e3753d726f8a02d98368f20f77dd9fa739e3d80 (diff) |
WIP: big refactoring
Diffstat (limited to 'src/home/telegram/bot.py')
-rw-r--r-- | src/home/telegram/bot.py | 149 |
1 files changed, 77 insertions, 72 deletions
diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py index 10bfe06..7e22263 100644 --- a/src/home/telegram/bot.py +++ b/src/home/telegram/bot.py @@ -5,19 +5,19 @@ import itertools from enum import Enum, auto from functools import wraps -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Coroutine from telegram import Update, ReplyKeyboardMarkup from telegram.ext import ( - Updater, - Filters, - BaseFilter, + Application, + filters, CommandHandler, MessageHandler, CallbackQueryHandler, CallbackContext, ConversationHandler ) +from telegram.ext.filters import BaseFilter from telegram.error import TimedOut from home.config import config @@ -33,26 +33,26 @@ from ._botcontext import Context db: Optional[BotDatabase] = None _user_filter: Optional[BaseFilter] = None -_cancel_filter = Filters.text(lang.all('cancel')) -_back_filter = Filters.text(lang.all('back')) -_cancel_and_back_filter = Filters.text(lang.all('back') + lang.all('cancel')) +_cancel_filter = filters.Text(lang.all('cancel')) +_back_filter = filters.Text(lang.all('back')) +_cancel_and_back_filter = filters.Text(lang.all('back') + lang.all('cancel')) _logger = logging.getLogger(__name__) -_updater: Optional[Updater] = None +_application: Optional[Application] = None _reporting: Optional[ReportingHelper] = None -_exception_handler: Optional[callable] = None +_exception_handler: Optional[Coroutine] = None _dispatcher = None _markup_getter: Optional[callable] = None -_start_handler_ref: Optional[callable] = None +_start_handler_ref: Optional[Coroutine] = None def text_filter(*args): if not _user_filter: raise RuntimeError('user_filter is not initialized') - return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter + return filters.Text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter -def _handler_of_handler(*args, **kwargs): +async def _handler_of_handler(*args, **kwargs): self = None context = None update = None @@ -99,7 +99,7 @@ def _handler_of_handler(*args, **kwargs): if self: _args.insert(0, self) - result = f(*_args, **kwargs) + result = await f(*_args, **kwargs) return result if not return_with_context else (result, ctx) except Exception as e: @@ -107,7 +107,7 @@ def _handler_of_handler(*args, **kwargs): if not _exception_handler(e, ctx) and not isinstance(e, TimedOut): _logger.exception(e) if not ctx.is_callback_context(): - ctx.reply_exc(e) + await ctx.reply_exc(e) else: notify_user(ctx.user_id, exc2text(e)) else: @@ -117,10 +117,10 @@ def _handler_of_handler(*args, **kwargs): def handler(**kwargs): def inner(f): @wraps(f) - def _handler(*args, **inner_kwargs): + async def _handler(*args, **inner_kwargs): if 'argument' in kwargs and kwargs['argument'] == 'message_key': inner_kwargs['argument'] = 'message_key' - return _handler_of_handler(f=f, *args, **inner_kwargs) + return await _handler_of_handler(f=f, *args, **inner_kwargs) messages = [] texts = [] @@ -139,43 +139,43 @@ def handler(**kwargs): new_messages = list(itertools.chain.from_iterable([lang.all(m) for m in messages])) texts += new_messages texts = list(set(texts)) - _updater.dispatcher.add_handler( + _application.add_handler( MessageHandler(text_filter(*texts), _handler), group=0 ) if 'command' in kwargs: - _updater.dispatcher.add_handler(CommandHandler(kwargs['command'], _handler), group=0) + _application.add_handler(CommandHandler(kwargs['command'], _handler), group=0) if 'callback' in kwargs: - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) + _application.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) return _handler return inner -def simplehandler(f: callable): +def simplehandler(f: Coroutine): @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) + async def _handler(*args, **kwargs): + return await _handler_of_handler(f=f, *args, **kwargs) return _handler def callbackhandler(*args, **kwargs): def inner(f): @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) + async def _handler(*args, **kwargs): + return await _handler_of_handler(f=f, *args, **kwargs) pattern_kwargs = {} if kwargs['callback'] != '*': pattern_kwargs['pattern'] = kwargs['callback'] - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) + _application.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) return _handler return inner -def exceptionhandler(f: callable): +async def exceptionhandler(f: callable): global _exception_handler if _exception_handler: _logger.warning('exception handler already set, we will overwrite it') @@ -198,10 +198,10 @@ def convinput(state, is_enter=False, **kwargs): ) @wraps(f) - def _impl(*args, **kwargs): - result, ctx = _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) + async def _impl(*args, **kwargs): + result, ctx = await _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) if result == conversation.END: - start(ctx) + await start(ctx) return result return _impl @@ -252,7 +252,7 @@ class conversation: handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state))) if 'regex' in kwargs: - handlers.append(MessageHandler(Filters.regex(kwargs['regex']) & _user_filter, f)) + handlers.append(MessageHandler(filters.Regex(kwargs['regex']) & _user_filter, f)) if 'command' in kwargs: handlers.append(CommandHandler(kwargs['command'], f, _user_filter)) @@ -327,21 +327,21 @@ class conversation: @staticmethod @simplehandler - def invalid(ctx: Context): - ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) + async def invalid(ctx: Context): + await ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) # return 0 # FIXME is this needed @simplehandler - def cancel(self, ctx: Context): - start(ctx) + async def cancel(self, ctx: Context): + await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @simplehandler - def back(self, ctx: Context): + async def back(self, ctx: Context): cur_state = self.get_user_state(ctx.user_id) if cur_state is None: - start(ctx) + await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @@ -411,7 +411,7 @@ class LangConversation(conversation): START, = range(1) @conventer(START, command='lang') - def entry(self, ctx: Context): + async def entry(self, ctx: Context): self._logger.debug(f'current language: {ctx.user_lang}') buttons = [] @@ -419,11 +419,11 @@ class LangConversation(conversation): buttons.append(name) markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) - ctx.reply(ctx.lang('select_language'), markup=markup) + await ctx.reply(ctx.lang('select_language'), markup=markup) return self.START @convinput(START, messages=lang.languages) - def input(self, ctx: Context): + async def input(self, ctx: Context): selected_lang = None for key, value in languages.items(): if value == ctx.text: @@ -434,30 +434,34 @@ class LangConversation(conversation): raise ValueError('could not find the language') db.set_user_lang(ctx.user_id, selected_lang) - ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) + await ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) return self.END def initialize(): global _user_filter - global _updater + global _application + # global _updater global _dispatcher # init user_filter - if 'users' in config['bot']: - _logger.info('allowed users: ' + str(config['bot']['users'])) - _user_filter = Filters.user(config['bot']['users']) + _user_ids = config.app_config.get_user_ids() + if len(_user_ids) > 0: + _logger.info('allowed users: ' + str(_user_ids)) + _user_filter = filters.User(_user_ids) else: - _user_filter = Filters.all # not sure if this is correct + _user_filter = filters.ALL # not sure if this is correct - # init updater - _updater = Updater(config['bot']['token'], - request_kwargs={'read_timeout': 6, 'connect_timeout': 7}) + _application = Application.builder()\ + .token(config.app_config.get('bot.token'))\ + .connect_timeout(7)\ + .read_timeout(6)\ + .build() # transparently log all messages - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, _logging_message_handler), group=10) - _updater.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) + # _application.dispatcher.add_handler(MessageHandler(filters.ALL & _user_filter, _logging_message_handler), group=10) + # _application.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) def run(start_handler=None, any_handler=None): @@ -473,37 +477,38 @@ def run(start_handler=None, any_handler=None): _start_handler_ref = start_handler - _updater.dispatcher.add_handler(LangConversation().get_handler(), group=0) - _updater.dispatcher.add_handler(CommandHandler('start', simplehandler(start_handler), _user_filter)) - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, any_handler)) + _application.add_handler(LangConversation().get_handler(), group=0) + _application.add_handler(CommandHandler('start', + callback=simplehandler(start_handler), + filters=_user_filter)) + _application.add_handler(MessageHandler(filters.ALL & _user_filter, any_handler)) - _updater.start_polling() - _updater.idle() + _application.run_polling() def add_conversation(conv: conversation) -> None: - _updater.dispatcher.add_handler(conv.get_handler(), group=0) + _application.add_handler(conv.get_handler(), group=0) def add_handler(h): - _updater.dispatcher.add_handler(h, group=0) + _application.add_handler(h, group=0) -def start(ctx: Context): - return _start_handler_ref(ctx) +async def start(ctx: Context): + return await _start_handler_ref(ctx) -def _default_start_handler(ctx: Context): +async def _default_start_handler(ctx: Context): if 'start_message' not in lang: - return ctx.reply('Please define start_message or override start()') - ctx.reply(ctx.lang('start_message')) + return await ctx.reply('Please define start_message or override start()') + await ctx.reply(ctx.lang('start_message')) @simplehandler -def _default_any_handler(ctx: Context): +async def _default_any_handler(ctx: Context): if 'invalid_command' not in lang: - return ctx.reply('Please define invalid_command or override any()') - ctx.reply(ctx.lang('invalid_command')) + return await ctx.reply('Please define invalid_command or override any()') + await ctx.reply(ctx.lang('invalid_command')) def _logging_message_handler(update: Update, context: CallbackContext): @@ -535,7 +540,7 @@ def notify_all(text_getter: callable, continue text = text_getter(db.get_user_lang(user_id)) - _updater.bot.send_message(chat_id=user_id, + _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML') @@ -543,33 +548,33 @@ def notify_all(text_getter: callable, def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None: if isinstance(text, Exception): text = exc2text(text) - _updater.bot.send_message(chat_id=user_id, + _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML', **kwargs) def send_photo(user_id, **kwargs): - _updater.bot.send_photo(chat_id=user_id, **kwargs) + _application.bot.send_photo(chat_id=user_id, **kwargs) def send_audio(user_id, **kwargs): - _updater.bot.send_audio(chat_id=user_id, **kwargs) + _application.bot.send_audio(chat_id=user_id, **kwargs) def send_file(user_id, **kwargs): - _updater.bot.send_document(chat_id=user_id, **kwargs) + _application.bot.send_document(chat_id=user_id, **kwargs) def edit_message_text(user_id, message_id, *args, **kwargs): - _updater.bot.edit_message_text(chat_id=user_id, + _application.bot.edit_message_text(chat_id=user_id, message_id=message_id, parse_mode='HTML', *args, **kwargs) def delete_message(user_id, message_id): - _updater.bot.delete_message(chat_id=user_id, message_id=message_id) + _application.bot.delete_message(chat_id=user_id, message_id=message_id) def set_database(_db: BotDatabase): |