summaryrefslogtreecommitdiff
path: root/src/home/bot
diff options
context:
space:
mode:
Diffstat (limited to 'src/home/bot')
-rw-r--r--src/home/bot/__init__.py6
-rw-r--r--src/home/bot/errors.py2
-rw-r--r--src/home/bot/lang.py76
-rw-r--r--src/home/bot/reporting.py22
-rw-r--r--src/home/bot/store.py80
-rw-r--r--src/home/bot/util.py57
-rw-r--r--src/home/bot/wrapper.py339
7 files changed, 582 insertions, 0 deletions
diff --git a/src/home/bot/__init__.py b/src/home/bot/__init__.py
new file mode 100644
index 0000000..5e68af7
--- /dev/null
+++ b/src/home/bot/__init__.py
@@ -0,0 +1,6 @@
+from .reporting import ReportingHelper
+from .lang import LangPack
+from .wrapper import Wrapper, Context, text_filter
+from .store import Store
+from .errors import *
+from .util import command_usage, user_any_name \ No newline at end of file
diff --git a/src/home/bot/errors.py b/src/home/bot/errors.py
new file mode 100644
index 0000000..74eee6f
--- /dev/null
+++ b/src/home/bot/errors.py
@@ -0,0 +1,2 @@
+class StoreNotEnabledError(Exception):
+ pass \ No newline at end of file
diff --git a/src/home/bot/lang.py b/src/home/bot/lang.py
new file mode 100644
index 0000000..2f10358
--- /dev/null
+++ b/src/home/bot/lang.py
@@ -0,0 +1,76 @@
+import logging
+
+from typing import Union, Optional
+
+logger = logging.getLogger(__name__)
+
+
+class LangStrings(dict):
+ _lang: Optional[str]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._lang = None
+
+ def setlang(self, lang: str):
+ self._lang = lang
+
+ def __missing__(self, key):
+ logger.warning(f'key {key} is missing in language {self._lang}')
+ return '{%s}' % key
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError(f'setting translation strings this way is prohibited (was trying to set {key}={value})')
+
+
+class LangPack:
+ strings: dict[str, LangStrings[str, str]]
+ default_lang: str
+
+ def __init__(self):
+ self.strings = {}
+ self.default_lang = 'en'
+
+ def ru(self, **kwargs) -> None:
+ self.set(kwargs, 'ru')
+
+ def en(self, **kwargs) -> None:
+ self.set(kwargs, 'en')
+
+ def set(self,
+ strings: Union[LangStrings, dict],
+ lang: str) -> None:
+
+ if isinstance(strings, dict) and not isinstance(strings, LangStrings):
+ strings = LangStrings(**strings)
+ strings.setlang(lang)
+
+ if lang not in self.strings:
+ self.strings[lang] = strings
+ else:
+ self.strings[lang].update(strings)
+
+ def all(self, key):
+ result = []
+ for strings in self.strings.values():
+ result.append(strings[key])
+ return result
+
+ @property
+ def languages(self) -> list[str]:
+ return list(self.strings.keys())
+
+ def get(self, key: str, lang: str, *args) -> str:
+ return self.strings[lang][key] % args
+
+ def __call__(self, *args, **kwargs):
+ return self.strings[self.default_lang][args[0]]
+
+ def __getitem__(self, key):
+ return self.strings[self.default_lang][key]
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError('setting translation strings this way is prohibited')
+
+ def __contains__(self, key):
+ return key in self.strings[self.default_lang]
diff --git a/src/home/bot/reporting.py b/src/home/bot/reporting.py
new file mode 100644
index 0000000..df3da2a
--- /dev/null
+++ b/src/home/bot/reporting.py
@@ -0,0 +1,22 @@
+import logging
+
+from telegram import Message
+from ..api import WebAPIClient as APIClient
+from ..api.errors import ApiResponseError
+from ..api.types import BotType
+
+logger = logging.getLogger(__name__)
+
+
+class ReportingHelper:
+ def __init__(self, client: APIClient, bot_type: BotType):
+ self.client = client
+ self.bot_type = bot_type
+
+ def report(self, message, text: str = None) -> None:
+ if text is None:
+ text = message.text
+ try:
+ self.client.log_bot_request(self.bot_type, message.chat_id, text)
+ except ApiResponseError as error:
+ logger.exception(error)
diff --git a/src/home/bot/store.py b/src/home/bot/store.py
new file mode 100644
index 0000000..aeedc47
--- /dev/null
+++ b/src/home/bot/store.py
@@ -0,0 +1,80 @@
+import sqlite3
+import os.path
+import logging
+
+from ..config import config
+
+logger = logging.getLogger(__name__)
+
+
+def _get_database_path() -> str:
+ return os.path.join(os.environ['HOME'], '.config', config.app_name, 'bot.db')
+
+
+class Store:
+ SCHEMA_VERSION = 1
+
+ def __init__(self):
+ self.sqlite = sqlite3.connect(_get_database_path(), check_same_thread=False)
+
+ sqlite_version = self._get_sqlite_version()
+ logger.info(f'SQLite version: {sqlite_version}')
+
+ schema_version = self._get_schema_version()
+ logger.info(f'Schema version: {schema_version}')
+
+ if schema_version < 1:
+ self._database_init()
+ elif schema_version < Store.SCHEMA_VERSION:
+ self._database_upgrade(Store.SCHEMA_VERSION)
+
+ def __del__(self):
+ if self.sqlite:
+ self.sqlite.commit()
+ self.sqlite.close()
+
+ def _get_sqlite_version(self) -> str:
+ cursor = self.sqlite.cursor()
+ cursor.execute("SELECT sqlite_version()")
+
+ return cursor.fetchone()[0]
+
+ def _get_schema_version(self) -> int:
+ cursor = self.sqlite.execute('PRAGMA user_version')
+ return int(cursor.fetchone()[0])
+
+ def _set_schema_version(self, v) -> None:
+ self.sqlite.execute('PRAGMA user_version={:d}'.format(v))
+ logger.info(f'Schema set to {v}')
+
+ def _database_init(self) -> None:
+ cursor = self.sqlite.cursor()
+ cursor.execute("""CREATE TABLE IF NOT EXISTS users (
+ id INTEGER PRIMARY KEY,
+ lang TEXT NOT NULL
+ )""")
+ self.sqlite.commit()
+ self._set_schema_version(1)
+
+ def _database_upgrade(self, version: int) -> None:
+ # do the upgrade here
+
+ # self.sqlite.commit()
+ self._set_schema_version(version)
+
+ def get_user_lang(self, user_id: int, default: str = 'en') -> str:
+ cursor = self.sqlite.cursor()
+ cursor.execute('SELECT lang FROM users WHERE id=?', (user_id,))
+ row = cursor.fetchone()
+
+ if row is None:
+ cursor.execute('INSERT INTO users (id, lang) VALUES (?, ?)', (user_id, default))
+ self.sqlite.commit()
+ return default
+ else:
+ return row[0]
+
+ def set_user_lang(self, user_id: int, lang: str) -> None:
+ cursor = self.sqlite.cursor()
+ cursor.execute('UPDATE users SET lang=? WHERE id=?', (lang, user_id))
+ self.sqlite.commit() \ No newline at end of file
diff --git a/src/home/bot/util.py b/src/home/bot/util.py
new file mode 100644
index 0000000..4f80a67
--- /dev/null
+++ b/src/home/bot/util.py
@@ -0,0 +1,57 @@
+from telegram import User
+from .lang import LangStrings
+
+_strings = {
+ 'en': LangStrings(
+ usage='Usage',
+ arguments='Arguments'
+ ),
+ 'ru': LangStrings(
+ usage='Использование',
+ arguments='Аргументы'
+ )
+}
+
+
+def command_usage(command: str, arguments: dict, language='en') -> str:
+ if language not in _strings:
+ raise ValueError('unsupported language')
+
+ blocks = []
+ argument_names = []
+ argument_lines = []
+ for k, v in arguments.items():
+ argument_names.append(k)
+ argument_lines.append(
+ f'<code>{k}</code>: {v}'
+ )
+
+ command = f'/{command}'
+ if argument_names:
+ command += ' ' + ' '.join(argument_names)
+
+ blocks.append(
+ f'<b>{_strings[language]["usage"]}</b>\n'
+ f'<code>{command}</code>'
+ )
+
+ if argument_lines:
+ blocks.append(
+ f'<b>{_strings[language]["arguments"]}</b>\n' + '\n'.join(argument_lines)
+ )
+
+ return '\n\n'.join(blocks)
+
+
+def user_any_name(user: User) -> str:
+ name = [user.first_name, user.last_name]
+ name = list(filter(lambda s: s is not None, name))
+ name = ' '.join(name).strip()
+
+ if not name:
+ name = user.username
+
+ if not name:
+ name = str(user.id)
+
+ return name
diff --git a/src/home/bot/wrapper.py b/src/home/bot/wrapper.py
new file mode 100644
index 0000000..8651e90
--- /dev/null
+++ b/src/home/bot/wrapper.py
@@ -0,0 +1,339 @@
+import logging
+import traceback
+
+from html import escape
+from telegram import (
+ Update,
+ ParseMode,
+ ReplyKeyboardMarkup,
+ CallbackQuery,
+ User,
+)
+from telegram.ext import (
+ Updater,
+ Filters,
+ BaseFilter,
+ Handler,
+ CommandHandler,
+ MessageHandler,
+ CallbackQueryHandler,
+ CallbackContext,
+ ConversationHandler
+)
+from telegram.error import TimedOut
+from ..config import config
+from typing import Optional, Union
+from .store import Store
+from .lang import LangPack
+from ..api.types import BotType
+from ..api import WebAPIClient
+from .reporting import ReportingHelper
+
+logger = logging.getLogger(__name__)
+languages = {
+ 'en': 'English',
+ 'ru': 'Русский'
+}
+LANG_STARTED = range(1)
+user_filter: Optional[BaseFilter] = None
+
+
+def default_langpack() -> LangPack:
+ lang = LangPack()
+ lang.en(
+ start_message="Select command on the keyboard.",
+ unknown_message="Unknown message",
+ cancel="Cancel",
+ select_language="Select language on the keyboard.",
+ invalid_language="Invalid language. Please try again.",
+ language_saved='Saved.',
+ )
+ lang.ru(
+ start_message="Выберите команду на клавиатуре.",
+ unknown_message="Неизвестная команда",
+ cancel="Отмена",
+ select_language="Выберите язык на клавиатуре.",
+ invalid_language="Неверный язык. Пожалуйста, попробуйте снова",
+ language_saved="Настройки сохранены."
+ )
+ return lang
+
+
+def init_user_filter():
+ global user_filter
+ if user_filter is None:
+ if 'users' in config['bot']:
+ logger.info('allowed users: ' + str(config['bot']['users']))
+ user_filter = Filters.user(config['bot']['users'])
+ else:
+ user_filter = Filters.all # not sure if this is correct
+
+
+def text_filter(*args):
+ init_user_filter()
+ return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & user_filter
+
+
+def exc2text(e: Exception) -> str:
+ tb = ''.join(traceback.format_tb(e.__traceback__))
+ return f'{e.__class__.__name__}: ' + escape(str(e)) + "\n\n" + escape(tb)
+
+
+class IgnoreMarkup:
+ pass
+
+
+class Context:
+ _update: Optional[Update]
+ _callback_context: Optional[CallbackContext]
+ _markup_getter: callable
+ _lang: LangPack
+ _store: Optional[Store]
+ _user_lang: Optional[str]
+
+ def __init__(self,
+ update: Optional[Update],
+ callback_context: Optional[CallbackContext],
+ markup_getter: callable,
+ lang: LangPack,
+ store: Optional[Store]):
+ self._update = update
+ self._callback_context = callback_context
+ self._markup_getter = markup_getter
+ self._lang = lang
+ self._store = store
+ self._user_lang = None
+
+ def reply(self, text, markup=None):
+ if markup is None:
+ markup = self._markup_getter(self)
+ kwargs = dict(parse_mode=ParseMode.HTML)
+ if not isinstance(markup, IgnoreMarkup):
+ kwargs['reply_markup'] = markup
+ self._update.message.reply_text(text, **kwargs)
+
+ def reply_exc(self, e: Exception) -> None:
+ self.reply(exc2text(e))
+
+ def answer(self, text: str = None):
+ self.callback_query.answer(text)
+
+ def edit(self, text, markup=None):
+ kwargs = dict(parse_mode=ParseMode.HTML)
+ if not isinstance(markup, IgnoreMarkup):
+ kwargs['reply_markup'] = markup
+ self.callback_query.edit_message_text(text, **kwargs)
+
+ @property
+ def text(self) -> str:
+ return self._update.message.text
+
+ @property
+ def callback_query(self) -> CallbackQuery:
+ return self._update.callback_query
+
+ @property
+ def args(self) -> Optional[list[str]]:
+ return self._callback_context.args
+
+ @property
+ def user_id(self) -> int:
+ return self.user.id
+
+ @property
+ def user(self) -> User:
+ return self._update.effective_user
+
+ @property
+ def user_lang(self) -> str:
+ if self._user_lang is None:
+ self._user_lang = self._store.get_user_lang(self.user_id)
+ return self._user_lang
+
+ def lang(self, key: str, *args) -> str:
+ return self._lang.get(key, self.user_lang, *args)
+
+ def is_callback_context(self) -> bool:
+ return self._update.callback_query and self._update.callback_query.data and self._update.callback_query.data != ''
+
+
+class Wrapper:
+ store: Optional[Store]
+ updater: Updater
+ lang: LangPack
+ reporting: Optional[ReportingHelper]
+
+ def __init__(self):
+ self.updater = Updater(config['bot']['token'],
+ request_kwargs={'read_timeout': 6, 'connect_timeout': 7})
+ self.lang = default_langpack()
+ self.store = Store()
+ self.reporting = None
+
+ init_user_filter()
+
+ dispatcher = self.updater.dispatcher
+ dispatcher.add_handler(CommandHandler('start', self.wrap(self.start), user_filter))
+
+ # transparently log all messages
+ self.add_handler(MessageHandler(Filters.all & user_filter, self.logging_message_handler), group=10)
+ self.add_handler(CallbackQueryHandler(self.logging_callback_handler), group=10)
+
+ def run(self):
+ self._lang_setup()
+ self.updater.dispatcher.add_handler(
+ MessageHandler(Filters.all & user_filter, self.wrap(self.any))
+ )
+
+ # start the bot
+ self.updater.start_polling()
+
+ # run the bot until the user presses Ctrl-C or the process receives SIGINT, SIGTERM or SIGABRT
+ self.updater.idle()
+
+ def enable_logging(self, bot_type: BotType):
+ api = WebAPIClient(timeout=3)
+ api.enable_async()
+
+ self.reporting = ReportingHelper(api, bot_type)
+
+ def logging_message_handler(self, update: Update, context: CallbackContext):
+ if self.reporting is None:
+ return
+
+ self.reporting.report(update.message)
+
+ def logging_callback_handler(self, update: Update, context: CallbackContext):
+ if self.reporting is None:
+ return
+
+ self.reporting.report(update.callback_query.message, text=update.callback_query.data)
+
+ def wrap(self, f: callable):
+ def handler(update: Update, context: CallbackContext):
+ ctx = Context(update,
+ callback_context=context,
+ markup_getter=self.markup,
+ lang=self.lang,
+ store=self.store)
+
+ try:
+ return f(ctx)
+ except Exception as e:
+ if not self.exception_handler(e, ctx) and not isinstance(e, TimedOut):
+ logger.exception(e)
+ if not ctx.is_callback_context():
+ ctx.reply_exc(e)
+ else:
+ self.notify_user(ctx.user_id, exc2text(e))
+
+ return handler
+
+ def add_handler(self, handler: Handler, group=0):
+ self.updater.dispatcher.add_handler(handler, group=group)
+
+ def start(self, ctx: Context):
+ if 'start_message' not in self.lang:
+ ctx.reply('Please define start_message or override start()')
+ return
+
+ ctx.reply(ctx.lang('start_message'))
+
+ def any(self, ctx: Context):
+ if 'invalid_command' not in self.lang:
+ ctx.reply('Please define invalid_command or override any()')
+ return
+
+ ctx.reply(ctx.lang('invalid_command'))
+
+ def markup(self, ctx: Optional[Context]) -> Optional[ReplyKeyboardMarkup]:
+ return None
+
+ def exception_handler(self, e: Exception, ctx: Context) -> Optional[bool]:
+ pass
+
+ def notify_all(self, text_getter: callable, exclude: tuple[int] = ()) -> None:
+ if 'notify_users' not in config['bot']:
+ logger.error('notify_all() called but no notify_users directive found in the config')
+ return
+
+ for user_id in config['bot']['notify_users']:
+ if user_id in exclude:
+ continue
+
+ text = text_getter(self.store.get_user_lang(user_id))
+ self.updater.bot.send_message(chat_id=user_id,
+ text=text,
+ parse_mode='HTML')
+
+ def notify_user(self, user_id: int, text: Union[str, Exception]) -> None:
+ if isinstance(text, Exception):
+ text = exc2text(text)
+ self.updater.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML')
+
+ def send_audio(self, user_id, **kwargs):
+ self.updater.bot.send_audio(chat_id=user_id, **kwargs)
+
+ def send_file(self, user_id, **kwargs):
+ self.updater.bot.send_document(chat_id=user_id, **kwargs)
+
+ #
+ # Language Selection
+ #
+
+ def _lang_setup(self):
+ supported = self.lang.languages
+ if len(supported) > 1:
+ cancel_filter = Filters.text(self.lang.all('cancel'))
+
+ self.add_handler(ConversationHandler(
+ entry_points=[CommandHandler('lang', self.wrap(self._lang_command), user_filter)],
+ states={
+ LANG_STARTED: [
+ *list(map(lambda key: MessageHandler(text_filter(languages[key]),
+ self.wrap(self._lang_input)), supported)),
+ MessageHandler(user_filter & ~cancel_filter, self.wrap(self._lang_invalid_input))
+ ]
+ },
+ fallbacks=[MessageHandler(user_filter & cancel_filter, self.wrap(self._lang_cancel_input))]
+ ))
+
+ def _lang_command(self, ctx: Context):
+ logger.debug(f'current language: {ctx.user_lang}')
+
+ buttons = []
+ for name in languages.values():
+ buttons.append(name)
+ markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False)
+
+ ctx.reply(ctx.lang('select_language'), markup=markup)
+ return LANG_STARTED
+
+ def _lang_input(self, ctx: Context):
+ lang = None
+ for key, value in languages.items():
+ if value == ctx.text:
+ lang = key
+ break
+
+ if lang is None:
+ ValueError('could not find the language')
+
+ self.store.set_user_lang(ctx.user_id, lang)
+
+ ctx.reply(ctx.lang('language_saved'), markup=IgnoreMarkup())
+
+ self.start(ctx)
+ return ConversationHandler.END
+
+ def _lang_invalid_input(self, ctx: Context):
+ ctx.reply(self.lang('invalid_language'), markup=IgnoreMarkup())
+ return LANG_STARTED
+
+ def _lang_cancel_input(self, ctx: Context):
+ self.start(ctx)
+ return ConversationHandler.END
+
+ @property
+ def user_filter(self):
+ return user_filter