aboutsummaryrefslogtreecommitdiff
# SPDX-License-Identifier: GPL-3.0-or-later

# Haketilo proxy data and configuration (instantiatable HaketiloState subtype).
#
# This file is part of Hydrilla&Haketilo.
#
# Copyright (C) 2022 Wojtek Kosior
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
#
# I, Wojtek Kosior, thereby promise not to sue for violation of this
# file's license. Although I request that you do not make use of this
# code in a proprietary program, I am not going to enforce this in
# court.

"""
This module contains logic for keeping track of all settings, rules, mappings
and resources.
"""

import sqlite3
import secrets
import typing as t
import dataclasses as dc

from pathlib import Path

from ...exceptions import HaketiloException
from ...translations import smart_gettext as _
from ... import url_patterns
from ... import item_infos
from .. import state as st
from .. import policies
from .. import simple_dependency_satisfying as sds
from . import base
from . import rules
from . import items
from . import repos
from . import payloads
from . import _operations


here = Path(__file__).resolve().parent


def _add_popup_settings_columns(cursor: sqlite3.Cursor) -> None:
    for page_type in ('jsallowed', 'jsblocked', 'payloadon'):
        cursor.execute(
            f'''
            ALTER TABLE general ADD COLUMN
            default_popup_{page_type}_onkeyboard BOOLEAN NOT NULL DEFAULT TRUE;
            '''
        )
        cursor.execute(
            f'''
            ALTER TABLE general ADD COLUMN
            default_popup_{page_type}_style CHAR(1) NOT NULL DEFAULT 'T'
            CHECK (default_popup_{page_type}_style IN ('D', 'T'));
            '''
        )

def _add_locale_column(cursor: sqlite3.Cursor) -> None:
    cursor.execute(
        '''
        ALTER TABLE general ADD COLUMN
        locale VARCHAR NOT NULL DEFAULT 'unknown';
        '''
    )

def _add_update_waiting_column(cursor: sqlite3.Cursor) -> None:
    cursor.execute(
        '''
        ALTER TABLE general ADD COLUMN
        update_waiting BOOLEAN NOT NULL DEFAULT TRUE;
        '''
    )

def _prepare_database(connection: sqlite3.Connection) -> None:
    cursor = connection.cursor()

    try:
        cursor.execute(
            '''
            SELECT
                    COUNT(name)
            FROM
                    sqlite_master
            WHERE
                    name = 'general' AND type = 'table';
            '''
        )

        (db_initialized,), = cursor.fetchall()

        if not db_initialized:
            cursor.executescript((here / 'tables.sql').read_text())

        cursor.execute('BEGIN TRANSACTION;')

        try:
            if db_initialized:
                # If db was initialized before we connected to it, we must check
                # what its schema version is.
                cursor.execute(
                    '''
                    SELECT
                            haketilo_version
                    FROM
                            general;
                    '''
                )

                (db_haketilo_version,) = cursor.fetchone()
                if db_haketilo_version != '3.0b1':
                    raise HaketiloException(_('err.proxy.unknown_db_schema'))

            popup_settings_columns_present = False
            locale_column_present = False
            update_waiting_column_present = False

            cursor.execute("PRAGMA TABLE_INFO('general')")
            for __cid, name, __type, __notnull, __dflt_value, __pk \
                in cursor.fetchall():
                if name == 'default_popup_jsallowed_onkeyboard':
                    popup_settings_columns_present = True

                if name == 'locale':
                    locale_column_present = True

                if name == 'update_waiting':
                    update_waiting_column_present = True

            if not popup_settings_columns_present:
                _add_popup_settings_columns(cursor)

            if not locale_column_present:
                _add_locale_column(cursor)

            if not update_waiting_column_present:
                _add_update_waiting_column(cursor)

            cursor.execute('COMMIT TRANSACTION;')
        except:
            cursor.execute('ROLLBACK TRANSACTION;')
            raise

        cursor.execute('PRAGMA FOREIGN_KEYS;')
        if cursor.fetchall() == []:
            raise HaketiloException(_('err.proxy.no_sqlite_foreign_keys'))

        cursor.execute('PRAGMA FOREIGN_KEYS=ON;')
    finally:
        cursor.close()


