aboutsummaryrefslogtreecommitdiff
path: root/include/py/homekit/database/sqlite.py
blob: 8b0c44cebf2be80b241cc5de7e672d8a5c50d0d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import sqlite3
import os.path
import logging

from ._base import get_data_root_directory
from ..config import config, is_development_mode


def _get_database_path(name: str) -> str:
    return os.path.join(
        get_data_root_directory(),
        f'{name}.db')


class SQLiteBase:
    SCHEMA = 1

    def __init__(self, name=None, path=None, check_same_thread=False):
        if not path:
            if not name:
                name = config.app_config['database_name']
            database_path = _get_database_path(name)
        else:
            database_path = path
        if not os.path.exists(os.path.dirname(database_path)):
            os.makedirs(os.path.dirname(database_path))

        self.logger = logging.getLogger(self.__class__.__name__)
        self.sqlite = sqlite3.connect(database_path, check_same_thread=check_same_thread)

        if is_development_mode():
            self.sql_logger = logging.getLogger(self.__class__.__name__)
            self.sql_logger.setLevel('TRACE')
            self.sqlite.set_trace_callback(self.sql_logger.trace)

        sqlite_version = self._get_sqlite_version()
        self.logger.debug(f'SQLite version: {sqlite_version}')

        schema_version = self.schema_get_version()
        self.logger.debug(f'Schema version: {schema_version}')

        self.schema_init(schema_version)
        self.schema_set_version(self.SCHEMA)

    def __del__(self):
        if self.sqlite:
            self.sqlite.commit()
            self.sqlite.close()

    def _get_sqlite_version(self) -> str:
        cursor = self.sqlite.cursor()
        cursor.execute("SELECT sqlite_version()")
        return cursor.fetchone()[0]

    def schema_get_version(self) -> int:
        cursor = self.sqlite.execute('PRAGMA user_version')
        return int(cursor.fetchone()[0])

    def schema_set_version(self, v) -> None:
        self.sqlite.execute('PRAGMA user_version={:d}'.format(v))
        self.logger.info(f'Schema set to {v}')

    def cursor(self) -> sqlite3.Cursor:
        return self.sqlite.cursor()

    def commit(self) -> None:
        return self.sqlite.commit()

    def schema_init(self, version: int) -> None:
        raise ValueError(f'{self.__class__.__name__}: must override schema_init')