from __future__ import annotations import logging import itertools from enum import Enum, auto from functools import wraps from typing import Optional, Union, Tuple, Coroutine from telegram import Update, ReplyKeyboardMarkup from telegram.ext import ( Application, filters, CommandHandler, MessageHandler, CallbackQueryHandler, CallbackContext, ConversationHandler ) from telegram.ext.filters import BaseFilter from telegram.error import TimedOut from homekit.config import config from homekit.api import WebApiClient from homekit.api.types import BotType from ._botlang import lang, languages from ._botdb import BotDatabase from ._botutil import ReportingHelper, exc2text, IgnoreMarkup, user_any_name from ._botcontext import Context db: Optional[BotDatabase] = None _user_filter: Optional[BaseFilter] = None _cancel_filter = filters.Text(lang.all('cancel')) _back_filter = filters.Text(lang.all('back')) _cancel_and_back_filter = filters.Text(lang.all('back') + lang.all('cancel')) _logger = logging.getLogger(__name__) _application: Optional[Application] = None _reporting: Optional[ReportingHelper] = None _exception_handler: Optional[Coroutine] = None _dispatcher = None _markup_getter: Optional[callable] = None _start_handler_ref: Optional[Coroutine] = None def text_filter(*args): if not _user_filter: raise RuntimeError('user_filter is not initialized') return filters.Text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter async def _handler_of_handler(*args, **kwargs): self = None context = None update = None _args = list(args) while len(_args): v = _args[0] if isinstance(v, conversation): self = v _args.pop(0) elif isinstance(v, Update): update = v _args.pop(0) elif isinstance(v, CallbackContext): context = v _args.pop(0) break ctx = Context(update, callback_context=context, markup_getter=lambda _ctx: None if not _markup_getter else _markup_getter(_ctx), store=db) try: _args.insert(0, ctx) f = kwargs['f'] del kwargs['f'] if 'return_with_context' in kwargs: return_with_context = True del kwargs['return_with_context'] else: return_with_context = False if 'argument' in kwargs and kwargs['argument'] == 'message_key': del kwargs['argument'] mkey = None for k, v in lang.get_langpack(ctx.user_lang).items(): if ctx.text == v: mkey = k break _args.insert(0, mkey) if self: _args.insert(0, self) result = await f(*_args, **kwargs) return result if not return_with_context else (result, ctx) except Exception as e: if _exception_handler: if not _exception_handler(e, ctx) and not isinstance(e, TimedOut): _logger.exception(e) if not ctx.is_callback_context(): await ctx.reply_exc(e) else: notify_user(ctx.user_id, exc2text(e)) else: _logger.exception(e) def handler(**kwargs): def inner(f): @wraps(f) async def _handler(*args, **inner_kwargs): if 'argument' in kwargs and kwargs['argument'] == 'message_key': inner_kwargs['argument'] = 'message_key' return await _handler_of_handler(f=f, *args, **inner_kwargs) messages = [] texts = [] if 'messages' in kwargs: messages += kwargs['messages'] if 'message' in kwargs: messages.append(kwargs['message']) if 'text' in kwargs: texts.append(kwargs['text']) if 'texts' in kwargs: texts += kwargs['texts'] if messages or texts: new_messages = list(itertools.chain.from_iterable([lang.all(m) for m in messages])) texts += new_messages texts = list(set(texts)) _application.add_handler( MessageHandler(text_filter(*texts), _handler), group=0 ) if 'command' in kwargs: _application.add_handler(CommandHandler(kwargs['command'], _handler), group=0) if 'callback' in kwargs: _application.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) return _handler return inner def simplehandler(f: Coroutine): @wraps(f) async def _handler(*args, **kwargs): return await _handler_of_handler(f=f, *args, **kwargs) return _handler def callbackhandler(*args, **kwargs): def inner(f): @wraps(f) async def _handler(*args, **kwargs): return await _handler_of_handler(f=f, *args, **kwargs) pattern_kwargs = {} if kwargs['callback'] != '*': pattern_kwargs['pattern'] = kwargs['callback'] _application.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) return _handler return inner async def exceptionhandler(f: callable): global _exception_handler if _exception_handler: _logger.warning('exception handler already set, we will overwrite it') _exception_handler = f def defaultreplymarkup(f: callable): global _markup_getter _markup_getter = f def convinput(state, is_enter=False, **kwargs): def inner(f): f.__dict__['_conv_data'] = dict( orig_f=f, enter=is_enter, type=ConversationMethodType.ENTRY if is_enter and state == 0 else ConversationMethodType.STATE_HANDLER, state=state, **kwargs ) @wraps(f) async def _impl(*args, **kwargs): result, ctx = await _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) if result == conversation.END: await start(ctx) return result return _impl return inner def conventer(state, **kwargs): return convinput(state, is_enter=True, **kwargs) class ConversationMethodType(Enum): ENTRY = auto() STATE_HANDLER = auto() class conversation: END = ConversationHandler.END STATE_SEQS = [] def __init__(self, enable_back=False): self._logger = logging.getLogger(self.__class__.__name__) self._user_state_cache = {} self._back_enabled = enable_back def make_handlers(self, f: callable, **kwargs) -> list: messages = {} handlers = [] if 'messages' in kwargs: if isinstance(kwargs['messages'], dict): messages = kwargs['messages'] else: for m in kwargs['messages']: messages[m] = None if 'message' in kwargs: if isinstance(kwargs['message'], str): messages[kwargs['message']] = None else: AttributeError('invalid message type: ' + type(kwargs['message'])) if messages: for message, target_state in messages.items(): if not target_state: handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), f)) else: handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state))) if 'regex' in kwargs: handlers.append(MessageHandler(filters.Regex(kwargs['regex']) & _user_filter, f)) if 'command' in kwargs: handlers.append(CommandHandler(kwargs['command'], f, _user_filter)) return handlers def make_invoker(self, state): def _invoke(update: Update, context: CallbackContext): ctx = Context(update, callback_context=context, markup_getter=lambda _ctx: None if not _markup_getter else _markup_getter(_ctx), store=db) return self.invoke(state, ctx) return _invoke def invoke(self, state, ctx: Context): self._logger.debug(f'invoke, state={state}') for item in dir(self): f = getattr(self, item) if not callable(f) or item.startswith('_') or '_conv_data' not in f.__dict__: continue cd = f.__dict__['_conv_data'] if cd['enter'] and cd['state'] == state: return cd['orig_f'](self, ctx) raise RuntimeError(f'invoke: failed to find method for state {state}') def get_handler(self) -> ConversationHandler: entry_points = [] states = {} l_cancel_filter = _cancel_filter if not self._back_enabled else _cancel_and_back_filter for item in dir(self): f = getattr(self, item) if not callable(f) or item.startswith('_') or '_conv_data' not in f.__dict__: continue cd = f.__dict__['_conv_data'] if cd['type'] == ConversationMethodType.ENTRY: entry_points = self.make_handlers(f, **cd) elif cd['type'] == ConversationMethodType.STATE_HANDLER: states[cd['state']] = self.make_handlers(f, **cd) states[cd['state']].append( MessageHandler(_user_filter & ~l_cancel_filter, conversation.invalid) ) fallbacks = [MessageHandler(_user_filter & _cancel_filter, self.cancel)] if self._back_enabled: fallbacks.append(MessageHandler(_user_filter & _back_filter, self.back)) return ConversationHandler( entry_points=entry_points, states=states, fallbacks=fallbacks ) def get_user_state(self, user_id: int) -> Optional[int]: if user_id not in self._user_state_cache: return None return self._user_state_cache[user_id] # TODO store in ctx.user_state def set_user_state(self, user_id: int, state: Union[int, None]): if not self._back_enabled: return if state is not None: self._user_state_cache[user_id] = state else: del self._user_state_cache[user_id] @staticmethod @simplehandler async def invalid(ctx: Context): await ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) # return 0 # FIXME is this needed @simplehandler async def cancel(self, ctx: Context): await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @simplehandler async def back(self, ctx: Context): cur_state = self.get_user_state(ctx.user_id) if cur_state is None: await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END new_state = None for seq in self.STATE_SEQS: if cur_state in seq: idx = seq.index(cur_state) if idx > 0: return self.invoke(seq[idx-1], ctx) if new_state is None: raise RuntimeError('failed to determine state to go back to') @classmethod def add_cancel_button(cls, ctx: Context, buttons): buttons.append([ctx.lang('cancel')]) @classmethod def add_back_button(cls, ctx: Context, buttons): # buttons.insert(0, [ctx.lang('back')]) buttons.append([ctx.lang('back')]) def reply(self, ctx: Context, state: Union[int, Enum], text: str, buttons: Optional[list], with_cancel=False, with_back=False, buttons_lang_completed=False): if buttons: new_buttons = [] if not buttons_lang_completed: for item in buttons: if isinstance(item, list): item = map(lambda s: ctx.lang(s), item) new_buttons.append(list(item)) elif isinstance(item, str): new_buttons.append([ctx.lang(item)]) else: raise ValueError('invalid type: ' + type(item)) else: new_buttons = list(buttons) buttons = None else: if with_cancel or with_back: new_buttons = [] else: new_buttons = None if with_cancel: self.add_cancel_button(ctx, new_buttons) if with_back: if not self._back_enabled: raise AttributeError(f'back is not enabled for this conversation ({self.__class__.__name__})') self.add_back_button(ctx, new_buttons) markup = ReplyKeyboardMarkup(new_buttons, one_time_keyboard=True) if new_buttons else IgnoreMarkup() ctx.reply(text, markup=markup) self.set_user_state(ctx.user_id, state) return state class LangConversation(conversation): START, = range(1) @conventer(START, command='lang') async def entry(self, ctx: Context): self._logger.debug(f'current language: {ctx.user_lang}') buttons = [] for name in languages.values(): buttons.append(name) markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) await ctx.reply(ctx.lang('select_language'), markup=markup) return self.START @convinput(START, messages=lang.languages) async def input(self, ctx: Context): selected_lang = None for key, value in languages.items(): if value == ctx.text: selected_lang = key break if selected_lang is None: raise ValueError('could not find the language') db.set_user_lang(ctx.user_id, selected_lang) await ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) return self.END def initialize(): global _user_filter global _application # global _updater global _dispatcher # init user_filter _user_ids = config.app_config.get_user_ids() if len(_user_ids) > 0: _logger.info('allowed users: ' + str(_user_ids)) _user_filter = filters.User(_user_ids) else: _user_filter = filters.ALL # not sure if this is correct _application = Application.builder()\ .token(config.app_config.get('bot.token'))\ .connect_timeout(7)\ .read_timeout(6)\ .build() # transparently log all messages # _application.dispatcher.add_handler(MessageHandler(filters.ALL & _user_filter, _logging_message_handler), group=10) # _application.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) def run(start_handler=None, any_handler=None): global db global _start_handler_ref if not start_handler: start_handler = _default_start_handler if not any_handler: any_handler = _default_any_handler if not db: db = BotDatabase() _start_handler_ref = start_handler _application.add_handler(LangConversation().get_handler(), group=0) _application.add_handler(CommandHandler('start', callback=simplehandler(start_handler), filters=_user_filter)) _application.add_handler(MessageHandler(filters.ALL & _user_filter, any_handler)) _application.run_polling() def add_conversation(conv: conversation) -> None: _application.add_handler(conv.get_handler(), group=0) def add_handler(h): _application.add_handler(h, group=0) async def start(ctx: Context): return await _start_handler_ref(ctx) async def _default_start_handler(ctx: Context): if 'start_message' not in lang: return await ctx.reply('Please define start_message or override start()') await ctx.reply(ctx.lang('start_message')) @simplehandler async def _default_any_handler(ctx: Context): if 'invalid_command' not in lang: return await ctx.reply('Please define invalid_command or override any()') await ctx.reply(ctx.lang('invalid_command')) def _logging_message_handler(update: Update, context: CallbackContext): if _reporting: _reporting.report(update.message) def _logging_callback_handler(update: Update, context: CallbackContext): if _reporting: _reporting.report(update.callback_query.message, text=update.callback_query.data) def enable_logging(bot_type: BotType): api = WebApiClient(timeout=3) api.enable_async() global _reporting _reporting = ReportingHelper(api, bot_type) def notify_all(text_getter: callable, exclude: Tuple[int] = ()) -> None: if 'notify_users' not in config['bot']: _logger.error('notify_all() called but no notify_users directive found in the config') return for user_id in config['bot']['notify_users']: if user_id in exclude: continue text = text_getter(db.get_user_lang(user_id)) _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML') def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None: if isinstance(text, Exception): text = exc2text(text) _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML', **kwargs) def send_photo(user_id, **kwargs): _application.bot.send_photo(chat_id=user_id, **kwargs) def send_audio(user_id, **kwargs): _application.bot.send_audio(chat_id=user_id, **kwargs) def send_file(user_id, **kwargs): _application.bot.send_document(chat_id=user_id, **kwargs) def edit_message_text(user_id, message_id, *args, **kwargs): _application.bot.edit_message_text(chat_id=user_id, message_id=message_id, parse_mode='HTML', *args, **kwargs) def delete_message(user_id, message_id): _application.bot.delete_message(chat_id=user_id, message_id=message_id) def set_database(_db: BotDatabase): global db db = _db