def load_settings(cursor: sqlite3.Cursor) -> st.HaketiloGlobalSettings:
    cursor.execute(
        '''
        SELECT
                default_allow_scripts,
                advanced_user,
                repo_refresh_seconds,
                mapping_use_mode,
                locale,
                update_waiting
        FROM
                general;
        '''
    )

    (default_allow_scripts, advanced_user, repo_refresh_seconds,
     mapping_use_mode, locale, update_waiting), = cursor.fetchall()

    popup_settings_dict = {}

    for page_type in ('jsallowed', 'jsblocked', 'payloadon'):
        try:
            cursor.execute(
                f'''
                SELECT
                    default_popup_{page_type}_onkeyboard,
                    default_popup_{page_type}_style
                FROM
                    general;
                '''
            )

            (onkeyboard, style), = cursor.fetchall()
        except:
            onkeyboard, style = True, 'T'

        popup_settings_dict[f'default_popup_{page_type}'] = st.PopupSettings(
            keyboard_trigger = onkeyboard,
            style            = st.PopupStyle(style)
        )

    return st.HaketiloGlobalSettings(
        default_allow_scripts = default_allow_scripts,
        advanced_user         = advanced_user,
        repo_refresh_seconds  = repo_refresh_seconds,
        mapping_use_mode      = st.MappingUseMode(mapping_use_mode),
        locale                = locale,
        update_waiting        = update_waiting,

        **popup_settings_dict
    )

