diff options
Diffstat (limited to 'src/home/telegram')
-rw-r--r-- | src/home/telegram/_botcontext.py | 19 | ||||
-rw-r--r-- | src/home/telegram/bot.py | 149 | ||||
-rw-r--r-- | src/home/telegram/config.py | 75 |
3 files changed, 162 insertions, 81 deletions
diff --git a/src/home/telegram/_botcontext.py b/src/home/telegram/_botcontext.py index f343eeb..a143bfe 100644 --- a/src/home/telegram/_botcontext.py +++ b/src/home/telegram/_botcontext.py @@ -1,6 +1,7 @@ from typing import Optional, List -from telegram import Update, ParseMode, User, CallbackQuery +from telegram import Update, User, CallbackQuery +from telegram.constants import ParseMode from telegram.ext import CallbackContext from ._botdb import BotDatabase @@ -26,25 +27,25 @@ class Context: self._store = store self._user_lang = None - def reply(self, text, markup=None): + async def reply(self, text, markup=None): if markup is None: markup = self._markup_getter(self) kwargs = dict(parse_mode=ParseMode.HTML) if not isinstance(markup, IgnoreMarkup): kwargs['reply_markup'] = markup - return self._update.message.reply_text(text, **kwargs) + return await self._update.message.reply_text(text, **kwargs) - def reply_exc(self, e: Exception) -> None: - self.reply(exc2text(e), markup=IgnoreMarkup()) + async def reply_exc(self, e: Exception) -> None: + await self.reply(exc2text(e), markup=IgnoreMarkup()) - def answer(self, text: str = None): - self.callback_query.answer(text) + async def answer(self, text: str = None): + await self.callback_query.answer(text) - def edit(self, text, markup=None): + async def edit(self, text, markup=None): kwargs = dict(parse_mode=ParseMode.HTML) if not isinstance(markup, IgnoreMarkup): kwargs['reply_markup'] = markup - self.callback_query.edit_message_text(text, **kwargs) + await self.callback_query.edit_message_text(text, **kwargs) @property def text(self) -> str: diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py index 10bfe06..7e22263 100644 --- a/src/home/telegram/bot.py +++ b/src/home/telegram/bot.py @@ -5,19 +5,19 @@ import itertools from enum import Enum, auto from functools import wraps -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Coroutine from telegram import Update, ReplyKeyboardMarkup from telegram.ext import ( - Updater, - Filters, - BaseFilter, + Application, + filters, CommandHandler, MessageHandler, CallbackQueryHandler, CallbackContext, ConversationHandler ) +from telegram.ext.filters import BaseFilter from telegram.error import TimedOut from home.config import config @@ -33,26 +33,26 @@ from ._botcontext import Context db: Optional[BotDatabase] = None _user_filter: Optional[BaseFilter] = None -_cancel_filter = Filters.text(lang.all('cancel')) -_back_filter = Filters.text(lang.all('back')) -_cancel_and_back_filter = Filters.text(lang.all('back') + lang.all('cancel')) +_cancel_filter = filters.Text(lang.all('cancel')) +_back_filter = filters.Text(lang.all('back')) +_cancel_and_back_filter = filters.Text(lang.all('back') + lang.all('cancel')) _logger = logging.getLogger(__name__) -_updater: Optional[Updater] = None +_application: Optional[Application] = None _reporting: Optional[ReportingHelper] = None -_exception_handler: Optional[callable] = None +_exception_handler: Optional[Coroutine] = None _dispatcher = None _markup_getter: Optional[callable] = None -_start_handler_ref: Optional[callable] = None +_start_handler_ref: Optional[Coroutine] = None def text_filter(*args): if not _user_filter: raise RuntimeError('user_filter is not initialized') - return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter + return filters.Text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter -def _handler_of_handler(*args, **kwargs): +async def _handler_of_handler(*args, **kwargs): self = None context = None update = None @@ -99,7 +99,7 @@ def _handler_of_handler(*args, **kwargs): if self: _args.insert(0, self) - result = f(*_args, **kwargs) + result = await f(*_args, **kwargs) return result if not return_with_context else (result, ctx) except Exception as e: @@ -107,7 +107,7 @@ def _handler_of_handler(*args, **kwargs): if not _exception_handler(e, ctx) and not isinstance(e, TimedOut): _logger.exception(e) if not ctx.is_callback_context(): - ctx.reply_exc(e) + await ctx.reply_exc(e) else: notify_user(ctx.user_id, exc2text(e)) else: @@ -117,10 +117,10 @@ def _handler_of_handler(*args, **kwargs): def handler(**kwargs): def inner(f): @wraps(f) - def _handler(*args, **inner_kwargs): + async def _handler(*args, **inner_kwargs): if 'argument' in kwargs and kwargs['argument'] == 'message_key': inner_kwargs['argument'] = 'message_key' - return _handler_of_handler(f=f, *args, **inner_kwargs) + return await _handler_of_handler(f=f, *args, **inner_kwargs) messages = [] texts = [] @@ -139,43 +139,43 @@ def handler(**kwargs): new_messages = list(itertools.chain.from_iterable([lang.all(m) for m in messages])) texts += new_messages texts = list(set(texts)) - _updater.dispatcher.add_handler( + _application.add_handler( MessageHandler(text_filter(*texts), _handler), group=0 ) if 'command' in kwargs: - _updater.dispatcher.add_handler(CommandHandler(kwargs['command'], _handler), group=0) + _application.add_handler(CommandHandler(kwargs['command'], _handler), group=0) if 'callback' in kwargs: - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) + _application.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) return _handler return inner -def simplehandler(f: callable): +def simplehandler(f: Coroutine): @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) + async def _handler(*args, **kwargs): + return await _handler_of_handler(f=f, *args, **kwargs) return _handler def callbackhandler(*args, **kwargs): def inner(f): @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) + async def _handler(*args, **kwargs): + return await _handler_of_handler(f=f, *args, **kwargs) pattern_kwargs = {} if kwargs['callback'] != '*': pattern_kwargs['pattern'] = kwargs['callback'] - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) + _application.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) return _handler return inner -def exceptionhandler(f: callable): +async def exceptionhandler(f: callable): global _exception_handler if _exception_handler: _logger.warning('exception handler already set, we will overwrite it') @@ -198,10 +198,10 @@ def convinput(state, is_enter=False, **kwargs): ) @wraps(f) - def _impl(*args, **kwargs): - result, ctx = _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) + async def _impl(*args, **kwargs): + result, ctx = await _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) if result == conversation.END: - start(ctx) + await start(ctx) return result return _impl @@ -252,7 +252,7 @@ class conversation: handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state))) if 'regex' in kwargs: - handlers.append(MessageHandler(Filters.regex(kwargs['regex']) & _user_filter, f)) + handlers.append(MessageHandler(filters.Regex(kwargs['regex']) & _user_filter, f)) if 'command' in kwargs: handlers.append(CommandHandler(kwargs['command'], f, _user_filter)) @@ -327,21 +327,21 @@ class conversation: @staticmethod @simplehandler - def invalid(ctx: Context): - ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) + async def invalid(ctx: Context): + await ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) # return 0 # FIXME is this needed @simplehandler - def cancel(self, ctx: Context): - start(ctx) + async def cancel(self, ctx: Context): + await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @simplehandler - def back(self, ctx: Context): + async def back(self, ctx: Context): cur_state = self.get_user_state(ctx.user_id) if cur_state is None: - start(ctx) + await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @@ -411,7 +411,7 @@ class LangConversation(conversation): START, = range(1) @conventer(START, command='lang') - def entry(self, ctx: Context): + async def entry(self, ctx: Context): self._logger.debug(f'current language: {ctx.user_lang}') buttons = [] @@ -419,11 +419,11 @@ class LangConversation(conversation): buttons.append(name) markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) - ctx.reply(ctx.lang('select_language'), markup=markup) + await ctx.reply(ctx.lang('select_language'), markup=markup) return self.START @convinput(START, messages=lang.languages) - def input(self, ctx: Context): + async def input(self, ctx: Context): selected_lang = None for key, value in languages.items(): if value == ctx.text: @@ -434,30 +434,34 @@ class LangConversation(conversation): raise ValueError('could not find the language') db.set_user_lang(ctx.user_id, selected_lang) - ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) + await ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) return self.END def initialize(): global _user_filter - global _updater + global _application + # global _updater global _dispatcher # init user_filter - if 'users' in config['bot']: - _logger.info('allowed users: ' + str(config['bot']['users'])) - _user_filter = Filters.user(config['bot']['users']) + _user_ids = config.app_config.get_user_ids() + if len(_user_ids) > 0: + _logger.info('allowed users: ' + str(_user_ids)) + _user_filter = filters.User(_user_ids) else: - _user_filter = Filters.all # not sure if this is correct + _user_filter = filters.ALL # not sure if this is correct - # init updater - _updater = Updater(config['bot']['token'], - request_kwargs={'read_timeout': 6, 'connect_timeout': 7}) + _application = Application.builder()\ + .token(config.app_config.get('bot.token'))\ + .connect_timeout(7)\ + .read_timeout(6)\ + .build() # transparently log all messages - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, _logging_message_handler), group=10) - _updater.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) + # _application.dispatcher.add_handler(MessageHandler(filters.ALL & _user_filter, _logging_message_handler), group=10) + # _application.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) def run(start_handler=None, any_handler=None): @@ -473,37 +477,38 @@ def run(start_handler=None, any_handler=None): _start_handler_ref = start_handler - _updater.dispatcher.add_handler(LangConversation().get_handler(), group=0) - _updater.dispatcher.add_handler(CommandHandler('start', simplehandler(start_handler), _user_filter)) - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, any_handler)) + _application.add_handler(LangConversation().get_handler(), group=0) + _application.add_handler(CommandHandler('start', + callback=simplehandler(start_handler), + filters=_user_filter)) + _application.add_handler(MessageHandler(filters.ALL & _user_filter, any_handler)) - _updater.start_polling() - _updater.idle() + _application.run_polling() def add_conversation(conv: conversation) -> None: - _updater.dispatcher.add_handler(conv.get_handler(), group=0) + _application.add_handler(conv.get_handler(), group=0) def add_handler(h): - _updater.dispatcher.add_handler(h, group=0) + _application.add_handler(h, group=0) -def start(ctx: Context): - return _start_handler_ref(ctx) +async def start(ctx: Context): + return await _start_handler_ref(ctx) -def _default_start_handler(ctx: Context): +async def _default_start_handler(ctx: Context): if 'start_message' not in lang: - return ctx.reply('Please define start_message or override start()') - ctx.reply(ctx.lang('start_message')) + return await ctx.reply('Please define start_message or override start()') + await ctx.reply(ctx.lang('start_message')) @simplehandler -def _default_any_handler(ctx: Context): +async def _default_any_handler(ctx: Context): if 'invalid_command' not in lang: - return ctx.reply('Please define invalid_command or override any()') - ctx.reply(ctx.lang('invalid_command')) + return await ctx.reply('Please define invalid_command or override any()') + await ctx.reply(ctx.lang('invalid_command')) def _logging_message_handler(update: Update, context: CallbackContext): @@ -535,7 +540,7 @@ def notify_all(text_getter: callable, continue text = text_getter(db.get_user_lang(user_id)) - _updater.bot.send_message(chat_id=user_id, + _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML') @@ -543,33 +548,33 @@ def notify_all(text_getter: callable, def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None: if isinstance(text, Exception): text = exc2text(text) - _updater.bot.send_message(chat_id=user_id, + _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML', **kwargs) def send_photo(user_id, **kwargs): - _updater.bot.send_photo(chat_id=user_id, **kwargs) + _application.bot.send_photo(chat_id=user_id, **kwargs) def send_audio(user_id, **kwargs): - _updater.bot.send_audio(chat_id=user_id, **kwargs) + _application.bot.send_audio(chat_id=user_id, **kwargs) def send_file(user_id, **kwargs): - _updater.bot.send_document(chat_id=user_id, **kwargs) + _application.bot.send_document(chat_id=user_id, **kwargs) def edit_message_text(user_id, message_id, *args, **kwargs): - _updater.bot.edit_message_text(chat_id=user_id, + _application.bot.edit_message_text(chat_id=user_id, message_id=message_id, parse_mode='HTML', *args, **kwargs) def delete_message(user_id, message_id): - _updater.bot.delete_message(chat_id=user_id, message_id=message_id) + _application.bot.delete_message(chat_id=user_id, message_id=message_id) def set_database(_db: BotDatabase): diff --git a/src/home/telegram/config.py b/src/home/telegram/config.py new file mode 100644 index 0000000..7a46087 --- /dev/null +++ b/src/home/telegram/config.py @@ -0,0 +1,75 @@ +from ..config import ConfigUnit +from typing import Optional, Union +from abc import ABC +from enum import Enum + + +class TelegramUserListType(Enum): + USERS = 'users' + NOTIFY = 'notify_users' + + +class TelegramUserIdsConfig(ConfigUnit): + NAME = 'telegram_user_ids' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'roottype': 'dict', + 'type': 'integer' + } + + +_user_ids_config = TelegramUserIdsConfig() + + +def _user_id_mapper(user: Union[str, int]) -> int: + if isinstance(user, int): + return user + return _user_ids_config[user] + + +class TelegramChatsConfig(ConfigUnit): + NAME = 'telegram_chats' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'type': 'dict', + 'schema': { + 'id': {'type': 'string', 'required': True}, + 'token': {'type': 'string', 'required': True}, + } + } + + +class TelegramBotConfig(ConfigUnit, ABC): + @staticmethod + def schema() -> Optional[dict]: + return { + 'bot': { + 'type': 'dict', + 'schema': { + 'token': {'type': 'string', 'required': True}, + TelegramUserListType.USERS: {**TelegramBotConfig._userlist_schema(), 'required': True}, + TelegramUserListType.NOTIFY: TelegramBotConfig._userlist_schema(), + } + } + } + + @staticmethod + def _userlist_schema() -> dict: + return {'type': 'list', 'schema': {'type': ['string', 'int']}} + + @staticmethod + def custom_validator(data): + for ult in TelegramUserListType: + users = data['bot'][ult.value] + for user in users: + if isinstance(user, str): + if user not in _user_ids_config: + raise ValueError(f'user {user} not found in {TelegramUserIdsConfig.NAME}') + + def get_user_ids(self, + ult: TelegramUserListType = TelegramUserListType.USERS) -> list[int]: + return list(map(_user_id_mapper, self['bot'][ult.value]))
\ No newline at end of file |