diff options
Diffstat (limited to 'src/home/telegram/bot.py')
-rw-r--r-- | src/home/telegram/bot.py | 578 |
1 files changed, 0 insertions, 578 deletions
diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py deleted file mode 100644 index 10bfe06..0000000 --- a/src/home/telegram/bot.py +++ /dev/null @@ -1,578 +0,0 @@ -from __future__ import annotations - -import logging -import itertools - -from enum import Enum, auto -from functools import wraps -from typing import Optional, Union, Tuple - -from telegram import Update, ReplyKeyboardMarkup -from telegram.ext import ( - Updater, - Filters, - BaseFilter, - CommandHandler, - MessageHandler, - CallbackQueryHandler, - CallbackContext, - ConversationHandler -) -from telegram.error import TimedOut - -from home.config import config -from home.api import WebAPIClient -from home.api.types import BotType - -from ._botlang import lang, languages -from ._botdb import BotDatabase -from ._botutil import ReportingHelper, exc2text, IgnoreMarkup, user_any_name -from ._botcontext import Context - - -db: Optional[BotDatabase] = None - -_user_filter: Optional[BaseFilter] = None -_cancel_filter = Filters.text(lang.all('cancel')) -_back_filter = Filters.text(lang.all('back')) -_cancel_and_back_filter = Filters.text(lang.all('back') + lang.all('cancel')) - -_logger = logging.getLogger(__name__) -_updater: Optional[Updater] = None -_reporting: Optional[ReportingHelper] = None -_exception_handler: Optional[callable] = None -_dispatcher = None -_markup_getter: Optional[callable] = None -_start_handler_ref: Optional[callable] = None - - -def text_filter(*args): - if not _user_filter: - raise RuntimeError('user_filter is not initialized') - return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter - - -def _handler_of_handler(*args, **kwargs): - self = None - context = None - update = None - - _args = list(args) - while len(_args): - v = _args[0] - if isinstance(v, conversation): - self = v - _args.pop(0) - elif isinstance(v, Update): - update = v - _args.pop(0) - elif isinstance(v, CallbackContext): - context = v - _args.pop(0) - break - - ctx = Context(update, - callback_context=context, - markup_getter=lambda _ctx: None if not _markup_getter else _markup_getter(_ctx), - store=db) - try: - _args.insert(0, ctx) - - f = kwargs['f'] - del kwargs['f'] - - if 'return_with_context' in kwargs: - return_with_context = True - del kwargs['return_with_context'] - else: - return_with_context = False - - if 'argument' in kwargs and kwargs['argument'] == 'message_key': - del kwargs['argument'] - mkey = None - for k, v in lang.get_langpack(ctx.user_lang).items(): - if ctx.text == v: - mkey = k - break - _args.insert(0, mkey) - - if self: - _args.insert(0, self) - - result = f(*_args, **kwargs) - return result if not return_with_context else (result, ctx) - - except Exception as e: - if _exception_handler: - if not _exception_handler(e, ctx) and not isinstance(e, TimedOut): - _logger.exception(e) - if not ctx.is_callback_context(): - ctx.reply_exc(e) - else: - notify_user(ctx.user_id, exc2text(e)) - else: - _logger.exception(e) - - -def handler(**kwargs): - def inner(f): - @wraps(f) - def _handler(*args, **inner_kwargs): - if 'argument' in kwargs and kwargs['argument'] == 'message_key': - inner_kwargs['argument'] = 'message_key' - return _handler_of_handler(f=f, *args, **inner_kwargs) - - messages = [] - texts = [] - - if 'messages' in kwargs: - messages += kwargs['messages'] - if 'message' in kwargs: - messages.append(kwargs['message']) - - if 'text' in kwargs: - texts.append(kwargs['text']) - if 'texts' in kwargs: - texts += kwargs['texts'] - - if messages or texts: - new_messages = list(itertools.chain.from_iterable([lang.all(m) for m in messages])) - texts += new_messages - texts = list(set(texts)) - _updater.dispatcher.add_handler( - MessageHandler(text_filter(*texts), _handler), - group=0 - ) - - if 'command' in kwargs: - _updater.dispatcher.add_handler(CommandHandler(kwargs['command'], _handler), group=0) - - if 'callback' in kwargs: - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) - - return _handler - - return inner - - -def simplehandler(f: callable): - @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) - return _handler - - -def callbackhandler(*args, **kwargs): - def inner(f): - @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) - pattern_kwargs = {} - if kwargs['callback'] != '*': - pattern_kwargs['pattern'] = kwargs['callback'] - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) - return _handler - return inner - - -def exceptionhandler(f: callable): - global _exception_handler - if _exception_handler: - _logger.warning('exception handler already set, we will overwrite it') - _exception_handler = f - - -def defaultreplymarkup(f: callable): - global _markup_getter - _markup_getter = f - - -def convinput(state, is_enter=False, **kwargs): - def inner(f): - f.__dict__['_conv_data'] = dict( - orig_f=f, - enter=is_enter, - type=ConversationMethodType.ENTRY if is_enter and state == 0 else ConversationMethodType.STATE_HANDLER, - state=state, - **kwargs - ) - - @wraps(f) - def _impl(*args, **kwargs): - result, ctx = _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) - if result == conversation.END: - start(ctx) - return result - - return _impl - - return inner - - -def conventer(state, **kwargs): - return convinput(state, is_enter=True, **kwargs) - - -class ConversationMethodType(Enum): - ENTRY = auto() - STATE_HANDLER = auto() - - -class conversation: - END = ConversationHandler.END - STATE_SEQS = [] - - def __init__(self, enable_back=False): - self._logger = logging.getLogger(self.__class__.__name__) - self._user_state_cache = {} - self._back_enabled = enable_back - - def make_handlers(self, f: callable, **kwargs) -> list: - messages = {} - handlers = [] - - if 'messages' in kwargs: - if isinstance(kwargs['messages'], dict): - messages = kwargs['messages'] - else: - for m in kwargs['messages']: - messages[m] = None - - if 'message' in kwargs: - if isinstance(kwargs['message'], str): - messages[kwargs['message']] = None - else: - AttributeError('invalid message type: ' + type(kwargs['message'])) - - if messages: - for message, target_state in messages.items(): - if not target_state: - handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), f)) - else: - handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state))) - - if 'regex' in kwargs: - handlers.append(MessageHandler(Filters.regex(kwargs['regex']) & _user_filter, f)) - - if 'command' in kwargs: - handlers.append(CommandHandler(kwargs['command'], f, _user_filter)) - - return handlers - - def make_invoker(self, state): - def _invoke(update: Update, context: CallbackContext): - ctx = Context(update, - callback_context=context, - markup_getter=lambda _ctx: None if not _markup_getter else _markup_getter(_ctx), - store=db) - return self.invoke(state, ctx) - return _invoke - - def invoke(self, state, ctx: Context): - self._logger.debug(f'invoke, state={state}') - for item in dir(self): - f = getattr(self, item) - if not callable(f) or item.startswith('_') or '_conv_data' not in f.__dict__: - continue - cd = f.__dict__['_conv_data'] - if cd['enter'] and cd['state'] == state: - return cd['orig_f'](self, ctx) - - raise RuntimeError(f'invoke: failed to find method for state {state}') - - def get_handler(self) -> ConversationHandler: - entry_points = [] - states = {} - - l_cancel_filter = _cancel_filter if not self._back_enabled else _cancel_and_back_filter - - for item in dir(self): - f = getattr(self, item) - if not callable(f) or item.startswith('_') or '_conv_data' not in f.__dict__: - continue - - cd = f.__dict__['_conv_data'] - - if cd['type'] == ConversationMethodType.ENTRY: - entry_points = self.make_handlers(f, **cd) - elif cd['type'] == ConversationMethodType.STATE_HANDLER: - states[cd['state']] = self.make_handlers(f, **cd) - states[cd['state']].append( - MessageHandler(_user_filter & ~l_cancel_filter, conversation.invalid) - ) - - fallbacks = [MessageHandler(_user_filter & _cancel_filter, self.cancel)] - if self._back_enabled: - fallbacks.append(MessageHandler(_user_filter & _back_filter, self.back)) - - return ConversationHandler( - entry_points=entry_points, - states=states, - fallbacks=fallbacks - ) - - def get_user_state(self, user_id: int) -> Optional[int]: - if user_id not in self._user_state_cache: - return None - return self._user_state_cache[user_id] - - # TODO store in ctx.user_state - def set_user_state(self, user_id: int, state: Union[int, None]): - if not self._back_enabled: - return - if state is not None: - self._user_state_cache[user_id] = state - else: - del self._user_state_cache[user_id] - - @staticmethod - @simplehandler - def invalid(ctx: Context): - ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) - # return 0 # FIXME is this needed - - @simplehandler - def cancel(self, ctx: Context): - start(ctx) - self.set_user_state(ctx.user_id, None) - return conversation.END - - @simplehandler - def back(self, ctx: Context): - cur_state = self.get_user_state(ctx.user_id) - if cur_state is None: - start(ctx) - self.set_user_state(ctx.user_id, None) - return conversation.END - - new_state = None - for seq in self.STATE_SEQS: - if cur_state in seq: - idx = seq.index(cur_state) - if idx > 0: - return self.invoke(seq[idx-1], ctx) - - if new_state is None: - raise RuntimeError('failed to determine state to go back to') - - @classmethod - def add_cancel_button(cls, ctx: Context, buttons): - buttons.append([ctx.lang('cancel')]) - - @classmethod - def add_back_button(cls, ctx: Context, buttons): - # buttons.insert(0, [ctx.lang('back')]) - buttons.append([ctx.lang('back')]) - - def reply(self, - ctx: Context, - state: Union[int, Enum], - text: str, - buttons: Optional[list], - with_cancel=False, - with_back=False, - buttons_lang_completed=False): - - if buttons: - new_buttons = [] - if not buttons_lang_completed: - for item in buttons: - if isinstance(item, list): - item = map(lambda s: ctx.lang(s), item) - new_buttons.append(list(item)) - elif isinstance(item, str): - new_buttons.append([ctx.lang(item)]) - else: - raise ValueError('invalid type: ' + type(item)) - else: - new_buttons = list(buttons) - - buttons = None - else: - if with_cancel or with_back: - new_buttons = [] - else: - new_buttons = None - - if with_cancel: - self.add_cancel_button(ctx, new_buttons) - if with_back: - if not self._back_enabled: - raise AttributeError(f'back is not enabled for this conversation ({self.__class__.__name__})') - self.add_back_button(ctx, new_buttons) - - markup = ReplyKeyboardMarkup(new_buttons, one_time_keyboard=True) if new_buttons else IgnoreMarkup() - ctx.reply(text, markup=markup) - self.set_user_state(ctx.user_id, state) - return state - - -class LangConversation(conversation): - START, = range(1) - - @conventer(START, command='lang') - def entry(self, ctx: Context): - self._logger.debug(f'current language: {ctx.user_lang}') - - buttons = [] - for name in languages.values(): - buttons.append(name) - markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) - - ctx.reply(ctx.lang('select_language'), markup=markup) - return self.START - - @convinput(START, messages=lang.languages) - def input(self, ctx: Context): - selected_lang = None - for key, value in languages.items(): - if value == ctx.text: - selected_lang = key - break - - if selected_lang is None: - raise ValueError('could not find the language') - - db.set_user_lang(ctx.user_id, selected_lang) - ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) - - return self.END - - -def initialize(): - global _user_filter - global _updater - global _dispatcher - - # init user_filter - if 'users' in config['bot']: - _logger.info('allowed users: ' + str(config['bot']['users'])) - _user_filter = Filters.user(config['bot']['users']) - else: - _user_filter = Filters.all # not sure if this is correct - - # init updater - _updater = Updater(config['bot']['token'], - request_kwargs={'read_timeout': 6, 'connect_timeout': 7}) - - # transparently log all messages - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, _logging_message_handler), group=10) - _updater.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) - - -def run(start_handler=None, any_handler=None): - global db - global _start_handler_ref - - if not start_handler: - start_handler = _default_start_handler - if not any_handler: - any_handler = _default_any_handler - if not db: - db = BotDatabase() - - _start_handler_ref = start_handler - - _updater.dispatcher.add_handler(LangConversation().get_handler(), group=0) - _updater.dispatcher.add_handler(CommandHandler('start', simplehandler(start_handler), _user_filter)) - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, any_handler)) - - _updater.start_polling() - _updater.idle() - - -def add_conversation(conv: conversation) -> None: - _updater.dispatcher.add_handler(conv.get_handler(), group=0) - - -def add_handler(h): - _updater.dispatcher.add_handler(h, group=0) - - -def start(ctx: Context): - return _start_handler_ref(ctx) - - -def _default_start_handler(ctx: Context): - if 'start_message' not in lang: - return ctx.reply('Please define start_message or override start()') - ctx.reply(ctx.lang('start_message')) - - -@simplehandler -def _default_any_handler(ctx: Context): - if 'invalid_command' not in lang: - return ctx.reply('Please define invalid_command or override any()') - ctx.reply(ctx.lang('invalid_command')) - - -def _logging_message_handler(update: Update, context: CallbackContext): - if _reporting: - _reporting.report(update.message) - - -def _logging_callback_handler(update: Update, context: CallbackContext): - if _reporting: - _reporting.report(update.callback_query.message, text=update.callback_query.data) - - -def enable_logging(bot_type: BotType): - api = WebAPIClient(timeout=3) - api.enable_async() - - global _reporting - _reporting = ReportingHelper(api, bot_type) - - -def notify_all(text_getter: callable, - exclude: Tuple[int] = ()) -> None: - if 'notify_users' not in config['bot']: - _logger.error('notify_all() called but no notify_users directive found in the config') - return - - for user_id in config['bot']['notify_users']: - if user_id in exclude: - continue - - text = text_getter(db.get_user_lang(user_id)) - _updater.bot.send_message(chat_id=user_id, - text=text, - parse_mode='HTML') - - -def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None: - if isinstance(text, Exception): - text = exc2text(text) - _updater.bot.send_message(chat_id=user_id, - text=text, - parse_mode='HTML', - **kwargs) - - -def send_photo(user_id, **kwargs): - _updater.bot.send_photo(chat_id=user_id, **kwargs) - - -def send_audio(user_id, **kwargs): - _updater.bot.send_audio(chat_id=user_id, **kwargs) - - -def send_file(user_id, **kwargs): - _updater.bot.send_document(chat_id=user_id, **kwargs) - - -def edit_message_text(user_id, message_id, *args, **kwargs): - _updater.bot.edit_message_text(chat_id=user_id, - message_id=message_id, - parse_mode='HTML', - *args, **kwargs) - - -def delete_message(user_id, message_id): - _updater.bot.delete_message(chat_id=user_id, message_id=message_id) - - -def set_database(_db: BotDatabase): - global db - db = _db - |