@dc.dataclass
class ConcreteHaketiloState(base.HaketiloStateWithFields):
    def __post_init__(self) -> None:
        self.rebuild_structures()

    def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None:
        with self.cursor(transaction=(repo_id == 1)) as cursor:
            # This method without the repo_id argument exposed is part of the
            # state API. As such, calls with repo_id = 1 (imports of local
            # semirepo packages) create a new transaction. Calls with different
            # values of repo_id are assumed to originate from within the state
            # implementation code and expect an existing transaction. Here, we
            # verify the transaction is indeed present.
            assert self.connection.in_transaction

            _operations._load_packages_no_state_update(
                cursor          = cursor,
                malcontent_path = malcontent_path,
                repo_id         = repo_id
            )

            cursor.execute('UPDATE general SET update_waiting = TRUE;')
            self.settings = dc.replace(self.settings, update_waiting=True)

            self.rebuild_structures(rules=False)

    def count_orphan_items(self) -> st.OrphanItemsStats:
        with self.cursor() as cursor:
            cursor.execute(
                '''
                SELECT
                        COALESCE(SUM(i.type = 'M'), 0),
                        COALESCE(SUM(i.type = 'R'), 0)
                FROM
                             item_versions     AS iv
                        JOIN items             AS i  USING (item_id)
                        JOIN orphan_iterations AS oi USING (repo_iteration_id)
                WHERE
                        iv.active != 'R';
                '''
            )

            (orphan_mappings, orphan_resources), = cursor.fetchall()

        return st.OrphanItemsStats(orphan_mappings, orphan_resources)

    def prune_orphan_items(self) -> None:
        with self.cursor(transaction=True) as cursor:
            _operations.prune_orphans(cursor, aggressive=True)

            self.recompute_dependencies()

    def soft_prune_orphan_items(self) -> None:
        with self.cursor() as cursor:
            assert self.connection.in_transaction

            _operations.prune_orphans(cursor)

    def recompute_dependencies(
            self,
            unlocked_required_mappings: base.NoLockArg = []
    ) -> None:
        with self.cursor() as cursor:
            assert self.connection.in_transaction

            _operations._recompute_dependencies_no_state_update(
                cursor                     = cursor,
                unlocked_required_mappings = unlocked_required_mappings
            )

            if unlocked_required_mappings == 'all_mappings_unlocked':
                cursor.execute('UPDATE general SET update_waiting = FALSE;')
                self.settings = dc.replace(self.settings, update_waiting=False)

            self.rebuild_structures(rules=False)

    def upate_all_items(self) -> None:
        with self.cursor(transaction=True):
            self.recompute_dependencies('all_mappings_unlocked')

    def pull_missing_files(self) -> None:
        with self.cursor() as cursor:
            assert self.connection.in_transaction

            _operations.pull_missing_files(cursor)

    def _rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
        new_policy_tree = base.PolicyTree()

        web_ui_main_pattern = 'http*://hkt.mitm.it/***'
        web_ui_main_factory = policies.WebUIMainPolicyFactory(builtin=True)

        for parsed_pattern in url_patterns.parse_pattern(web_ui_main_pattern):
            new_policy_tree = new_policy_tree.register(
                parsed_pattern = parsed_pattern,
                item           = web_ui_main_factory
            )

        web_ui_landing_pattern = f'{self.efective_listen_addr}/***'
        web_ui_landing_factory = policies.WebUILandingPolicyFactory(
            builtin = True
        )

        try:
            parsed_pattern, = url_patterns.parse_pattern(web_ui_landing_pattern)
        except url_patterns.HaketiloURLException:
            fmt = _('warn.proxy.failed_to_register_landing_page_at_{}')
            self.logger.warn(fmt.format(web_ui_landing_pattern))
        else:
            new_policy_tree = new_policy_tree.register(
                parsed_pattern = parsed_pattern,
                item           = web_ui_landing_factory
            )

        mitm_it_page_pattern = 'http://mitm.it/***'
        mitm_it_page_factory = policies.MitmItPagePolicyFactory()

        parsed_pattern, = url_patterns.parse_pattern(mitm_it_page_pattern)
        new_policy_tree = new_policy_tree.register(
            parsed_pattern = parsed_pattern,
            item           = mitm_it_page_factory
        )

        # Put script blocking/allowing rules in policy tree.
        cursor.execute('SELECT pattern, allow_scripts FROM rules;')

        for pattern, allow_scripts in cursor.fetchall():
            for parsed_pattern in url_patterns.parse_pattern(pattern):
                factory: policies.PolicyFactory
                if allow_scripts:
                    factory = policies.RuleAllowPolicyFactory(
                        builtin = False,
                        pattern = parsed_pattern
                    )
                else:
                    factory = policies.RuleBlockPolicyFactory(
                        builtin = False,
                        pattern = parsed_pattern
                    )

                new_policy_tree = new_policy_tree.register(
                    parsed_pattern = parsed_pattern,
                    item           = factory
                )

        # Put script payload rules in policy tree.
        cursor.execute(
            '''
            SELECT
                    p.payload_id,
                    p.pattern,
                    p.eval_allowed,
                    p.cors_bypass_allowed,
                    ms.enabled,
                    i.identifier
            FROM
                         payloads         AS p
                    JOIN item_versions    AS iv
                            ON p.mapping_item_id = iv.item_version_id
                    JOIN items            AS i
                            USING (item_id)
                    JOIN mapping_statuses AS ms
                            USING (item_id);
            '''
        )

        new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}

        for (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
             enabled_status, identifier) in cursor.fetchall():
            payload_ref = payloads.ConcretePayloadRef(str(payload_id_int), self)

            previous_data = self.payloads_data.get(payload_ref)
            if previous_data is not None:
                token = previous_data.unique_token
            else:
                token = secrets.token_urlsafe(8)

            payload_key = st.PayloadKey(payload_ref, identifier)

            for parsed_pattern in url_patterns.parse_pattern(pattern):
                new_policy_tree = new_policy_tree.register_payload(
                    parsed_pattern,
                    payload_key,
                    token
                )

                pattern_path_segments = parsed_pattern.path_segments

            payload_data = st.PayloadData(
                ref                   = payload_ref,
                explicitly_enabled    = enabled_status == 'E',
                unique_token          = token,
                mapping_identifier    = identifier,
                pattern               = pattern,
                pattern_path_segments = pattern_path_segments,
                eval_allowed          = eval_allowed,
                cors_bypass_allowed   = cors_bypass_allowed,
                global_secret         = self.secret
            )

            new_payloads_data[payload_ref] = payload_data

        self.policy_tree   = new_policy_tree
        self.payloads_data = new_payloads_data

    def rebuild_structures(self, *, payloads: bool = True, rules: bool = True) \
        -> None:
        # The `payloads` and `rules` args will be useful for optimization but
        # for now we're not yet using them.
        with self.cursor() as cursor:
            self._rebuild_structures(cursor)

    def rule_store(self) -> st.RuleStore:
        return rules.ConcreteRuleStore(self)

    def repo_store(self) -> st.RepoStore:
        return repos.ConcreteRepoStore(self)

    def mapping_store(self) -> st.MappingStore:
        return items.ConcreteMappingStore(self)

    def mapping_version_store(self) -> st.MappingVersionStore:
        return items.ConcreteMappingVersionStore(self)

    def resource_store(self) -> st.ResourceStore:
        return items.ConcreteResourceStore(self)

    def resource_version_store(self) -> st.ResourceVersionStore:
        return items.ConcreteResourceVersionStore(self)

    def payload_store(self) -> st.PayloadStore:
        return payloads.ConcretePayloadStore(self)

    def get_secret(self) -> bytes:
        return self.secret

    def get_settings(self) -> st.HaketiloGlobalSettings:
        with self.lock:
            return self.settings

    def update_settings(
            self,
            *,
            mapping_use_mode:        t.Optional[st.MappingUseMode]    = None,
            default_allow_scripts:   t.Optional[bool]                 = None,
            advanced_user:           t.Optional[bool]                 = None,
            repo_refresh_seconds:    t.Optional[int]                  = None,
            locale:                  t.Optional[str]                  = None,
            default_popup_settings:  t.Mapping[str, st.PopupSettings] = {}
    ) -> None:
        with self.cursor(transaction=True) as cursor:
            def set_opt(col_name: str, val: t.Union[bool, int, str]) -> None:
                cursor.execute(f'UPDATE general SET {col_name} = ?;', (val,))

            if mapping_use_mode is not None:
                set_opt('mapping_use_mode', mapping_use_mode.value)
            if default_allow_scripts is not None:
                set_opt('default_allow_scripts', default_allow_scripts)
            if advanced_user is not None:
                set_opt('advanced_user', advanced_user)
            if repo_refresh_seconds is not None:
                set_opt('repo_refresh_seconds', repo_refresh_seconds)
            if locale is not None:
                set_opt('locale', locale)

            for page_type in ('jsallowed', 'jsblocked', 'payloadon'):
                popup_settings = default_popup_settings.get(page_type)
                if popup_settings is not None:
                    trigger_col_name = f'default_popup_{page_type}_onkeyboard'
                    set_opt(trigger_col_name, popup_settings.keyboard_trigger)

                    style_col_name = f'default_popup_{page_type}_style'
                    set_opt(style_col_name, popup_settings.style.value)

            self.settings = load_settings(cursor)

    @staticmethod
    def make(
            store_dir:   Path,
            listen_host: str,
            listen_port: int,
            logger:      st.Logger
    ) -> 'ConcreteHaketiloState':
        store_dir.mkdir(parents=True, exist_ok=True)

        connection = sqlite3.connect(
            str(store_dir / 'sqlite3.db'),
            isolation_level   = None,
            check_same_thread = False
        )

        _prepare_database(connection)

        global_settings = load_settings(connection.cursor())

        return ConcreteHaketiloState(
            store_dir    = store_dir,
            _logger      = logger,
            _listen_host = listen_host,
            _listen_port = listen_port,
            connection   = connection,
            settings     = global_settings
        )