diff options
author | Evgeny Zinoviev <me@ch1p.io> | 2022-11-06 20:40:42 +0300 |
---|---|---|
committer | Evgeny Zinoviev <me@ch1p.io> | 2022-11-06 20:53:55 +0300 |
commit | 75ee161b6eb64cf19c8a9718d15047443f3e4ebe (patch) | |
tree | ccebc9cbd2709ad13a14ec00372fdcfe9226cd9f /src/home/telegram | |
parent | 28c67c4510a3bee574b4077be35147dba257c8f7 (diff) |
inverter_bot: refactor and introduce new functions
Diffstat (limited to 'src/home/telegram')
-rw-r--r-- | src/home/telegram/__init__.py | 2 | ||||
-rw-r--r-- | src/home/telegram/_botcontext.py | 85 | ||||
-rw-r--r-- | src/home/telegram/_botdb.py | 32 | ||||
-rw-r--r-- | src/home/telegram/_botlang.py | 117 | ||||
-rw-r--r-- | src/home/telegram/_botutil.py | 47 | ||||
-rw-r--r-- | src/home/telegram/bot.py | 542 |
6 files changed, 824 insertions, 1 deletions
diff --git a/src/home/telegram/__init__.py b/src/home/telegram/__init__.py index 8565b40..a68dae1 100644 --- a/src/home/telegram/__init__.py +++ b/src/home/telegram/__init__.py @@ -1 +1 @@ -from .telegram import send_message, send_photo
\ No newline at end of file +from .telegram import send_message, send_photo diff --git a/src/home/telegram/_botcontext.py b/src/home/telegram/_botcontext.py new file mode 100644 index 0000000..f343eeb --- /dev/null +++ b/src/home/telegram/_botcontext.py @@ -0,0 +1,85 @@ +from typing import Optional, List + +from telegram import Update, ParseMode, User, CallbackQuery +from telegram.ext import CallbackContext + +from ._botdb import BotDatabase +from ._botlang import lang +from ._botutil import IgnoreMarkup, exc2text + + +class Context: + _update: Optional[Update] + _callback_context: Optional[CallbackContext] + _markup_getter: callable + db: Optional[BotDatabase] + _user_lang: Optional[str] + + def __init__(self, + update: Optional[Update], + callback_context: Optional[CallbackContext], + markup_getter: callable, + store: Optional[BotDatabase]): + self._update = update + self._callback_context = callback_context + self._markup_getter = markup_getter + self._store = store + self._user_lang = None + + def reply(self, text, markup=None): + if markup is None: + markup = self._markup_getter(self) + kwargs = dict(parse_mode=ParseMode.HTML) + if not isinstance(markup, IgnoreMarkup): + kwargs['reply_markup'] = markup + return self._update.message.reply_text(text, **kwargs) + + def reply_exc(self, e: Exception) -> None: + self.reply(exc2text(e), markup=IgnoreMarkup()) + + def answer(self, text: str = None): + self.callback_query.answer(text) + + def edit(self, text, markup=None): + kwargs = dict(parse_mode=ParseMode.HTML) + if not isinstance(markup, IgnoreMarkup): + kwargs['reply_markup'] = markup + self.callback_query.edit_message_text(text, **kwargs) + + @property + def text(self) -> str: + return self._update.message.text + + @property + def callback_query(self) -> CallbackQuery: + return self._update.callback_query + + @property + def args(self) -> Optional[List[str]]: + return self._callback_context.args + + @property + def user_id(self) -> int: + return self.user.id + + @property + def user_data(self): + return self._callback_context.user_data + + @property + def user(self) -> User: + return self._update.effective_user + + @property + def user_lang(self) -> str: + if self._user_lang is None: + self._user_lang = self._store.get_user_lang(self.user_id) + return self._user_lang + + def lang(self, key: str, *args) -> str: + return lang.get(key, self.user_lang, *args) + + def is_callback_context(self) -> bool: + return self._update.callback_query \ + and self._update.callback_query.data \ + and self._update.callback_query.data != '' diff --git a/src/home/telegram/_botdb.py b/src/home/telegram/_botdb.py new file mode 100644 index 0000000..9e9cf94 --- /dev/null +++ b/src/home/telegram/_botdb.py @@ -0,0 +1,32 @@ +from home.database.sqlite import SQLiteBase + + +class BotDatabase(SQLiteBase): + def __init__(self): + super().__init__() + + def schema_init(self, version: int) -> None: + if version < 1: + cursor = self.cursor() + cursor.execute("""CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + lang TEXT NOT NULL + )""") + self.commit() + + def get_user_lang(self, user_id: int, default: str = 'en') -> str: + cursor = self.cursor() + cursor.execute('SELECT lang FROM users WHERE id=?', (user_id,)) + row = cursor.fetchone() + + if row is None: + cursor.execute('INSERT INTO users (id, lang) VALUES (?, ?)', (user_id, default)) + self.commit() + return default + else: + return row[0] + + def set_user_lang(self, user_id: int, lang: str) -> None: + cursor = self.cursor() + cursor.execute('UPDATE users SET lang=? WHERE id=?', (lang, user_id)) + self.commit() diff --git a/src/home/telegram/_botlang.py b/src/home/telegram/_botlang.py new file mode 100644 index 0000000..318b8b0 --- /dev/null +++ b/src/home/telegram/_botlang.py @@ -0,0 +1,117 @@ +import logging + +from typing import Optional, Dict, List, Union + +_logger = logging.getLogger(__name__) + + +class LangStrings(dict): + _lang: Optional[str] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._lang = None + + def setlang(self, lang: str): + self._lang = lang + + def __missing__(self, key): + _logger.warning(f'key {key} is missing in language {self._lang}') + return '{%s}' % key + + def __setitem__(self, key, value): + raise NotImplementedError(f'setting translation strings this way is prohibited (was trying to set {key}={value})') + + +class LangPack: + strings: Dict[str, LangStrings[str, str]] + default_lang: str + + def __init__(self): + self.strings = {} + self.default_lang = 'en' + + def ru(self, **kwargs) -> None: + self.set(kwargs, 'ru') + + def en(self, **kwargs) -> None: + self.set(kwargs, 'en') + + def set(self, + strings: Union[LangStrings, dict], + lang: str) -> None: + + if isinstance(strings, dict) and not isinstance(strings, LangStrings): + strings = LangStrings(**strings) + strings.setlang(lang) + + if lang not in self.strings: + self.strings[lang] = strings + else: + self.strings[lang].update(strings) + + def all(self, key): + result = [] + for strings in self.strings.values(): + result.append(strings[key]) + return result + + @property + def languages(self) -> List[str]: + return list(self.strings.keys()) + + def get(self, key: str, lang: str, *args) -> str: + if args: + return self.strings[lang][key] % args + else: + return self.strings[lang][key] + + def __call__(self, *args, **kwargs): + return self.strings[self.default_lang][args[0]] + + def __getitem__(self, key): + return self.strings[self.default_lang][key] + + def __setitem__(self, key, value): + raise NotImplementedError('setting translation strings this way is prohibited') + + def __contains__(self, key): + return key in self.strings[self.default_lang] + + @staticmethod + def pfx(prefix: str, l: list) -> list: + return list(map(lambda s: f'{prefix}{s}', l)) + + + +languages = { + 'en': 'English', + 'ru': 'Русский' +} + + +lang = LangPack() +lang.en( + en='English', + ru='Russian', + start_message="Select command on the keyboard.", + unknown_message="Unknown message", + cancel="🚫 Cancel", + back='🔙 Back', + select_language="Select language on the keyboard.", + invalid_language="Invalid language. Please try again.", + saved='Saved.', + please_wait="⏳ Please wait..." +) +lang.ru( + en='Английский', + ru='Русский', + start_message="Выберите команду на клавиатуре.", + unknown_message="Неизвестная команда", + cancel="🚫 Отмена", + back='🔙 Назад', + select_language="Выберите язык на клавиатуре.", + invalid_language="Неверный язык. Пожалуйста, попробуйте снова", + saved="Настройки сохранены.", + please_wait="⏳ Ожидайте..." +)
\ No newline at end of file diff --git a/src/home/telegram/_botutil.py b/src/home/telegram/_botutil.py new file mode 100644 index 0000000..6d1ee8f --- /dev/null +++ b/src/home/telegram/_botutil.py @@ -0,0 +1,47 @@ +import logging +import traceback + +from html import escape +from telegram import User +from home.api import WebAPIClient as APIClient +from home.api.types import BotType +from home.api.errors import ApiResponseError + +_logger = logging.getLogger(__name__) + + +def user_any_name(user: User) -> str: + name = [user.first_name, user.last_name] + name = list(filter(lambda s: s is not None, name)) + name = ' '.join(name).strip() + + if not name: + name = user.username + + if not name: + name = str(user.id) + + return name + + +class ReportingHelper: + def __init__(self, client: APIClient, bot_type: BotType): + self.client = client + self.bot_type = bot_type + + def report(self, message, text: str = None) -> None: + if text is None: + text = message.text + try: + self.client.log_bot_request(self.bot_type, message.chat_id, text) + except ApiResponseError as error: + _logger.exception(error) + + +def exc2text(e: Exception) -> str: + tb = ''.join(traceback.format_tb(e.__traceback__)) + return f'{e.__class__.__name__}: ' + escape(str(e)) + "\n\n" + escape(tb) + + +class IgnoreMarkup: + pass diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py new file mode 100644 index 0000000..602573b --- /dev/null +++ b/src/home/telegram/bot.py @@ -0,0 +1,542 @@ +from __future__ import annotations + +import logging + +from enum import Enum, auto +from functools import wraps +from typing import Optional, Union, List, Tuple, Dict + +from telegram import ( + Update, + ParseMode, + ReplyKeyboardMarkup, + CallbackQuery, + User, + Message, +) +from telegram.ext import ( + Updater, + Filters, + BaseFilter, + Handler, + CommandHandler, + MessageHandler, + CallbackQueryHandler, + CallbackContext, + ConversationHandler +) +from telegram.error import TimedOut + +from home.config import config +from home.api import WebAPIClient +from home.api.types import BotType +from home.api.errors import ApiResponseError + +from ._botlang import lang, languages +from ._botdb import BotDatabase +from ._botutil import ReportingHelper, exc2text, IgnoreMarkup +from ._botcontext import Context + + +# LANG_STARTED, = range(1) + +user_filter: Optional[BaseFilter] = None +cancel_filter = Filters.text(lang.all('cancel')) +back_filter = Filters.text(lang.all('back')) +cancel_and_back_filter = Filters.text(lang.all('back') + lang.all('cancel')) + +db: Optional[BotDatabase] = None + +_logger = logging.getLogger(__name__) +_updater: Optional[Updater] = None +_reporting: Optional[ReportingHelper] = None +_exception_handler: Optional[callable] = None +_dispatcher = None +_markup_getter: Optional[callable] = None +_start_handler_ref: Optional[callable] = None + + +def text_filter(*args): + if not user_filter: + raise RuntimeError('user_filter is not initialized') + return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & user_filter + + +def _handler_of_handler(*args, **kwargs): + self = None + context = None + update = None + + _args = list(args) + while len(_args): + v = _args[0] + if isinstance(v, conversation): + self = v + _args.pop(0) + elif isinstance(v, Update): + update = v + _args.pop(0) + elif isinstance(v, CallbackContext): + context = v + _args.pop(0) + break + + ctx = Context(update, + callback_context=context, + markup_getter=lambda _ctx: None if not _markup_getter else _markup_getter(_ctx), + store=db) + try: + _args.insert(0, ctx) + if self: + _args.insert(0, self) + + f = kwargs['f'] + del kwargs['f'] + + if 'return_with_context' in kwargs: + return_with_context = True + del kwargs['return_with_context'] + else: + return_with_context = False + + result = f(*_args, **kwargs) + return result if not return_with_context else (result, ctx) + + except Exception as e: + if _exception_handler: + if not _exception_handler(e, ctx) and not isinstance(e, TimedOut): + _logger.exception(e) + if not ctx.is_callback_context(): + ctx.reply_exc(e) + else: + notify_user(ctx.user_id, exc2text(e)) + + +def handler(**kwargs): + def inner(f): + @wraps(f) + def _handler(*args, **kwargs): + return _handler_of_handler(f=f, *args, **kwargs) + + if 'message' in kwargs: + _updater.dispatcher.add_handler(MessageHandler(text_filter(lang.all(kwargs['message'])), _handler), group=0) + elif 'command' in kwargs: + _updater.dispatcher.add_handler(CommandHandler(kwargs['command'], _handler), group=0) + elif 'callback' in kwargs: + _updater.dispatcher.add_handler(CallbackQueryHandler(_handler), group=0) + return _handler + return inner + + +def simplehandler(f: callable): + @wraps(f) + def _handler(*args, **kwargs): + return _handler_of_handler(f=f, *args, **kwargs) + return _handler + + +def callbackhandler(f: callable): + @wraps(f) + def _handler(*args, **kwargs): + return _handler_of_handler(f=f, *args, **kwargs) + _updater.dispatcher.add_handler(CallbackQueryHandler(_handler), group=0) + return _handler + + +def exceptionhandler(f: callable): + global _exception_handler + if _exception_handler: + _logger.warning('exception handler already set, we will overwrite it') + _exception_handler = f + + +def defaultreplymarkup(f: callable): + global _markup_getter + _markup_getter = f + + +def convinput(state, is_enter=False, **kwargs): + def inner(f): + f.__dict__['_conv_data'] = dict( + orig_f=f, + enter=is_enter, + type=ConversationMethodType.ENTRY if is_enter and state == 0 else ConversationMethodType.STATE_HANDLER, + state=state, + **kwargs + ) + + @wraps(f) + def _impl(*args, **kwargs): + result, ctx = _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) + if result == conversation.END: + start(ctx) + return result + + return _impl + + return inner + + +def conventer(state, **kwargs): + return convinput(state, is_enter=True, **kwargs) + + +class ConversationMethodType(Enum): + ENTRY = auto() + STATE_HANDLER = auto() + + +class conversation: + END = ConversationHandler.END + STATE_SEQS = [] + + def __init__(self, enable_back=False): + self._logger = logging.getLogger(self.__class__.__name__) + self._user_state_cache = {} + self._back_enabled = enable_back + + def make_handlers(self, f: callable, **kwargs) -> list: + messages = {} + handlers = [] + + if 'messages' in kwargs: + if isinstance(kwargs['messages'], dict): + messages = kwargs['messages'] + else: + for m in kwargs['messages']: + messages[m] = None + + if 'message' in kwargs: + if isinstance(kwargs['message'], str): + messages[kwargs['message']] = None + else: + AttributeError('invalid message type: ' + type(kwargs['message'])) + + if messages: + for message, target_state in messages.items(): + if not target_state: + handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), f)) + else: + handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state))) + + if 'regex' in kwargs: + handlers.append(MessageHandler(Filters.regex(kwargs['regex']) & user_filter, f)) + + if 'command' in kwargs: + handlers.append(CommandHandler(kwargs['command'], f, user_filter)) + + return handlers + + def make_invoker(self, state): + def _invoke(update: Update, context: CallbackContext): + ctx = Context(update, + callback_context=context, + markup_getter=lambda _ctx: None if not _markup_getter else _markup_getter(_ctx), + store=db) + return self.invoke(state, ctx) + return _invoke + + def invoke(self, state, ctx: Context): + self._logger.debug(f'invoke, state={state}') + for item in dir(self): + f = getattr(self, item) + if not callable(f) or item.startswith('_') or '_conv_data' not in f.__dict__: + continue + cd = f.__dict__['_conv_data'] + if cd['enter'] and cd['state'] == state: + return cd['orig_f'](self, ctx) + + raise RuntimeError(f'invoke: failed to find method for state {state}') + + def get_handler(self) -> ConversationHandler: + entry_points = [] + states = {} + + l_cancel_filter = cancel_filter if not self._back_enabled else cancel_and_back_filter + + for item in dir(self): + f = getattr(self, item) + if not callable(f) or item.startswith('_') or '_conv_data' not in f.__dict__: + continue + + cd = f.__dict__['_conv_data'] + + if cd['type'] == ConversationMethodType.ENTRY: + entry_points = self.make_handlers(f, **cd) + elif cd['type'] == ConversationMethodType.STATE_HANDLER: + states[cd['state']] = self.make_handlers(f, **cd) + states[cd['state']].append( + MessageHandler(user_filter & ~l_cancel_filter, conversation.invalid) + ) + + fallbacks = [MessageHandler(user_filter & cancel_filter, self.cancel)] + if self._back_enabled: + fallbacks.append(MessageHandler(user_filter & back_filter, self.back)) + + return ConversationHandler( + entry_points=entry_points, + states=states, + fallbacks=fallbacks + ) + + def get_user_state(self, user_id: int) -> Optional[int]: + if user_id not in self._user_state_cache: + return None + return self._user_state_cache[user_id] + + # TODO store in ctx.user_state + def set_user_state(self, user_id: int, state: Union[int, None]): + if not self._back_enabled: + return + if state is not None: + self._user_state_cache[user_id] = state + else: + del self._user_state_cache[user_id] + + @staticmethod + @simplehandler + def invalid(ctx: Context): + ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) + # return 0 # FIXME is this needed + + @simplehandler + def cancel(self, ctx: Context): + start(ctx) + self.set_user_state(ctx.user_id, None) + return conversation.END + + @simplehandler + def back(self, ctx: Context): + cur_state = self.get_user_state(ctx.user_id) + if cur_state is None: + start(ctx) + self.set_user_state(ctx.user_id, None) + return conversation.END + + new_state = None + for seq in self.STATE_SEQS: + if cur_state in seq: + idx = seq.index(cur_state) + if idx > 0: + return self.invoke(seq[idx-1], ctx) + + if new_state is None: + raise RuntimeError('failed to determine state to go back to') + + @classmethod + def add_cancel_button(cls, ctx: Context, buttons): + buttons.append([ctx.lang('cancel')]) + + @classmethod + def add_back_button(cls, ctx: Context, buttons): + # buttons.insert(0, [ctx.lang('back')]) + buttons.append([ctx.lang('back')]) + + def reply(self, + ctx: Context, + state: Union[int, Enum], + text: str, + buttons: Optional[list], + with_cancel=False, + with_back=False, + buttons_lang_completed=False): + + if buttons: + new_buttons = [] + if not buttons_lang_completed: + for item in buttons: + if isinstance(item, list): + item = map(lambda s: ctx.lang(s), item) + new_buttons.append(list(item)) + elif isinstance(item, str): + new_buttons.append([ctx.lang(item)]) + else: + raise ValueError('invalid type: ' + type(item)) + else: + new_buttons = list(buttons) + + buttons = None + else: + if with_cancel or with_back: + new_buttons = [] + else: + new_buttons = None + + if with_cancel: + self.add_cancel_button(ctx, new_buttons) + if with_back: + if not self._back_enabled: + raise AttributeError(f'back is not enabled for this conversation ({self.__class__.__name__})') + self.add_back_button(ctx, new_buttons) + + markup = ReplyKeyboardMarkup(new_buttons, one_time_keyboard=True) if new_buttons else IgnoreMarkup() + ctx.reply(text, markup=markup) + self.set_user_state(ctx.user_id, state) + return state + + +class LangConversation(conversation): + START, = range(1) + + @conventer(START, command='lang') + def entry(self, ctx: Context): + self._logger.debug(f'current language: {ctx.user_lang}') + + buttons = [] + for name in languages.values(): + buttons.append(name) + markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) + + ctx.reply(ctx.lang('select_language'), markup=markup) + return self.START + + @convinput(START, messages=lang.languages) + def input(self, ctx: Context): + selected_lang = None + for key, value in languages.items(): + if value == ctx.text: + selected_lang = key + break + + if selected_lang is None: + raise ValueError('could not find the language') + + db.set_user_lang(ctx.user_id, selected_lang) + ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) + + return self.END + + +def initialize(): + global user_filter + global _updater + global _dispatcher + + # init user_filter + if 'users' in config['bot']: + _logger.info('allowed users: ' + str(config['bot']['users'])) + user_filter = Filters.user(config['bot']['users']) + else: + user_filter = Filters.all # not sure if this is correct + + # init updater + _updater = Updater(config['bot']['token'], + request_kwargs={'read_timeout': 6, 'connect_timeout': 7}) + + # transparently log all messages + _updater.dispatcher.add_handler(MessageHandler(Filters.all & user_filter, _logging_message_handler), group=10) + _updater.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) + + +def run(start_handler=None, any_handler=None): + global db + global _start_handler_ref + + if not start_handler: + start_handler = _default_start_handler + if not any_handler: + any_handler = _default_any_handler + if not db: + db = BotDatabase() + + _start_handler_ref = start_handler + + _updater.dispatcher.add_handler(LangConversation().get_handler(), group=0) + _updater.dispatcher.add_handler(CommandHandler('start', simplehandler(start_handler), user_filter)) + _updater.dispatcher.add_handler(MessageHandler(Filters.all & user_filter, any_handler)) + + _updater.start_polling() + _updater.idle() + + +def add_conversation(conv: conversation) -> None: + _updater.dispatcher.add_handler(conv.get_handler(), group=0) + + +def start(ctx: Context): + return _start_handler_ref(ctx) + + +def _default_start_handler(ctx: Context): + if 'start_message' not in lang: + return ctx.reply('Please define start_message or override start()') + ctx.reply(ctx.lang('start_message')) + + +@simplehandler +def _default_any_handler(ctx: Context): + if 'invalid_command' not in lang: + return ctx.reply('Please define invalid_command or override any()') + ctx.reply(ctx.lang('invalid_command')) + + +def _logging_message_handler(update: Update, context: CallbackContext): + if _reporting: + _reporting.report(update.message) + + +def _logging_callback_handler(update: Update, context: CallbackContext): + if _reporting: + _reporting.report(update.callback_query.message, text=update.callback_query.data) + + +def enable_logging(bot_type: BotType): + api = WebAPIClient(timeout=3) + api.enable_async() + + global _reporting + _reporting = ReportingHelper(api, bot_type) + + +def notify_all(text_getter: callable, + exclude: Tuple[int] = ()) -> None: + if 'notify_users' not in config['bot']: + _logger.error('notify_all() called but no notify_users directive found in the config') + return + + for user_id in config['bot']['notify_users']: + if user_id in exclude: + continue + + text = text_getter(db.get_user_lang(user_id)) + _updater.bot.send_message(chat_id=user_id, + text=text, + parse_mode='HTML') + + +def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None: + if isinstance(text, Exception): + text = exc2text(text) + _updater.bot.send_message(chat_id=user_id, + text=text, + parse_mode='HTML', + **kwargs) + + +def send_photo(user_id, **kwargs): + _updater.bot.send_photo(chat_id=user_id, **kwargs) + + +def send_audio(user_id, **kwargs): + _updater.bot.send_audio(chat_id=user_id, **kwargs) + + +def send_file(user_id, **kwargs): + _updater.bot.send_document(chat_id=user_id, **kwargs) + + +def edit_message_text(user_id, message_id, *args, **kwargs): + _updater.bot.edit_message_text(chat_id=user_id, + message_id=message_id, + parse_mode='HTML', + *args, **kwargs) + + +def delete_message(user_id, message_id): + _updater.bot.delete_message(chat_id=user_id, message_id=message_id) + + +def set_database(_db: BotDatabase): + global db + db = _db + |