aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/state_impl/concrete_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/state_impl/concrete_state.py')
-rw-r--r--src/hydrilla/proxy/state_impl/concrete_state.py523
1 files changed, 523 insertions, 0 deletions
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py
new file mode 100644
index 0000000..89a2eb2
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl/concrete_state.py
@@ -0,0 +1,523 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy data and configuration (instantiatable HaketiloState subtype).
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use of this
+# code in a proprietary program, I am not going to enforce this in
+# court.
+
+"""
+This module contains logic for keeping track of all settings, rules, mappings
+and resources.
+"""
+
+import sqlite3
+import secrets
+import typing as t
+import dataclasses as dc
+
+from pathlib import Path
+
+from ...exceptions import HaketiloException
+from ...translations import smart_gettext as _
+from ... import url_patterns
+from ... import item_infos
+from .. import state as st
+from .. import policies
+from .. import simple_dependency_satisfying as sds
+from . import base
+from . import rules
+from . import items
+from . import repos
+from . import payloads
+from . import _operations
+
+
+here = Path(__file__).resolve().parent
+
+
+def _add_popup_settings_columns(cursor: sqlite3.Cursor) -> None:
+ for page_type in ('jsallowed', 'jsblocked', 'payloadon'):
+ cursor.execute(
+ f'''
+ ALTER TABLE general ADD COLUMN
+ default_popup_{page_type}_onkeyboard BOOLEAN NOT NULL DEFAULT TRUE;
+ '''
+ )
+ cursor.execute(
+ f'''
+ ALTER TABLE general ADD COLUMN
+ default_popup_{page_type}_style CHAR(1) NOT NULL DEFAULT 'T'
+ CHECK (default_popup_{page_type}_style IN ('D', 'T'));
+ '''
+ )
+
+def _add_locale_column(cursor: sqlite3.Cursor) -> None:
+ cursor.execute(
+ '''
+ ALTER TABLE general ADD COLUMN
+ locale VARCHAR NOT NULL DEFAULT 'unknown';
+ '''
+ )
+
+def _add_update_waiting_column(cursor: sqlite3.Cursor) -> None:
+ cursor.execute(
+ '''
+ ALTER TABLE general ADD COLUMN
+ update_waiting BOOLEAN NOT NULL DEFAULT TRUE;
+ '''
+ )
+
+def _prepare_database(connection: sqlite3.Connection) -> None:
+ cursor = connection.cursor()
+
+ try:
+ cursor.execute(
+ '''
+ SELECT
+ COUNT(name)
+ FROM
+ sqlite_master
+ WHERE
+ name = 'general' AND type = 'table';
+ '''
+ )
+
+ (db_initialized,), = cursor.fetchall()
+
+ if not db_initialized:
+ cursor.executescript((here / 'tables.sql').read_text())
+
+ cursor.execute('BEGIN TRANSACTION;')
+
+ try:
+ if db_initialized:
+ # If db was initialized before we connected to it, we must check
+ # what its schema version is.
+ cursor.execute(
+ '''
+ SELECT
+ haketilo_version
+ FROM
+ general;
+ '''
+ )
+
+ (db_haketilo_version,) = cursor.fetchone()
+ if db_haketilo_version != '3.0b1':
+ raise HaketiloException(_('err.proxy.unknown_db_schema'))
+
+ popup_settings_columns_present = False
+ locale_column_present = False
+ update_waiting_column_present = False
+
+ cursor.execute("PRAGMA TABLE_INFO('general')")
+ for __cid, name, __type, __notnull, __dflt_value, __pk \
+ in cursor.fetchall():
+ if name == 'default_popup_jsallowed_onkeyboard':
+ popup_settings_columns_present = True
+
+ if name == 'locale':
+ locale_column_present = True
+
+ if name == 'update_waiting':
+ update_waiting_column_present = True
+
+ if not popup_settings_columns_present:
+ _add_popup_settings_columns(cursor)
+
+ if not locale_column_present:
+ _add_locale_column(cursor)
+
+ if not update_waiting_column_present:
+ _add_update_waiting_column(cursor)
+
+ cursor.execute('COMMIT TRANSACTION;')
+ except:
+ cursor.execute('ROLLBACK TRANSACTION;')
+ raise
+
+ cursor.execute('PRAGMA FOREIGN_KEYS;')
+ if cursor.fetchall() == []:
+ raise HaketiloException(_('err.proxy.no_sqlite_foreign_keys'))
+
+ cursor.execute('PRAGMA FOREIGN_KEYS=ON;')
+ finally:
+ cursor.close()
+
+
+def load_settings(cursor: sqlite3.Cursor) -> st.HaketiloGlobalSettings:
+ cursor.execute(
+ '''
+ SELECT
+ default_allow_scripts,
+ advanced_user,
+ repo_refresh_seconds,
+ mapping_use_mode,
+ locale,
+ update_waiting
+ FROM
+ general;
+ '''
+ )
+
+ (default_allow_scripts, advanced_user, repo_refresh_seconds,
+ mapping_use_mode, locale, update_waiting), = cursor.fetchall()
+
+ popup_settings_dict = {}
+
+ for page_type in ('jsallowed', 'jsblocked', 'payloadon'):
+ try:
+ cursor.execute(
+ f'''
+ SELECT
+ default_popup_{page_type}_onkeyboard,
+ default_popup_{page_type}_style
+ FROM
+ general;
+ '''
+ )
+
+ (onkeyboard, style), = cursor.fetchall()
+ except:
+ onkeyboard, style = True, 'T'
+
+ popup_settings_dict[f'default_popup_{page_type}'] = st.PopupSettings(
+ keyboard_trigger = onkeyboard,
+ style = st.PopupStyle(style)
+ )
+
+ return st.HaketiloGlobalSettings(
+ default_allow_scripts = default_allow_scripts,
+ advanced_user = advanced_user,
+ repo_refresh_seconds = repo_refresh_seconds,
+ mapping_use_mode = st.MappingUseMode(mapping_use_mode),
+ locale = locale,
+ update_waiting = update_waiting,
+
+ **popup_settings_dict
+ )
+
+@dc.dataclass
+class ConcreteHaketiloState(base.HaketiloStateWithFields):
+ def __post_init__(self) -> None:
+ self.rebuild_structures()
+
+ def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None:
+ with self.cursor(transaction=(repo_id == 1)) as cursor:
+ # This method without the repo_id argument exposed is part of the
+ # state API. As such, calls with repo_id = 1 (imports of local
+ # semirepo packages) create a new transaction. Calls with different
+ # values of repo_id are assumed to originate from within the state
+ # implementation code and expect an existing transaction. Here, we
+ # verify the transaction is indeed present.
+ assert self.connection.in_transaction
+
+ _operations._load_packages_no_state_update(
+ cursor = cursor,
+ malcontent_path = malcontent_path,
+ repo_id = repo_id
+ )
+
+ cursor.execute('UPDATE general SET update_waiting = TRUE;')
+ self.settings = dc.replace(self.settings, update_waiting=True)
+
+ self.rebuild_structures(rules=False)
+
+ def count_orphan_items(self) -> st.OrphanItemsStats:
+ with self.cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ COALESCE(SUM(i.type = 'M'), 0),
+ COALESCE(SUM(i.type = 'R'), 0)
+ FROM
+ item_versions AS iv
+ JOIN items AS i USING (item_id)
+ JOIN orphan_iterations AS oi USING (repo_iteration_id)
+ WHERE
+ iv.active != 'R';
+ '''
+ )
+
+ (orphan_mappings, orphan_resources), = cursor.fetchall()
+
+ return st.OrphanItemsStats(orphan_mappings, orphan_resources)
+
+ def prune_orphan_items(self) -> None:
+ with self.cursor(transaction=True) as cursor:
+ _operations.prune_orphans(cursor, aggressive=True)
+
+ self.recompute_dependencies()
+
+ def soft_prune_orphan_items(self) -> None:
+ with self.cursor() as cursor:
+ assert self.connection.in_transaction
+
+ _operations.prune_orphans(cursor)
+
+ def recompute_dependencies(
+ self,
+ unlocked_required_mappings: base.NoLockArg = []
+ ) -> None:
+ with self.cursor() as cursor:
+ assert self.connection.in_transaction
+
+ _operations._recompute_dependencies_no_state_update(
+ cursor = cursor,
+ unlocked_required_mappings = unlocked_required_mappings
+ )
+
+ if unlocked_required_mappings == 'all_mappings_unlocked':
+ cursor.execute('UPDATE general SET update_waiting = FALSE;')
+ self.settings = dc.replace(self.settings, update_waiting=False)
+
+ self.rebuild_structures(rules=False)
+
+ def upate_all_items(self) -> None:
+ with self.cursor(transaction=True):
+ self.recompute_dependencies('all_mappings_unlocked')
+
+ def pull_missing_files(self) -> None:
+ with self.cursor() as cursor:
+ assert self.connection.in_transaction
+
+ _operations.pull_missing_files(cursor)
+
+ def _rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
+ new_policy_tree = base.PolicyTree()
+
+ web_ui_main_pattern = 'http*://hkt.mitm.it/***'
+ web_ui_main_factory = policies.WebUIMainPolicyFactory(builtin=True)
+
+ for parsed_pattern in url_patterns.parse_pattern(web_ui_main_pattern):
+ new_policy_tree = new_policy_tree.register(
+ parsed_pattern = parsed_pattern,
+ item = web_ui_main_factory
+ )
+
+ web_ui_landing_pattern = f'{self.efective_listen_addr}/***'
+ web_ui_landing_factory = policies.WebUILandingPolicyFactory(
+ builtin = True
+ )
+
+ try:
+ parsed_pattern, = url_patterns.parse_pattern(web_ui_landing_pattern)
+ except url_patterns.HaketiloURLException:
+ fmt = _('warn.proxy.failed_to_register_landing_page_at_{}')
+ self.logger.warn(fmt.format(web_ui_landing_pattern))
+ else:
+ new_policy_tree = new_policy_tree.register(
+ parsed_pattern = parsed_pattern,
+ item = web_ui_landing_factory
+ )
+
+ mitm_it_page_pattern = 'http://mitm.it/***'
+ mitm_it_page_factory = policies.MitmItPagePolicyFactory()
+
+ parsed_pattern, = url_patterns.parse_pattern(mitm_it_page_pattern)
+ new_policy_tree = new_policy_tree.register(
+ parsed_pattern = parsed_pattern,
+ item = mitm_it_page_factory
+ )
+
+ # Put script blocking/allowing rules in policy tree.
+ cursor.execute('SELECT pattern, allow_scripts FROM rules;')
+
+ for pattern, allow_scripts in cursor.fetchall():
+ for parsed_pattern in url_patterns.parse_pattern(pattern):
+ factory: policies.PolicyFactory
+ if allow_scripts:
+ factory = policies.RuleAllowPolicyFactory(
+ builtin = False,
+ pattern = parsed_pattern
+ )
+ else:
+ factory = policies.RuleBlockPolicyFactory(
+ builtin = False,
+ pattern = parsed_pattern
+ )
+
+ new_policy_tree = new_policy_tree.register(
+ parsed_pattern = parsed_pattern,
+ item = factory
+ )
+
+ # Put script payload rules in policy tree.
+ cursor.execute(
+ '''
+ SELECT
+ p.payload_id,
+ p.pattern,
+ p.eval_allowed,
+ p.cors_bypass_allowed,
+ ms.enabled,
+ i.identifier
+ FROM
+ payloads AS p
+ JOIN item_versions AS iv
+ ON p.mapping_item_id = iv.item_version_id
+ JOIN items AS i
+ USING (item_id)
+ JOIN mapping_statuses AS ms
+ USING (item_id);
+ '''
+ )
+
+ new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
+
+ for (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
+ enabled_status, identifier) in cursor.fetchall():
+ payload_ref = payloads.ConcretePayloadRef(str(payload_id_int), self)
+
+ previous_data = self.payloads_data.get(payload_ref)
+ if previous_data is not None:
+ token = previous_data.unique_token
+ else:
+ token = secrets.token_urlsafe(8)
+
+ payload_key = st.PayloadKey(payload_ref, identifier)
+
+ for parsed_pattern in url_patterns.parse_pattern(pattern):
+ new_policy_tree = new_policy_tree.register_payload(
+ parsed_pattern,
+ payload_key,
+ token
+ )
+
+ pattern_path_segments = parsed_pattern.path_segments
+
+ payload_data = st.PayloadData(
+ ref = payload_ref,
+ explicitly_enabled = enabled_status == 'E',
+ unique_token = token,
+ mapping_identifier = identifier,
+ pattern = pattern,
+ pattern_path_segments = pattern_path_segments,
+ eval_allowed = eval_allowed,
+ cors_bypass_allowed = cors_bypass_allowed,
+ global_secret = self.secret
+ )
+
+ new_payloads_data[payload_ref] = payload_data
+
+ self.policy_tree = new_policy_tree
+ self.payloads_data = new_payloads_data
+
+ def rebuild_structures(self, *, payloads: bool = True, rules: bool = True) \
+ -> None:
+ # The `payloads` and `rules` args will be useful for optimization but
+ # for now we're not yet using them.
+ with self.cursor() as cursor:
+ self._rebuild_structures(cursor)
+
+ def rule_store(self) -> st.RuleStore:
+ return rules.ConcreteRuleStore(self)
+
+ def repo_store(self) -> st.RepoStore:
+ return repos.ConcreteRepoStore(self)
+
+ def mapping_store(self) -> st.MappingStore:
+ return items.ConcreteMappingStore(self)
+
+ def mapping_version_store(self) -> st.MappingVersionStore:
+ return items.ConcreteMappingVersionStore(self)
+
+ def resource_store(self) -> st.ResourceStore:
+ return items.ConcreteResourceStore(self)
+
+ def resource_version_store(self) -> st.ResourceVersionStore:
+ return items.ConcreteResourceVersionStore(self)
+
+ def payload_store(self) -> st.PayloadStore:
+ return payloads.ConcretePayloadStore(self)
+
+ def get_secret(self) -> bytes:
+ return self.secret
+
+ def get_settings(self) -> st.HaketiloGlobalSettings:
+ with self.lock:
+ return self.settings
+
+ def update_settings(
+ self,
+ *,
+ mapping_use_mode: t.Optional[st.MappingUseMode] = None,
+ default_allow_scripts: t.Optional[bool] = None,
+ advanced_user: t.Optional[bool] = None,
+ repo_refresh_seconds: t.Optional[int] = None,
+ locale: t.Optional[str] = None,
+ default_popup_settings: t.Mapping[str, st.PopupSettings] = {}
+ ) -> None:
+ with self.cursor(transaction=True) as cursor:
+ def set_opt(col_name: str, val: t.Union[bool, int, str]) -> None:
+ cursor.execute(f'UPDATE general SET {col_name} = ?;', (val,))
+
+ if mapping_use_mode is not None:
+ set_opt('mapping_use_mode', mapping_use_mode.value)
+ if default_allow_scripts is not None:
+ set_opt('default_allow_scripts', default_allow_scripts)
+ if advanced_user is not None:
+ set_opt('advanced_user', advanced_user)
+ if repo_refresh_seconds is not None:
+ set_opt('repo_refresh_seconds', repo_refresh_seconds)
+ if locale is not None:
+ set_opt('locale', locale)
+
+ for page_type in ('jsallowed', 'jsblocked', 'payloadon'):
+ popup_settings = default_popup_settings.get(page_type)
+ if popup_settings is not None:
+ trigger_col_name = f'default_popup_{page_type}_onkeyboard'
+ set_opt(trigger_col_name, popup_settings.keyboard_trigger)
+
+ style_col_name = f'default_popup_{page_type}_style'
+ set_opt(style_col_name, popup_settings.style.value)
+
+ self.settings = load_settings(cursor)
+
+ @staticmethod
+ def make(
+ store_dir: Path,
+ listen_host: str,
+ listen_port: int,
+ logger: st.Logger
+ ) -> 'ConcreteHaketiloState':
+ store_dir.mkdir(parents=True, exist_ok=True)
+
+ connection = sqlite3.connect(
+ str(store_dir / 'sqlite3.db'),
+ isolation_level = None,
+ check_same_thread = False
+ )
+
+ _prepare_database(connection)
+
+ global_settings = load_settings(connection.cursor())
+
+ return ConcreteHaketiloState(
+ store_dir = store_dir,
+ _logger = logger,
+ _listen_host = listen_host,
+ _listen_port = listen_port,
+ connection = connection,
+ settings = global_settings
+ )