diff options
Diffstat (limited to 'src/hydrilla/proxy/state_impl/concrete_state.py')
-rw-r--r-- | src/hydrilla/proxy/state_impl/concrete_state.py | 523 |
1 files changed, 523 insertions, 0 deletions
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py new file mode 100644 index 0000000..89a2eb2 --- /dev/null +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -0,0 +1,523 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (instantiatable HaketiloState subtype). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use of this +# code in a proprietary program, I am not going to enforce this in +# court. + +""" +This module contains logic for keeping track of all settings, rules, mappings +and resources. +""" + +import sqlite3 +import secrets +import typing as t +import dataclasses as dc + +from pathlib import Path + +from ...exceptions import HaketiloException +from ...translations import smart_gettext as _ +from ... import url_patterns +from ... import item_infos +from .. import state as st +from .. import policies +from .. import simple_dependency_satisfying as sds +from . import base +from . import rules +from . import items +from . import repos +from . import payloads +from . import _operations + + +here = Path(__file__).resolve().parent + + +def _add_popup_settings_columns(cursor: sqlite3.Cursor) -> None: + for page_type in ('jsallowed', 'jsblocked', 'payloadon'): + cursor.execute( + f''' + ALTER TABLE general ADD COLUMN + default_popup_{page_type}_onkeyboard BOOLEAN NOT NULL DEFAULT TRUE; + ''' + ) + cursor.execute( + f''' + ALTER TABLE general ADD COLUMN + default_popup_{page_type}_style CHAR(1) NOT NULL DEFAULT 'T' + CHECK (default_popup_{page_type}_style IN ('D', 'T')); + ''' + ) + +def _add_locale_column(cursor: sqlite3.Cursor) -> None: + cursor.execute( + ''' + ALTER TABLE general ADD COLUMN + locale VARCHAR NOT NULL DEFAULT 'unknown'; + ''' + ) + +def _add_update_waiting_column(cursor: sqlite3.Cursor) -> None: + cursor.execute( + ''' + ALTER TABLE general ADD COLUMN + update_waiting BOOLEAN NOT NULL DEFAULT TRUE; + ''' + ) + +def _prepare_database(connection: sqlite3.Connection) -> None: + cursor = connection.cursor() + + try: + cursor.execute( + ''' + SELECT + COUNT(name) + FROM + sqlite_master + WHERE + name = 'general' AND type = 'table'; + ''' + ) + + (db_initialized,), = cursor.fetchall() + + if not db_initialized: + cursor.executescript((here / 'tables.sql').read_text()) + + cursor.execute('BEGIN TRANSACTION;') + + try: + if db_initialized: + # If db was initialized before we connected to it, we must check + # what its schema version is. + cursor.execute( + ''' + SELECT + haketilo_version + FROM + general; + ''' + ) + + (db_haketilo_version,) = cursor.fetchone() + if db_haketilo_version != '3.0b1': + raise HaketiloException(_('err.proxy.unknown_db_schema')) + + popup_settings_columns_present = False + locale_column_present = False + update_waiting_column_present = False + + cursor.execute("PRAGMA TABLE_INFO('general')") + for __cid, name, __type, __notnull, __dflt_value, __pk \ + in cursor.fetchall(): + if name == 'default_popup_jsallowed_onkeyboard': + popup_settings_columns_present = True + + if name == 'locale': + locale_column_present = True + + if name == 'update_waiting': + update_waiting_column_present = True + + if not popup_settings_columns_present: + _add_popup_settings_columns(cursor) + + if not locale_column_present: + _add_locale_column(cursor) + + if not update_waiting_column_present: + _add_update_waiting_column(cursor) + + cursor.execute('COMMIT TRANSACTION;') + except: + cursor.execute('ROLLBACK TRANSACTION;') + raise + + cursor.execute('PRAGMA FOREIGN_KEYS;') + if cursor.fetchall() == []: + raise HaketiloException(_('err.proxy.no_sqlite_foreign_keys')) + + cursor.execute('PRAGMA FOREIGN_KEYS=ON;') + finally: + cursor.close() + + +def load_settings(cursor: sqlite3.Cursor) -> st.HaketiloGlobalSettings: + cursor.execute( + ''' + SELECT + default_allow_scripts, + advanced_user, + repo_refresh_seconds, + mapping_use_mode, + locale, + update_waiting + FROM + general; + ''' + ) + + (default_allow_scripts, advanced_user, repo_refresh_seconds, + mapping_use_mode, locale, update_waiting), = cursor.fetchall() + + popup_settings_dict = {} + + for page_type in ('jsallowed', 'jsblocked', 'payloadon'): + try: + cursor.execute( + f''' + SELECT + default_popup_{page_type}_onkeyboard, + default_popup_{page_type}_style + FROM + general; + ''' + ) + + (onkeyboard, style), = cursor.fetchall() + except: + onkeyboard, style = True, 'T' + + popup_settings_dict[f'default_popup_{page_type}'] = st.PopupSettings( + keyboard_trigger = onkeyboard, + style = st.PopupStyle(style) + ) + + return st.HaketiloGlobalSettings( + default_allow_scripts = default_allow_scripts, + advanced_user = advanced_user, + repo_refresh_seconds = repo_refresh_seconds, + mapping_use_mode = st.MappingUseMode(mapping_use_mode), + locale = locale, + update_waiting = update_waiting, + + **popup_settings_dict + ) + +@dc.dataclass +class ConcreteHaketiloState(base.HaketiloStateWithFields): + def __post_init__(self) -> None: + self.rebuild_structures() + + def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None: + with self.cursor(transaction=(repo_id == 1)) as cursor: + # This method without the repo_id argument exposed is part of the + # state API. As such, calls with repo_id = 1 (imports of local + # semirepo packages) create a new transaction. Calls with different + # values of repo_id are assumed to originate from within the state + # implementation code and expect an existing transaction. Here, we + # verify the transaction is indeed present. + assert self.connection.in_transaction + + _operations._load_packages_no_state_update( + cursor = cursor, + malcontent_path = malcontent_path, + repo_id = repo_id + ) + + cursor.execute('UPDATE general SET update_waiting = TRUE;') + self.settings = dc.replace(self.settings, update_waiting=True) + + self.rebuild_structures(rules=False) + + def count_orphan_items(self) -> st.OrphanItemsStats: + with self.cursor() as cursor: + cursor.execute( + ''' + SELECT + COALESCE(SUM(i.type = 'M'), 0), + COALESCE(SUM(i.type = 'R'), 0) + FROM + item_versions AS iv + JOIN items AS i USING (item_id) + JOIN orphan_iterations AS oi USING (repo_iteration_id) + WHERE + iv.active != 'R'; + ''' + ) + + (orphan_mappings, orphan_resources), = cursor.fetchall() + + return st.OrphanItemsStats(orphan_mappings, orphan_resources) + + def prune_orphan_items(self) -> None: + with self.cursor(transaction=True) as cursor: + _operations.prune_orphans(cursor, aggressive=True) + + self.recompute_dependencies() + + def soft_prune_orphan_items(self) -> None: + with self.cursor() as cursor: + assert self.connection.in_transaction + + _operations.prune_orphans(cursor) + + def recompute_dependencies( + self, + unlocked_required_mappings: base.NoLockArg = [] + ) -> None: + with self.cursor() as cursor: + assert self.connection.in_transaction + + _operations._recompute_dependencies_no_state_update( + cursor = cursor, + unlocked_required_mappings = unlocked_required_mappings + ) + + if unlocked_required_mappings == 'all_mappings_unlocked': + cursor.execute('UPDATE general SET update_waiting = FALSE;') + self.settings = dc.replace(self.settings, update_waiting=False) + + self.rebuild_structures(rules=False) + + def upate_all_items(self) -> None: + with self.cursor(transaction=True): + self.recompute_dependencies('all_mappings_unlocked') + + def pull_missing_files(self) -> None: + with self.cursor() as cursor: + assert self.connection.in_transaction + + _operations.pull_missing_files(cursor) + + def _rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + new_policy_tree = base.PolicyTree() + + web_ui_main_pattern = 'http*://hkt.mitm.it/***' + web_ui_main_factory = policies.WebUIMainPolicyFactory(builtin=True) + + for parsed_pattern in url_patterns.parse_pattern(web_ui_main_pattern): + new_policy_tree = new_policy_tree.register( + parsed_pattern = parsed_pattern, + item = web_ui_main_factory + ) + + web_ui_landing_pattern = f'{self.efective_listen_addr}/***' + web_ui_landing_factory = policies.WebUILandingPolicyFactory( + builtin = True + ) + + try: + parsed_pattern, = url_patterns.parse_pattern(web_ui_landing_pattern) + except url_patterns.HaketiloURLException: + fmt = _('warn.proxy.failed_to_register_landing_page_at_{}') + self.logger.warn(fmt.format(web_ui_landing_pattern)) + else: + new_policy_tree = new_policy_tree.register( + parsed_pattern = parsed_pattern, + item = web_ui_landing_factory + ) + + mitm_it_page_pattern = 'http://mitm.it/***' + mitm_it_page_factory = policies.MitmItPagePolicyFactory() + + parsed_pattern, = url_patterns.parse_pattern(mitm_it_page_pattern) + new_policy_tree = new_policy_tree.register( + parsed_pattern = parsed_pattern, + item = mitm_it_page_factory + ) + + # Put script blocking/allowing rules in policy tree. + cursor.execute('SELECT pattern, allow_scripts FROM rules;') + + for pattern, allow_scripts in cursor.fetchall(): + for parsed_pattern in url_patterns.parse_pattern(pattern): + factory: policies.PolicyFactory + if allow_scripts: + factory = policies.RuleAllowPolicyFactory( + builtin = False, + pattern = parsed_pattern + ) + else: + factory = policies.RuleBlockPolicyFactory( + builtin = False, + pattern = parsed_pattern + ) + + new_policy_tree = new_policy_tree.register( + parsed_pattern = parsed_pattern, + item = factory + ) + + # Put script payload rules in policy tree. + cursor.execute( + ''' + SELECT + p.payload_id, + p.pattern, + p.eval_allowed, + p.cors_bypass_allowed, + ms.enabled, + i.identifier + FROM + payloads AS p + JOIN item_versions AS iv + ON p.mapping_item_id = iv.item_version_id + JOIN items AS i + USING (item_id) + JOIN mapping_statuses AS ms + USING (item_id); + ''' + ) + + new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} + + for (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, + enabled_status, identifier) in cursor.fetchall(): + payload_ref = payloads.ConcretePayloadRef(str(payload_id_int), self) + + previous_data = self.payloads_data.get(payload_ref) + if previous_data is not None: + token = previous_data.unique_token + else: + token = secrets.token_urlsafe(8) + + payload_key = st.PayloadKey(payload_ref, identifier) + + for parsed_pattern in url_patterns.parse_pattern(pattern): + new_policy_tree = new_policy_tree.register_payload( + parsed_pattern, + payload_key, + token + ) + + pattern_path_segments = parsed_pattern.path_segments + + payload_data = st.PayloadData( + ref = payload_ref, + explicitly_enabled = enabled_status == 'E', + unique_token = token, + mapping_identifier = identifier, + pattern = pattern, + pattern_path_segments = pattern_path_segments, + eval_allowed = eval_allowed, + cors_bypass_allowed = cors_bypass_allowed, + global_secret = self.secret + ) + + new_payloads_data[payload_ref] = payload_data + + self.policy_tree = new_policy_tree + self.payloads_data = new_payloads_data + + def rebuild_structures(self, *, payloads: bool = True, rules: bool = True) \ + -> None: + # The `payloads` and `rules` args will be useful for optimization but + # for now we're not yet using them. + with self.cursor() as cursor: + self._rebuild_structures(cursor) + + def rule_store(self) -> st.RuleStore: + return rules.ConcreteRuleStore(self) + + def repo_store(self) -> st.RepoStore: + return repos.ConcreteRepoStore(self) + + def mapping_store(self) -> st.MappingStore: + return items.ConcreteMappingStore(self) + + def mapping_version_store(self) -> st.MappingVersionStore: + return items.ConcreteMappingVersionStore(self) + + def resource_store(self) -> st.ResourceStore: + return items.ConcreteResourceStore(self) + + def resource_version_store(self) -> st.ResourceVersionStore: + return items.ConcreteResourceVersionStore(self) + + def payload_store(self) -> st.PayloadStore: + return payloads.ConcretePayloadStore(self) + + def get_secret(self) -> bytes: + return self.secret + + def get_settings(self) -> st.HaketiloGlobalSettings: + with self.lock: + return self.settings + + def update_settings( + self, + *, + mapping_use_mode: t.Optional[st.MappingUseMode] = None, + default_allow_scripts: t.Optional[bool] = None, + advanced_user: t.Optional[bool] = None, + repo_refresh_seconds: t.Optional[int] = None, + locale: t.Optional[str] = None, + default_popup_settings: t.Mapping[str, st.PopupSettings] = {} + ) -> None: + with self.cursor(transaction=True) as cursor: + def set_opt(col_name: str, val: t.Union[bool, int, str]) -> None: + cursor.execute(f'UPDATE general SET {col_name} = ?;', (val,)) + + if mapping_use_mode is not None: + set_opt('mapping_use_mode', mapping_use_mode.value) + if default_allow_scripts is not None: + set_opt('default_allow_scripts', default_allow_scripts) + if advanced_user is not None: + set_opt('advanced_user', advanced_user) + if repo_refresh_seconds is not None: + set_opt('repo_refresh_seconds', repo_refresh_seconds) + if locale is not None: + set_opt('locale', locale) + + for page_type in ('jsallowed', 'jsblocked', 'payloadon'): + popup_settings = default_popup_settings.get(page_type) + if popup_settings is not None: + trigger_col_name = f'default_popup_{page_type}_onkeyboard' + set_opt(trigger_col_name, popup_settings.keyboard_trigger) + + style_col_name = f'default_popup_{page_type}_style' + set_opt(style_col_name, popup_settings.style.value) + + self.settings = load_settings(cursor) + + @staticmethod + def make( + store_dir: Path, + listen_host: str, + listen_port: int, + logger: st.Logger + ) -> 'ConcreteHaketiloState': + store_dir.mkdir(parents=True, exist_ok=True) + + connection = sqlite3.connect( + str(store_dir / 'sqlite3.db'), + isolation_level = None, + check_same_thread = False + ) + + _prepare_database(connection) + + global_settings = load_settings(connection.cursor()) + + return ConcreteHaketiloState( + store_dir = store_dir, + _logger = logger, + _listen_host = listen_host, + _listen_port = listen_port, + connection = connection, + settings = global_settings + ) |