summaryrefslogtreecommitdiff
path: root/src/home/telegram/bot.py
diff options
context:
space:
mode:
authorEvgeny Zinoviev <me@ch1p.io>2023-05-31 09:22:00 +0300
committerEvgeny Zinoviev <me@ch1p.io>2023-06-10 02:07:23 +0300
commitf29e139cbb7e4a4d539cba6e894ef4a6acd312d6 (patch)
tree6246f126325c5c36fb573134a05f2771cd747966 /src/home/telegram/bot.py
parent3e3753d726f8a02d98368f20f77dd9fa739e3d80 (diff)
WIP: big refactoring
Diffstat (limited to 'src/home/telegram/bot.py')
-rw-r--r--src/home/telegram/bot.py149
1 files changed, 77 insertions, 72 deletions
diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py
index 10bfe06..7e22263 100644
--- a/src/home/telegram/bot.py
+++ b/src/home/telegram/bot.py
@@ -5,19 +5,19 @@ import itertools
from enum import Enum, auto
from functools import wraps
-from typing import Optional, Union, Tuple
+from typing import Optional, Union, Tuple, Coroutine
from telegram import Update, ReplyKeyboardMarkup
from telegram.ext import (
- Updater,
- Filters,
- BaseFilter,
+ Application,
+ filters,
CommandHandler,
MessageHandler,
CallbackQueryHandler,
CallbackContext,
ConversationHandler
)
+from telegram.ext.filters import BaseFilter
from telegram.error import TimedOut
from home.config import config
@@ -33,26 +33,26 @@ from ._botcontext import Context
db: Optional[BotDatabase] = None
_user_filter: Optional[BaseFilter] = None
-_cancel_filter = Filters.text(lang.all('cancel'))
-_back_filter = Filters.text(lang.all('back'))
-_cancel_and_back_filter = Filters.text(lang.all('back') + lang.all('cancel'))
+_cancel_filter = filters.Text(lang.all('cancel'))
+_back_filter = filters.Text(lang.all('back'))
+_cancel_and_back_filter = filters.Text(lang.all('back') + lang.all('cancel'))
_logger = logging.getLogger(__name__)
-_updater: Optional[Updater] = None
+_application: Optional[Application] = None
_reporting: Optional[ReportingHelper] = None
-_exception_handler: Optional[callable] = None
+_exception_handler: Optional[Coroutine] = None
_dispatcher = None
_markup_getter: Optional[callable] = None
-_start_handler_ref: Optional[callable] = None
+_start_handler_ref: Optional[Coroutine] = None
def text_filter(*args):
if not _user_filter:
raise RuntimeError('user_filter is not initialized')
- return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter
+ return filters.Text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter
-def _handler_of_handler(*args, **kwargs):
+async def _handler_of_handler(*args, **kwargs):
self = None
context = None
update = None
@@ -99,7 +99,7 @@ def _handler_of_handler(*args, **kwargs):
if self:
_args.insert(0, self)
- result = f(*_args, **kwargs)
+ result = await f(*_args, **kwargs)
return result if not return_with_context else (result, ctx)
except Exception as e:
@@ -107,7 +107,7 @@ def _handler_of_handler(*args, **kwargs):
if not _exception_handler(e, ctx) and not isinstance(e, TimedOut):
_logger.exception(e)
if not ctx.is_callback_context():
- ctx.reply_exc(e)
+ await ctx.reply_exc(e)
else:
notify_user(ctx.user_id, exc2text(e))
else:
@@ -117,10 +117,10 @@ def _handler_of_handler(*args, **kwargs):
def handler(**kwargs):
def inner(f):
@wraps(f)
- def _handler(*args, **inner_kwargs):
+ async def _handler(*args, **inner_kwargs):
if 'argument' in kwargs and kwargs['argument'] == 'message_key':
inner_kwargs['argument'] = 'message_key'
- return _handler_of_handler(f=f, *args, **inner_kwargs)
+ return await _handler_of_handler(f=f, *args, **inner_kwargs)
messages = []
texts = []
@@ -139,43 +139,43 @@ def handler(**kwargs):
new_messages = list(itertools.chain.from_iterable([lang.all(m) for m in messages]))
texts += new_messages
texts = list(set(texts))
- _updater.dispatcher.add_handler(
+ _application.add_handler(
MessageHandler(text_filter(*texts), _handler),
group=0
)
if 'command' in kwargs:
- _updater.dispatcher.add_handler(CommandHandler(kwargs['command'], _handler), group=0)
+ _application.add_handler(CommandHandler(kwargs['command'], _handler), group=0)
if 'callback' in kwargs:
- _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0)
+ _application.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0)
return _handler
return inner
-def simplehandler(f: callable):
+def simplehandler(f: Coroutine):
@wraps(f)
- def _handler(*args, **kwargs):
- return _handler_of_handler(f=f, *args, **kwargs)
+ async def _handler(*args, **kwargs):
+ return await _handler_of_handler(f=f, *args, **kwargs)
return _handler
def callbackhandler(*args, **kwargs):
def inner(f):
@wraps(f)
- def _handler(*args, **kwargs):
- return _handler_of_handler(f=f, *args, **kwargs)
+ async def _handler(*args, **kwargs):
+ return await _handler_of_handler(f=f, *args, **kwargs)
pattern_kwargs = {}
if kwargs['callback'] != '*':
pattern_kwargs['pattern'] = kwargs['callback']
- _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0)
+ _application.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0)
return _handler
return inner
-def exceptionhandler(f: callable):
+async def exceptionhandler(f: callable):
global _exception_handler
if _exception_handler:
_logger.warning('exception handler already set, we will overwrite it')
@@ -198,10 +198,10 @@ def convinput(state, is_enter=False, **kwargs):
)
@wraps(f)
- def _impl(*args, **kwargs):
- result, ctx = _handler_of_handler(f=f, *args, **kwargs, return_with_context=True)
+ async def _impl(*args, **kwargs):
+ result, ctx = await _handler_of_handler(f=f, *args, **kwargs, return_with_context=True)
if result == conversation.END:
- start(ctx)
+ await start(ctx)
return result
return _impl
@@ -252,7 +252,7 @@ class conversation:
handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state)))
if 'regex' in kwargs:
- handlers.append(MessageHandler(Filters.regex(kwargs['regex']) & _user_filter, f))
+ handlers.append(MessageHandler(filters.Regex(kwargs['regex']) & _user_filter, f))
if 'command' in kwargs:
handlers.append(CommandHandler(kwargs['command'], f, _user_filter))
@@ -327,21 +327,21 @@ class conversation:
@staticmethod
@simplehandler
- def invalid(ctx: Context):
- ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup())
+ async def invalid(ctx: Context):
+ await ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup())
# return 0 # FIXME is this needed
@simplehandler
- def cancel(self, ctx: Context):
- start(ctx)
+ async def cancel(self, ctx: Context):
+ await start(ctx)
self.set_user_state(ctx.user_id, None)
return conversation.END
@simplehandler
- def back(self, ctx: Context):
+ async def back(self, ctx: Context):
cur_state = self.get_user_state(ctx.user_id)
if cur_state is None:
- start(ctx)
+ await start(ctx)
self.set_user_state(ctx.user_id, None)
return conversation.END
@@ -411,7 +411,7 @@ class LangConversation(conversation):
START, = range(1)
@conventer(START, command='lang')
- def entry(self, ctx: Context):
+ async def entry(self, ctx: Context):
self._logger.debug(f'current language: {ctx.user_lang}')
buttons = []
@@ -419,11 +419,11 @@ class LangConversation(conversation):
buttons.append(name)
markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False)
- ctx.reply(ctx.lang('select_language'), markup=markup)
+ await ctx.reply(ctx.lang('select_language'), markup=markup)
return self.START
@convinput(START, messages=lang.languages)
- def input(self, ctx: Context):
+ async def input(self, ctx: Context):
selected_lang = None
for key, value in languages.items():
if value == ctx.text:
@@ -434,30 +434,34 @@ class LangConversation(conversation):
raise ValueError('could not find the language')
db.set_user_lang(ctx.user_id, selected_lang)
- ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup())
+ await ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup())
return self.END
def initialize():
global _user_filter
- global _updater
+ global _application
+ # global _updater
global _dispatcher
# init user_filter
- if 'users' in config['bot']:
- _logger.info('allowed users: ' + str(config['bot']['users']))
- _user_filter = Filters.user(config['bot']['users'])
+ _user_ids = config.app_config.get_user_ids()
+ if len(_user_ids) > 0:
+ _logger.info('allowed users: ' + str(_user_ids))
+ _user_filter = filters.User(_user_ids)
else:
- _user_filter = Filters.all # not sure if this is correct
+ _user_filter = filters.ALL # not sure if this is correct
- # init updater
- _updater = Updater(config['bot']['token'],
- request_kwargs={'read_timeout': 6, 'connect_timeout': 7})
+ _application = Application.builder()\
+ .token(config.app_config.get('bot.token'))\
+ .connect_timeout(7)\
+ .read_timeout(6)\
+ .build()
# transparently log all messages
- _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, _logging_message_handler), group=10)
- _updater.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10)
+ # _application.dispatcher.add_handler(MessageHandler(filters.ALL & _user_filter, _logging_message_handler), group=10)
+ # _application.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10)
def run(start_handler=None, any_handler=None):
@@ -473,37 +477,38 @@ def run(start_handler=None, any_handler=None):
_start_handler_ref = start_handler
- _updater.dispatcher.add_handler(LangConversation().get_handler(), group=0)
- _updater.dispatcher.add_handler(CommandHandler('start', simplehandler(start_handler), _user_filter))
- _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, any_handler))
+ _application.add_handler(LangConversation().get_handler(), group=0)
+ _application.add_handler(CommandHandler('start',
+ callback=simplehandler(start_handler),
+ filters=_user_filter))
+ _application.add_handler(MessageHandler(filters.ALL & _user_filter, any_handler))
- _updater.start_polling()
- _updater.idle()
+ _application.run_polling()
def add_conversation(conv: conversation) -> None:
- _updater.dispatcher.add_handler(conv.get_handler(), group=0)
+ _application.add_handler(conv.get_handler(), group=0)
def add_handler(h):
- _updater.dispatcher.add_handler(h, group=0)
+ _application.add_handler(h, group=0)
-def start(ctx: Context):
- return _start_handler_ref(ctx)
+async def start(ctx: Context):
+ return await _start_handler_ref(ctx)
-def _default_start_handler(ctx: Context):
+async def _default_start_handler(ctx: Context):
if 'start_message' not in lang:
- return ctx.reply('Please define start_message or override start()')
- ctx.reply(ctx.lang('start_message'))
+ return await ctx.reply('Please define start_message or override start()')
+ await ctx.reply(ctx.lang('start_message'))
@simplehandler
-def _default_any_handler(ctx: Context):
+async def _default_any_handler(ctx: Context):
if 'invalid_command' not in lang:
- return ctx.reply('Please define invalid_command or override any()')
- ctx.reply(ctx.lang('invalid_command'))
+ return await ctx.reply('Please define invalid_command or override any()')
+ await ctx.reply(ctx.lang('invalid_command'))
def _logging_message_handler(update: Update, context: CallbackContext):
@@ -535,7 +540,7 @@ def notify_all(text_getter: callable,
continue
text = text_getter(db.get_user_lang(user_id))
- _updater.bot.send_message(chat_id=user_id,
+ _application.bot.send_message(chat_id=user_id,
text=text,
parse_mode='HTML')
@@ -543,33 +548,33 @@ def notify_all(text_getter: callable,
def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None:
if isinstance(text, Exception):
text = exc2text(text)
- _updater.bot.send_message(chat_id=user_id,
+ _application.bot.send_message(chat_id=user_id,
text=text,
parse_mode='HTML',
**kwargs)
def send_photo(user_id, **kwargs):
- _updater.bot.send_photo(chat_id=user_id, **kwargs)
+ _application.bot.send_photo(chat_id=user_id, **kwargs)
def send_audio(user_id, **kwargs):
- _updater.bot.send_audio(chat_id=user_id, **kwargs)
+ _application.bot.send_audio(chat_id=user_id, **kwargs)
def send_file(user_id, **kwargs):
- _updater.bot.send_document(chat_id=user_id, **kwargs)
+ _application.bot.send_document(chat_id=user_id, **kwargs)
def edit_message_text(user_id, message_id, *args, **kwargs):
- _updater.bot.edit_message_text(chat_id=user_id,
+ _application.bot.edit_message_text(chat_id=user_id,
message_id=message_id,
parse_mode='HTML',
*args, **kwargs)
def delete_message(user_id, message_id):
- _updater.bot.delete_message(chat_id=user_id, message_id=message_id)
+ _application.bot.delete_message(chat_id=user_id, message_id=message_id)
def set_database(_db: BotDatabase):