diff options
Diffstat (limited to 'src/home/telegram/bot.py')
-rw-r--r-- | src/home/telegram/bot.py | 542 |
1 files changed, 542 insertions, 0 deletions
diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py new file mode 100644 index 0000000..602573b --- /dev/null +++ b/src/home/telegram/bot.py @@ -0,0 +1,542 @@ +from __future__ import annotations + +import logging + +from enum import Enum, auto +from functools import wraps +from typing import Optional, Union, List, Tuple, Dict + +from telegram import ( + Update, + ParseMode, + ReplyKeyboardMarkup, + CallbackQuery, + User, + Message, +) +from telegram.ext import ( + Updater, + Filters, + BaseFilter, + Handler, + CommandHandler, + MessageHandler, + CallbackQueryHandler, + CallbackContext, + ConversationHandler +) +from telegram.error import TimedOut + +from home.config import config +from home.api import WebAPIClient +from home.api.types import BotType +from home.api.errors import ApiResponseError + +from ._botlang import lang, languages +from ._botdb import BotDatabase +from ._botutil import ReportingHelper, exc2text, IgnoreMarkup +from ._botcontext import Context + + +# LANG_STARTED, = range(1) + +user_filter: Optional[BaseFilter] = None +cancel_filter = Filters.text(lang.all('cancel')) +back_filter = Filters.text(lang.all('back')) +cancel_and_back_filter = Filters.text(lang.all('back') + lang.all('cancel')) + +db: Optional[BotDatabase] = None + +_logger = logging.getLogger(__name__) +_updater: Optional[Updater] = None +_reporting: Optional[ReportingHelper] = None +_exception_handler: Optional[callable] = None +_dispatcher = None +_markup_getter: Optional[callable] = None +_start_handler_ref: Optional[callable] = None + + +def text_filter(*args): + if not user_filter: + raise RuntimeError('user_filter is not initialized') + return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & user_filter + + +def _handler_of_handler(*args, **kwargs): + self = None + context = None + update = None + + _args = list(args) + while len(_args): + v = _args[0] + if isinstance(v, conversation): + self = v + _args.pop(0) + elif isinstance(v, Update): + update = v + _args.pop(0) + elif isinstance(v, CallbackContext): + context = v + _args.pop(0) + break + + ctx = Context(update, + callback_context=context, + markup_getter=lambda _ctx: None if not _markup_getter else _markup_getter(_ctx), + store=db) + try: + _args.insert(0, ctx) + if self: + _args.insert(0, self) + + f = kwargs['f'] + del kwargs['f'] + + if 'return_with_context' in kwargs: + return_with_context = True + del kwargs['return_with_context'] + else: + return_with_context = False + + result = f(*_args, **kwargs) + return result if not return_with_context else (result, ctx) + + except Exception as e: + if _exception_handler: + if not _exception_handler(e, ctx) and not isinstance(e, TimedOut): + _logger.exception(e) + if not ctx.is_callback_context(): + ctx.reply_exc(e) + else: + notify_user(ctx.user_id, exc2text(e)) + + +def handler(**kwargs): + def inner(f): + @wraps(f) + def _handler(*args, **kwargs): + return _handler_of_handler(f=f, *args, **kwargs) + + if 'message' in kwargs: + _updater.dispatcher.add_handler(MessageHandler(text_filter(lang.all(kwargs['message'])), _handler), group=0) + elif 'command' in kwargs: + _updater.dispatcher.add_handler(CommandHandler(kwargs['command'], _handler), group=0) + elif 'callback' in kwargs: + _updater.dispatcher.add_handler(CallbackQueryHandler(_handler), group=0) + return _handler + return inner + + +def simplehandler(f: callable): + @wraps(f) + def _handler(*args, **kwargs): + return _handler_of_handler(f=f, *args, **kwargs) + return _handler + + +def callbackhandler(f: callable): + @wraps(f) + def _handler(*args, **kwargs): + return _handler_of_handler(f=f, *args, **kwargs) + _updater.dispatcher.add_handler(CallbackQueryHandler(_handler), group=0) + return _handler + + +def exceptionhandler(f: callable): + global _exception_handler + if _exception_handler: + _logger.warning('exception handler already set, we will overwrite it') + _exception_handler = f + + +def defaultreplymarkup(f: callable): + global _markup_getter + _markup_getter = f + + +def convinput(state, is_enter=False, **kwargs): + def inner(f): + f.__dict__['_conv_data'] = dict( + orig_f=f, + enter=is_enter, + type=ConversationMethodType.ENTRY if is_enter and state == 0 else ConversationMethodType.STATE_HANDLER, + state=state, + **kwargs + ) + + @wraps(f) + def _impl(*args, **kwargs): + result, ctx = _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) + if result == conversation.END: + start(ctx) + return result + + return _impl + + return inner + + +def conventer(state, **kwargs): + return convinput(state, is_enter=True, **kwargs) + + +class ConversationMethodType(Enum): + ENTRY = auto() + STATE_HANDLER = auto() + + +class conversation: + END = ConversationHandler.END + STATE_SEQS = [] + + def __init__(self, enable_back=False): + self._logger = logging.getLogger(self.__class__.__name__) + self._user_state_cache = {} + self._back_enabled = enable_back + + def make_handlers(self, f: callable, **kwargs) -> list: + messages = {} + handlers = [] + + if 'messages' in kwargs: + if isinstance(kwargs['messages'], dict): + messages = kwargs['messages'] + else: + for m in kwargs['messages']: + messages[m] = None + + if 'message' in kwargs: + if isinstance(kwargs['message'], str): + messages[kwargs['message']] = None + else: + AttributeError('invalid message type: ' + type(kwargs['message'])) + + if messages: + for message, target_state in messages.items(): + if not target_state: + handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), f)) + else: + handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state))) + + if 'regex' in kwargs: + handlers.append(MessageHandler(Filters.regex(kwargs['regex']) & user_filter, f)) + + if 'command' in kwargs: + handlers.append(CommandHandler(kwargs['command'], f, user_filter)) + + return handlers + + def make_invoker(self, state): + def _invoke(update: Update, context: CallbackContext): + ctx = Context(update, + callback_context=context, + markup_getter=lambda _ctx: None if not _markup_getter else _markup_getter(_ctx), + store=db) + return self.invoke(state, ctx) + return _invoke + + def invoke(self, state, ctx: Context): + self._logger.debug(f'invoke, state={state}') + for item in dir(self): + f = getattr(self, item) + if not callable(f) or item.startswith('_') or '_conv_data' not in f.__dict__: + continue + cd = f.__dict__['_conv_data'] + if cd['enter'] and cd['state'] == state: + return cd['orig_f'](self, ctx) + + raise RuntimeError(f'invoke: failed to find method for state {state}') + + def get_handler(self) -> ConversationHandler: + entry_points = [] + states = {} + + l_cancel_filter = cancel_filter if not self._back_enabled else cancel_and_back_filter + + for item in dir(self): + f = getattr(self, item) + if not callable(f) or item.startswith('_') or '_conv_data' not in f.__dict__: + continue + + cd = f.__dict__['_conv_data'] + + if cd['type'] == ConversationMethodType.ENTRY: + entry_points = self.make_handlers(f, **cd) + elif cd['type'] == ConversationMethodType.STATE_HANDLER: + states[cd['state']] = self.make_handlers(f, **cd) + states[cd['state']].append( + MessageHandler(user_filter & ~l_cancel_filter, conversation.invalid) + ) + + fallbacks = [MessageHandler(user_filter & cancel_filter, self.cancel)] + if self._back_enabled: + fallbacks.append(MessageHandler(user_filter & back_filter, self.back)) + + return ConversationHandler( + entry_points=entry_points, + states=states, + fallbacks=fallbacks + ) + + def get_user_state(self, user_id: int) -> Optional[int]: + if user_id not in self._user_state_cache: + return None + return self._user_state_cache[user_id] + + # TODO store in ctx.user_state + def set_user_state(self, user_id: int, state: Union[int, None]): + if not self._back_enabled: + return + if state is not None: + self._user_state_cache[user_id] = state + else: + del self._user_state_cache[user_id] + + @staticmethod + @simplehandler + def invalid(ctx: Context): + ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) + # return 0 # FIXME is this needed + + @simplehandler + def cancel(self, ctx: Context): + start(ctx) + self.set_user_state(ctx.user_id, None) + return conversation.END + + @simplehandler + def back(self, ctx: Context): + cur_state = self.get_user_state(ctx.user_id) + if cur_state is None: + start(ctx) + self.set_user_state(ctx.user_id, None) + return conversation.END + + new_state = None + for seq in self.STATE_SEQS: + if cur_state in seq: + idx = seq.index(cur_state) + if idx > 0: + return self.invoke(seq[idx-1], ctx) + + if new_state is None: + raise RuntimeError('failed to determine state to go back to') + + @classmethod + def add_cancel_button(cls, ctx: Context, buttons): + buttons.append([ctx.lang('cancel')]) + + @classmethod + def add_back_button(cls, ctx: Context, buttons): + # buttons.insert(0, [ctx.lang('back')]) + buttons.append([ctx.lang('back')]) + + def reply(self, + ctx: Context, + state: Union[int, Enum], + text: str, + buttons: Optional[list], + with_cancel=False, + with_back=False, + buttons_lang_completed=False): + + if buttons: + new_buttons = [] + if not buttons_lang_completed: + for item in buttons: + if isinstance(item, list): + item = map(lambda s: ctx.lang(s), item) + new_buttons.append(list(item)) + elif isinstance(item, str): + new_buttons.append([ctx.lang(item)]) + else: + raise ValueError('invalid type: ' + type(item)) + else: + new_buttons = list(buttons) + + buttons = None + else: + if with_cancel or with_back: + new_buttons = [] + else: + new_buttons = None + + if with_cancel: + self.add_cancel_button(ctx, new_buttons) + if with_back: + if not self._back_enabled: + raise AttributeError(f'back is not enabled for this conversation ({self.__class__.__name__})') + self.add_back_button(ctx, new_buttons) + + markup = ReplyKeyboardMarkup(new_buttons, one_time_keyboard=True) if new_buttons else IgnoreMarkup() + ctx.reply(text, markup=markup) + self.set_user_state(ctx.user_id, state) + return state + + +class LangConversation(conversation): + START, = range(1) + + @conventer(START, command='lang') + def entry(self, ctx: Context): + self._logger.debug(f'current language: {ctx.user_lang}') + + buttons = [] + for name in languages.values(): + buttons.append(name) + markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) + + ctx.reply(ctx.lang('select_language'), markup=markup) + return self.START + + @convinput(START, messages=lang.languages) + def input(self, ctx: Context): + selected_lang = None + for key, value in languages.items(): + if value == ctx.text: + selected_lang = key + break + + if selected_lang is None: + raise ValueError('could not find the language') + + db.set_user_lang(ctx.user_id, selected_lang) + ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) + + return self.END + + +def initialize(): + global user_filter + global _updater + global _dispatcher + + # init user_filter + if 'users' in config['bot']: + _logger.info('allowed users: ' + str(config['bot']['users'])) + user_filter = Filters.user(config['bot']['users']) + else: + user_filter = Filters.all # not sure if this is correct + + # init updater + _updater = Updater(config['bot']['token'], + request_kwargs={'read_timeout': 6, 'connect_timeout': 7}) + + # transparently log all messages + _updater.dispatcher.add_handler(MessageHandler(Filters.all & user_filter, _logging_message_handler), group=10) + _updater.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) + + +def run(start_handler=None, any_handler=None): + global db + global _start_handler_ref + + if not start_handler: + start_handler = _default_start_handler + if not any_handler: + any_handler = _default_any_handler + if not db: + db = BotDatabase() + + _start_handler_ref = start_handler + + _updater.dispatcher.add_handler(LangConversation().get_handler(), group=0) + _updater.dispatcher.add_handler(CommandHandler('start', simplehandler(start_handler), user_filter)) + _updater.dispatcher.add_handler(MessageHandler(Filters.all & user_filter, any_handler)) + + _updater.start_polling() + _updater.idle() + + +def add_conversation(conv: conversation) -> None: + _updater.dispatcher.add_handler(conv.get_handler(), group=0) + + +def start(ctx: Context): + return _start_handler_ref(ctx) + + +def _default_start_handler(ctx: Context): + if 'start_message' not in lang: + return ctx.reply('Please define start_message or override start()') + ctx.reply(ctx.lang('start_message')) + + +@simplehandler +def _default_any_handler(ctx: Context): + if 'invalid_command' not in lang: + return ctx.reply('Please define invalid_command or override any()') + ctx.reply(ctx.lang('invalid_command')) + + +def _logging_message_handler(update: Update, context: CallbackContext): + if _reporting: + _reporting.report(update.message) + + +def _logging_callback_handler(update: Update, context: CallbackContext): + if _reporting: + _reporting.report(update.callback_query.message, text=update.callback_query.data) + + +def enable_logging(bot_type: BotType): + api = WebAPIClient(timeout=3) + api.enable_async() + + global _reporting + _reporting = ReportingHelper(api, bot_type) + + +def notify_all(text_getter: callable, + exclude: Tuple[int] = ()) -> None: + if 'notify_users' not in config['bot']: + _logger.error('notify_all() called but no notify_users directive found in the config') + return + + for user_id in config['bot']['notify_users']: + if user_id in exclude: + continue + + text = text_getter(db.get_user_lang(user_id)) + _updater.bot.send_message(chat_id=user_id, + text=text, + parse_mode='HTML') + + +def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None: + if isinstance(text, Exception): + text = exc2text(text) + _updater.bot.send_message(chat_id=user_id, + text=text, + parse_mode='HTML', + **kwargs) + + +def send_photo(user_id, **kwargs): + _updater.bot.send_photo(chat_id=user_id, **kwargs) + + +def send_audio(user_id, **kwargs): + _updater.bot.send_audio(chat_id=user_id, **kwargs) + + +def send_file(user_id, **kwargs): + _updater.bot.send_document(chat_id=user_id, **kwargs) + + +def edit_message_text(user_id, message_id, *args, **kwargs): + _updater.bot.edit_message_text(chat_id=user_id, + message_id=message_id, + parse_mode='HTML', + *args, **kwargs) + + +def delete_message(user_id, message_id): + _updater.bot.delete_message(chat_id=user_id, message_id=message_id) + + +def set_database(_db: BotDatabase): + global db + db = _db + |