diff options
Diffstat (limited to 'src/home/bot')
-rw-r--r-- | src/home/bot/__init__.py | 6 | ||||
-rw-r--r-- | src/home/bot/errors.py | 2 | ||||
-rw-r--r-- | src/home/bot/lang.py | 76 | ||||
-rw-r--r-- | src/home/bot/reporting.py | 22 | ||||
-rw-r--r-- | src/home/bot/store.py | 80 | ||||
-rw-r--r-- | src/home/bot/util.py | 57 | ||||
-rw-r--r-- | src/home/bot/wrapper.py | 339 |
7 files changed, 582 insertions, 0 deletions
diff --git a/src/home/bot/__init__.py b/src/home/bot/__init__.py new file mode 100644 index 0000000..5e68af7 --- /dev/null +++ b/src/home/bot/__init__.py @@ -0,0 +1,6 @@ +from .reporting import ReportingHelper +from .lang import LangPack +from .wrapper import Wrapper, Context, text_filter +from .store import Store +from .errors import * +from .util import command_usage, user_any_name
\ No newline at end of file diff --git a/src/home/bot/errors.py b/src/home/bot/errors.py new file mode 100644 index 0000000..74eee6f --- /dev/null +++ b/src/home/bot/errors.py @@ -0,0 +1,2 @@ +class StoreNotEnabledError(Exception): + pass
\ No newline at end of file diff --git a/src/home/bot/lang.py b/src/home/bot/lang.py new file mode 100644 index 0000000..2f10358 --- /dev/null +++ b/src/home/bot/lang.py @@ -0,0 +1,76 @@ +import logging + +from typing import Union, Optional + +logger = logging.getLogger(__name__) + + +class LangStrings(dict): + _lang: Optional[str] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._lang = None + + def setlang(self, lang: str): + self._lang = lang + + def __missing__(self, key): + logger.warning(f'key {key} is missing in language {self._lang}') + return '{%s}' % key + + def __setitem__(self, key, value): + raise NotImplementedError(f'setting translation strings this way is prohibited (was trying to set {key}={value})') + + +class LangPack: + strings: dict[str, LangStrings[str, str]] + default_lang: str + + def __init__(self): + self.strings = {} + self.default_lang = 'en' + + def ru(self, **kwargs) -> None: + self.set(kwargs, 'ru') + + def en(self, **kwargs) -> None: + self.set(kwargs, 'en') + + def set(self, + strings: Union[LangStrings, dict], + lang: str) -> None: + + if isinstance(strings, dict) and not isinstance(strings, LangStrings): + strings = LangStrings(**strings) + strings.setlang(lang) + + if lang not in self.strings: + self.strings[lang] = strings + else: + self.strings[lang].update(strings) + + def all(self, key): + result = [] + for strings in self.strings.values(): + result.append(strings[key]) + return result + + @property + def languages(self) -> list[str]: + return list(self.strings.keys()) + + def get(self, key: str, lang: str, *args) -> str: + return self.strings[lang][key] % args + + def __call__(self, *args, **kwargs): + return self.strings[self.default_lang][args[0]] + + def __getitem__(self, key): + return self.strings[self.default_lang][key] + + def __setitem__(self, key, value): + raise NotImplementedError('setting translation strings this way is prohibited') + + def __contains__(self, key): + return key in self.strings[self.default_lang] diff --git a/src/home/bot/reporting.py b/src/home/bot/reporting.py new file mode 100644 index 0000000..df3da2a --- /dev/null +++ b/src/home/bot/reporting.py @@ -0,0 +1,22 @@ +import logging + +from telegram import Message +from ..api import WebAPIClient as APIClient +from ..api.errors import ApiResponseError +from ..api.types import BotType + +logger = logging.getLogger(__name__) + + +class ReportingHelper: + def __init__(self, client: APIClient, bot_type: BotType): + self.client = client + self.bot_type = bot_type + + def report(self, message, text: str = None) -> None: + if text is None: + text = message.text + try: + self.client.log_bot_request(self.bot_type, message.chat_id, text) + except ApiResponseError as error: + logger.exception(error) diff --git a/src/home/bot/store.py b/src/home/bot/store.py new file mode 100644 index 0000000..aeedc47 --- /dev/null +++ b/src/home/bot/store.py @@ -0,0 +1,80 @@ +import sqlite3 +import os.path +import logging + +from ..config import config + +logger = logging.getLogger(__name__) + + +def _get_database_path() -> str: + return os.path.join(os.environ['HOME'], '.config', config.app_name, 'bot.db') + + +class Store: + SCHEMA_VERSION = 1 + + def __init__(self): + self.sqlite = sqlite3.connect(_get_database_path(), check_same_thread=False) + + sqlite_version = self._get_sqlite_version() + logger.info(f'SQLite version: {sqlite_version}') + + schema_version = self._get_schema_version() + logger.info(f'Schema version: {schema_version}') + + if schema_version < 1: + self._database_init() + elif schema_version < Store.SCHEMA_VERSION: + self._database_upgrade(Store.SCHEMA_VERSION) + + def __del__(self): + if self.sqlite: + self.sqlite.commit() + self.sqlite.close() + + def _get_sqlite_version(self) -> str: + cursor = self.sqlite.cursor() + cursor.execute("SELECT sqlite_version()") + + return cursor.fetchone()[0] + + def _get_schema_version(self) -> int: + cursor = self.sqlite.execute('PRAGMA user_version') + return int(cursor.fetchone()[0]) + + def _set_schema_version(self, v) -> None: + self.sqlite.execute('PRAGMA user_version={:d}'.format(v)) + logger.info(f'Schema set to {v}') + + def _database_init(self) -> None: + cursor = self.sqlite.cursor() + cursor.execute("""CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + lang TEXT NOT NULL + )""") + self.sqlite.commit() + self._set_schema_version(1) + + def _database_upgrade(self, version: int) -> None: + # do the upgrade here + + # self.sqlite.commit() + self._set_schema_version(version) + + def get_user_lang(self, user_id: int, default: str = 'en') -> str: + cursor = self.sqlite.cursor() + cursor.execute('SELECT lang FROM users WHERE id=?', (user_id,)) + row = cursor.fetchone() + + if row is None: + cursor.execute('INSERT INTO users (id, lang) VALUES (?, ?)', (user_id, default)) + self.sqlite.commit() + return default + else: + return row[0] + + def set_user_lang(self, user_id: int, lang: str) -> None: + cursor = self.sqlite.cursor() + cursor.execute('UPDATE users SET lang=? WHERE id=?', (lang, user_id)) + self.sqlite.commit()
\ No newline at end of file diff --git a/src/home/bot/util.py b/src/home/bot/util.py new file mode 100644 index 0000000..4f80a67 --- /dev/null +++ b/src/home/bot/util.py @@ -0,0 +1,57 @@ +from telegram import User +from .lang import LangStrings + +_strings = { + 'en': LangStrings( + usage='Usage', + arguments='Arguments' + ), + 'ru': LangStrings( + usage='Использование', + arguments='Аргументы' + ) +} + + +def command_usage(command: str, arguments: dict, language='en') -> str: + if language not in _strings: + raise ValueError('unsupported language') + + blocks = [] + argument_names = [] + argument_lines = [] + for k, v in arguments.items(): + argument_names.append(k) + argument_lines.append( + f'<code>{k}</code>: {v}' + ) + + command = f'/{command}' + if argument_names: + command += ' ' + ' '.join(argument_names) + + blocks.append( + f'<b>{_strings[language]["usage"]}</b>\n' + f'<code>{command}</code>' + ) + + if argument_lines: + blocks.append( + f'<b>{_strings[language]["arguments"]}</b>\n' + '\n'.join(argument_lines) + ) + + return '\n\n'.join(blocks) + + +def user_any_name(user: User) -> str: + name = [user.first_name, user.last_name] + name = list(filter(lambda s: s is not None, name)) + name = ' '.join(name).strip() + + if not name: + name = user.username + + if not name: + name = str(user.id) + + return name diff --git a/src/home/bot/wrapper.py b/src/home/bot/wrapper.py new file mode 100644 index 0000000..8651e90 --- /dev/null +++ b/src/home/bot/wrapper.py @@ -0,0 +1,339 @@ +import logging +import traceback + +from html import escape +from telegram import ( + Update, + ParseMode, + ReplyKeyboardMarkup, + CallbackQuery, + User, +) +from telegram.ext import ( + Updater, + Filters, + BaseFilter, + Handler, + CommandHandler, + MessageHandler, + CallbackQueryHandler, + CallbackContext, + ConversationHandler +) +from telegram.error import TimedOut +from ..config import config +from typing import Optional, Union +from .store import Store +from .lang import LangPack +from ..api.types import BotType +from ..api import WebAPIClient +from .reporting import ReportingHelper + +logger = logging.getLogger(__name__) +languages = { + 'en': 'English', + 'ru': 'Русский' +} +LANG_STARTED = range(1) +user_filter: Optional[BaseFilter] = None + + +def default_langpack() -> LangPack: + lang = LangPack() + lang.en( + start_message="Select command on the keyboard.", + unknown_message="Unknown message", + cancel="Cancel", + select_language="Select language on the keyboard.", + invalid_language="Invalid language. Please try again.", + language_saved='Saved.', + ) + lang.ru( + start_message="Выберите команду на клавиатуре.", + unknown_message="Неизвестная команда", + cancel="Отмена", + select_language="Выберите язык на клавиатуре.", + invalid_language="Неверный язык. Пожалуйста, попробуйте снова", + language_saved="Настройки сохранены." + ) + return lang + + +def init_user_filter(): + global user_filter + if user_filter is None: + if 'users' in config['bot']: + logger.info('allowed users: ' + str(config['bot']['users'])) + user_filter = Filters.user(config['bot']['users']) + else: + user_filter = Filters.all # not sure if this is correct + + +def text_filter(*args): + init_user_filter() + return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & user_filter + + +def exc2text(e: Exception) -> str: + tb = ''.join(traceback.format_tb(e.__traceback__)) + return f'{e.__class__.__name__}: ' + escape(str(e)) + "\n\n" + escape(tb) + + +class IgnoreMarkup: + pass + + +class Context: + _update: Optional[Update] + _callback_context: Optional[CallbackContext] + _markup_getter: callable + _lang: LangPack + _store: Optional[Store] + _user_lang: Optional[str] + + def __init__(self, + update: Optional[Update], + callback_context: Optional[CallbackContext], + markup_getter: callable, + lang: LangPack, + store: Optional[Store]): + self._update = update + self._callback_context = callback_context + self._markup_getter = markup_getter + self._lang = lang + self._store = store + self._user_lang = None + + def reply(self, text, markup=None): + if markup is None: + markup = self._markup_getter(self) + kwargs = dict(parse_mode=ParseMode.HTML) + if not isinstance(markup, IgnoreMarkup): + kwargs['reply_markup'] = markup + self._update.message.reply_text(text, **kwargs) + + def reply_exc(self, e: Exception) -> None: + self.reply(exc2text(e)) + + def answer(self, text: str = None): + self.callback_query.answer(text) + + def edit(self, text, markup=None): + kwargs = dict(parse_mode=ParseMode.HTML) + if not isinstance(markup, IgnoreMarkup): + kwargs['reply_markup'] = markup + self.callback_query.edit_message_text(text, **kwargs) + + @property + def text(self) -> str: + return self._update.message.text + + @property + def callback_query(self) -> CallbackQuery: + return self._update.callback_query + + @property + def args(self) -> Optional[list[str]]: + return self._callback_context.args + + @property + def user_id(self) -> int: + return self.user.id + + @property + def user(self) -> User: + return self._update.effective_user + + @property + def user_lang(self) -> str: + if self._user_lang is None: + self._user_lang = self._store.get_user_lang(self.user_id) + return self._user_lang + + def lang(self, key: str, *args) -> str: + return self._lang.get(key, self.user_lang, *args) + + def is_callback_context(self) -> bool: + return self._update.callback_query and self._update.callback_query.data and self._update.callback_query.data != '' + + +class Wrapper: + store: Optional[Store] + updater: Updater + lang: LangPack + reporting: Optional[ReportingHelper] + + def __init__(self): + self.updater = Updater(config['bot']['token'], + request_kwargs={'read_timeout': 6, 'connect_timeout': 7}) + self.lang = default_langpack() + self.store = Store() + self.reporting = None + + init_user_filter() + + dispatcher = self.updater.dispatcher + dispatcher.add_handler(CommandHandler('start', self.wrap(self.start), user_filter)) + + # transparently log all messages + self.add_handler(MessageHandler(Filters.all & user_filter, self.logging_message_handler), group=10) + self.add_handler(CallbackQueryHandler(self.logging_callback_handler), group=10) + + def run(self): + self._lang_setup() + self.updater.dispatcher.add_handler( + MessageHandler(Filters.all & user_filter, self.wrap(self.any)) + ) + + # start the bot + self.updater.start_polling() + + # run the bot until the user presses Ctrl-C or the process receives SIGINT, SIGTERM or SIGABRT + self.updater.idle() + + def enable_logging(self, bot_type: BotType): + api = WebAPIClient(timeout=3) + api.enable_async() + + self.reporting = ReportingHelper(api, bot_type) + + def logging_message_handler(self, update: Update, context: CallbackContext): + if self.reporting is None: + return + + self.reporting.report(update.message) + + def logging_callback_handler(self, update: Update, context: CallbackContext): + if self.reporting is None: + return + + self.reporting.report(update.callback_query.message, text=update.callback_query.data) + + def wrap(self, f: callable): + def handler(update: Update, context: CallbackContext): + ctx = Context(update, + callback_context=context, + markup_getter=self.markup, + lang=self.lang, + store=self.store) + + try: + return f(ctx) + except Exception as e: + if not self.exception_handler(e, ctx) and not isinstance(e, TimedOut): + logger.exception(e) + if not ctx.is_callback_context(): + ctx.reply_exc(e) + else: + self.notify_user(ctx.user_id, exc2text(e)) + + return handler + + def add_handler(self, handler: Handler, group=0): + self.updater.dispatcher.add_handler(handler, group=group) + + def start(self, ctx: Context): + if 'start_message' not in self.lang: + ctx.reply('Please define start_message or override start()') + return + + ctx.reply(ctx.lang('start_message')) + + def any(self, ctx: Context): + if 'invalid_command' not in self.lang: + ctx.reply('Please define invalid_command or override any()') + return + + ctx.reply(ctx.lang('invalid_command')) + + def markup(self, ctx: Optional[Context]) -> Optional[ReplyKeyboardMarkup]: + return None + + def exception_handler(self, e: Exception, ctx: Context) -> Optional[bool]: + pass + + def notify_all(self, text_getter: callable, exclude: tuple[int] = ()) -> None: + if 'notify_users' not in config['bot']: + logger.error('notify_all() called but no notify_users directive found in the config') + return + + for user_id in config['bot']['notify_users']: + if user_id in exclude: + continue + + text = text_getter(self.store.get_user_lang(user_id)) + self.updater.bot.send_message(chat_id=user_id, + text=text, + parse_mode='HTML') + + def notify_user(self, user_id: int, text: Union[str, Exception]) -> None: + if isinstance(text, Exception): + text = exc2text(text) + self.updater.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML') + + def send_audio(self, user_id, **kwargs): + self.updater.bot.send_audio(chat_id=user_id, **kwargs) + + def send_file(self, user_id, **kwargs): + self.updater.bot.send_document(chat_id=user_id, **kwargs) + + # + # Language Selection + # + + def _lang_setup(self): + supported = self.lang.languages + if len(supported) > 1: + cancel_filter = Filters.text(self.lang.all('cancel')) + + self.add_handler(ConversationHandler( + entry_points=[CommandHandler('lang', self.wrap(self._lang_command), user_filter)], + states={ + LANG_STARTED: [ + *list(map(lambda key: MessageHandler(text_filter(languages[key]), + self.wrap(self._lang_input)), supported)), + MessageHandler(user_filter & ~cancel_filter, self.wrap(self._lang_invalid_input)) + ] + }, + fallbacks=[MessageHandler(user_filter & cancel_filter, self.wrap(self._lang_cancel_input))] + )) + + def _lang_command(self, ctx: Context): + logger.debug(f'current language: {ctx.user_lang}') + + buttons = [] + for name in languages.values(): + buttons.append(name) + markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) + + ctx.reply(ctx.lang('select_language'), markup=markup) + return LANG_STARTED + + def _lang_input(self, ctx: Context): + lang = None + for key, value in languages.items(): + if value == ctx.text: + lang = key + break + + if lang is None: + ValueError('could not find the language') + + self.store.set_user_lang(ctx.user_id, lang) + + ctx.reply(ctx.lang('language_saved'), markup=IgnoreMarkup()) + + self.start(ctx) + return ConversationHandler.END + + def _lang_invalid_input(self, ctx: Context): + ctx.reply(self.lang('invalid_language'), markup=IgnoreMarkup()) + return LANG_STARTED + + def _lang_cancel_input(self, ctx: Context): + self.start(ctx) + return ConversationHandler.END + + @property + def user_filter(self): + return user_filter |