# SPDX-License-Identifier: GPL-3.0-or-later # Haketilo proxy data and configuration (instantiatable HaketiloState subtype). # # This file is part of Hydrilla&Haketilo. # # Copyright (C) 2022 Wojtek Kosior # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # # I, Wojtek Kosior, thereby promise not to sue for violation of this # file's license. Although I request that you do not make use this code # in a proprietary program, I am not going to enforce this in court. """ This module contains logic for keeping track of all settings, rules, mappings and resources. """ # Enable using with Python 3.7. from __future__ import annotations import secrets import io import hashlib import typing as t import dataclasses as dc from pathlib import Path import sqlite3 from ...exceptions import HaketiloException from ...translations import smart_gettext as _ from ... import pattern_tree from ... import url_patterns from ... import versions from ... import item_infos from ..simple_dependency_satisfying import compute_payloads, ComputedPayload from .. import state as st from .. import policies from . import base here = Path(__file__).resolve().parent @dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] class ConcreteRepoRef(st.RepoRef): def remove(self, state: st.HaketiloState) -> None: raise NotImplementedError() def update( self, state: st.HaketiloState, *, name: t.Optional[str] = None, url: t.Optional[str] = None ) -> ConcreteRepoRef: raise NotImplementedError() def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef: raise NotImplementedError() @dc.dataclass(frozen=True, unsafe_hash=True) class ConcreteRepoIterationRef(st.RepoIterationRef): pass @dc.dataclass(frozen=True, unsafe_hash=True) class ConcreteMappingRef(st.MappingRef): def disable(self, state: st.HaketiloState) -> None: raise NotImplementedError() def forget_enabled(self, state: st.HaketiloState) -> None: raise NotImplementedError() @dc.dataclass(frozen=True, unsafe_hash=True) class ConcreteMappingVersionRef(st.MappingVersionRef): def enable(self, state: st.HaketiloState) -> None: raise NotImplementedError() @dc.dataclass(frozen=True, unsafe_hash=True) class ConcreteResourceRef(st.ResourceRef): pass @dc.dataclass(frozen=True, unsafe_hash=True) class ConcreteResourceVersionRef(st.ResourceVersionRef): pass @dc.dataclass(frozen=True, unsafe_hash=True) class ConcretePayloadRef(st.PayloadRef): def get_data(self, state: st.HaketiloState) -> st.PayloadData: return t.cast(ConcreteHaketiloState, state).payloads_data[self] def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: return 'to implement' def get_script_paths(self, state: st.HaketiloState) \ -> t.Iterable[t.Sequence[str]]: with t.cast(ConcreteHaketiloState, state).cursor() as cursor: cursor.execute( ''' SELECT i.identifier, fu.name FROM payloads AS p LEFT JOIN resolved_depended_resources AS rdd USING (payload_id) LEFT JOIN item_versions AS iv ON rdd.resource_item_id = iv.item_version_id LEFT JOIN items AS i USING (item_id) LEFT JOIN file_uses AS fu USING (item_version_id) WHERE fu.type = 'W' AND p.payload_id = ? AND (fu.idx IS NOT NULL OR rdd.idx IS NULL) ORDER BY rdd.idx, fu.idx; ''', (self.id,) ) paths: list[t.Sequence[str]] = [] for resource_identifier, file_name in cursor.fetchall(): if resource_identifier is None: # payload found but it had no script files return () paths.append((resource_identifier, *file_name.split('/'))) if paths == []: # payload not found raise st.MissingItemError() return paths def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ -> t.Optional[st.FileData]: if len(path) == 0: raise st.MissingItemError() resource_identifier, *file_name_segments = path file_name = '/'.join(file_name_segments) with t.cast(ConcreteHaketiloState, state).cursor() as cursor: cursor.execute( ''' SELECT f.data, fu.mime_type FROM payloads AS p JOIN resolved_depended_resources AS rdd USING (payload_id) JOIN item_versions AS iv ON rdd.resource_item_id = iv.item_version_id JOIN items AS i USING (item_id) JOIN file_uses AS fu USING (item_version_id) JOIN files AS f USING (file_id) WHERE p.payload_id = ? AND i.identifier = ? AND fu.name = ? AND fu.type = 'W'; ''', (self.id, resource_identifier, file_name) ) result = cursor.fetchall() if result == []: return None (data, mime_type), = result return st.FileData(type=mime_type, name=file_name, contents=data) # @dc.dataclass(frozen=True, unsafe_hash=True) # class ConcretePayloadRef(st.PayloadRef): # computed_payload: ComputedPayload = dc.field(hash=False, compare=False) # def get_data(self, state: st.HaketiloState) -> st.PayloadData: # return t.cast(ConcreteHaketiloState, state).payloads_data[self.id] # def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: # return 'to implement' # def get_script_paths(self, state: st.HaketiloState) \ # -> t.Iterator[t.Sequence[str]]: # for resource_info in self.computed_payload.resources: # for file_spec in resource_info.scripts: # yield (resource_info.identifier, *file_spec.name.split('/')) # def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ # -> t.Optional[st.FileData]: # if len(path) == 0: # raise st.MissingItemError() # resource_identifier, *file_name_segments = path # file_name = '/'.join(file_name_segments) # script_sha256 = '' # matched_resource_info = False # for resource_info in self.computed_payload.resources: # if resource_info.identifier == resource_identifier: # matched_resource_info = True # for script_spec in resource_info.scripts: # if script_spec.name == file_name: # script_sha256 = script_spec.sha256 # break # if not matched_resource_info: # raise st.MissingItemError(resource_identifier) # if script_sha256 == '': # return None # store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir # files_dir_path = store_dir_path / 'temporary_malcontent' / 'file' # file_path = files_dir_path / 'sha256' / script_sha256 # return st.FileData( # type = 'application/javascript', # name = file_name, # contents = file_path.read_bytes() # ) def register_payload( policy_tree: base.PolicyTree, pattern: url_patterns.ParsedPattern, payload_key: st.PayloadKey, token: str ) -> base.PolicyTree: """....""" payload_policy_factory = policies.PayloadPolicyFactory( builtin = False, payload_key = payload_key ) policy_tree = policy_tree.register(pattern, payload_policy_factory) resource_policy_factory = policies.PayloadResourcePolicyFactory( builtin = False, payload_key = payload_key ) policy_tree = policy_tree.register( pattern.path_append(token, '***'), resource_policy_factory ) return policy_tree DataById = t.Mapping[str, st.PayloadData] AnyInfoVar = t.TypeVar( 'AnyInfoVar', item_infos.ResourceInfo, item_infos.MappingInfo ) def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ -> t.Iterator[tuple[AnyInfoVar, str]]: item_type_path = malcontent_path / item_class.type_name if not item_type_path.is_dir(): return for item_path in item_type_path.iterdir(): if not item_path.is_dir(): continue for item_version_path in item_path.iterdir(): definition = item_version_path.read_text() item_info = item_class.load(io.StringIO(definition)) assert item_info.identifier == item_path.name assert versions.version_string(item_info.version) == \ item_version_path.name yield item_info, definition def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int: cursor.execute( ''' INSERT OR IGNORE INTO repos(name, url, deleted, next_iteration) VALUES(?, '', TRUE, 2); ''', (repo_name,) ) cursor.execute( ''' SELECT repo_id, next_iteration - 1 FROM repos WHERE name = ?; ''', (repo_name,) ) (repo_id, last_iteration), = cursor.fetchall() cursor.execute( ''' INSERT OR IGNORE INTO repo_iterations(repo_id, iteration) VALUES(?, ?); ''', (repo_id, last_iteration) ) cursor.execute( ''' SELECT repo_iteration_id FROM repo_iterations WHERE repo_id = ? AND iteration = ?; ''', (repo_id, last_iteration) ) (repo_iteration_id,), = cursor.fetchall() return repo_iteration_id def get_or_make_item(cursor: sqlite3.Cursor, type: str, identifier: str) -> int: type_letter = {'resource': 'R', 'mapping': 'M'}[type] cursor.execute( ''' INSERT OR IGNORE INTO items(type, identifier) VALUES(?, ?); ''', (type_letter, identifier) ) cursor.execute( ''' SELECT item_id FROM items WHERE type = ? AND identifier = ?; ''', (type_letter, identifier) ) (item_id,), = cursor.fetchall() return item_id def get_or_make_item_version( cursor: sqlite3.Cursor, item_id: int, repo_iteration_id: int, version: versions.VerTuple, definition: str ) -> int: ver_str = versions.version_string(version) cursor.execute( ''' INSERT OR IGNORE INTO item_versions( item_id, version, repo_iteration_id, definition ) VALUES(?, ?, ?, ?); ''', (item_id, ver_str, repo_iteration_id, definition) ) cursor.execute( ''' SELECT item_version_id FROM item_versions WHERE item_id = ? AND version = ? AND repo_iteration_id = ?; ''', (item_id, ver_str, repo_iteration_id) ) (item_version_id,), = cursor.fetchall() return item_version_id def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None: cursor.execute( ''' INSERT OR IGNORE INTO mapping_statuses(item_id, enabled) VALUES(?, 'N'); ''', (item_id,) ) def get_or_make_file(cursor: sqlite3.Cursor, sha256: str, file_bytes: bytes) \ -> int: cursor.execute( ''' INSERT OR IGNORE INTO files(sha256, data) VALUES(?, ?) ''', (sha256, file_bytes) ) cursor.execute( ''' SELECT file_id FROM files WHERE sha256 = ?; ''', (sha256,) ) (file_id,), = cursor.fetchall() return file_id def make_file_use( cursor: sqlite3.Cursor, item_version_id: int, file_id: int, name: str, type: str, mime_type: str, idx: int ) -> None: cursor.execute( ''' INSERT OR IGNORE INTO file_uses( item_version_id, file_id, name, type, mime_type, idx ) VALUES(?, ?, ?, ?, ?, ?); ''', (item_version_id, file_id, name, type, mime_type, idx) ) def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \ -> t.Mapping[AnyInfoVar, int]: cursor.execute( ''' SELECT i.item_id, iv.definition, r.name, ri.iteration FROM item_versions AS iv JOIN items AS i USING (item_id) JOIN repo_iterations AS ri USING (repo_iteration_id) JOIN repos AS r USING (repo_id) WHERE i.type = ?; ''', (info_type.type_name[0].upper(),) ) result: dict[AnyInfoVar, int] = {} for item_id, definition, repo_name, repo_iteration in cursor.fetchall(): definition_io = io.StringIO(definition) info = info_type.load(definition_io, repo_name, repo_iteration) result[info] = item_id return result @dc.dataclass class ConcreteHaketiloState(base.HaketiloStateWithFields): def __post_init__(self) -> None: self._prepare_database() self._populate_database_with_stuff_from_temporary_malcontent_dir() with self.cursor(transaction=True) as cursor: self.recompute_payloads(cursor) def _prepare_database(self) -> None: """....""" cursor = self.connection.cursor() try: cursor.execute( ''' SELECT COUNT(name) FROM sqlite_master WHERE name = 'general' AND type = 'table'; ''' ) (db_initialized,), = cursor.fetchall() if not db_initialized: cursor.executescript((here.parent / 'tables.sql').read_text()) else: cursor.execute( ''' SELECT haketilo_version FROM general; ''' ) (db_haketilo_version,) = cursor.fetchone() if db_haketilo_version != '3.0b1': raise HaketiloException(_('err.unknown_db_schema')) cursor.execute('PRAGMA FOREIGN_KEYS;') if cursor.fetchall() == []: raise HaketiloException(_('err.proxy.no_sqlite_foreign_keys')) cursor.execute('PRAGMA FOREIGN_KEYS=ON;') finally: cursor.close() def _populate_database_with_stuff_from_temporary_malcontent_dir(self) \ -> None: malcontent_dir_path = self.store_dir / 'temporary_malcontent' files_by_sha256_path = malcontent_dir_path / 'file' / 'sha256' with self.cursor(transaction=True) as cursor: for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]: info: item_infos.AnyInfo for info, definition in read_items( malcontent_dir_path, info_type # type: ignore ): repo_iteration_id = get_or_make_repo_iteration( cursor, info.repo ) item_id = get_or_make_item( cursor, info.type_name, info.identifier ) item_version_id = get_or_make_item_version( cursor, item_id, repo_iteration_id, info.version, definition ) if info_type is item_infos.MappingInfo: make_mapping_status(cursor, item_id) file_ids_bytes = {} file_specifiers = [*info.source_copyright] if isinstance(info, item_infos.ResourceInfo): file_specifiers.extend(info.scripts) for file_spec in file_specifiers: file_path = files_by_sha256_path / file_spec.sha256 file_bytes = file_path.read_bytes() sha256 = hashlib.sha256(file_bytes).digest().hex() assert sha256 == file_spec.sha256 file_id = get_or_make_file(cursor, sha256, file_bytes) file_ids_bytes[sha256] = (file_id, file_bytes) for idx, file_spec in enumerate(info.source_copyright): file_id, file_bytes = file_ids_bytes[file_spec.sha256] if file_bytes.isascii(): mime = 'text/plain' else: mime = 'application/octet-stream' make_file_use( cursor, item_version_id = item_version_id, file_id = file_id, name = file_spec.name, type = 'L', mime_type = mime, idx = idx ) if isinstance(info, item_infos.MappingInfo): continue for idx, file_spec in enumerate(info.scripts): file_id, _ = file_ids_bytes[file_spec.sha256] make_file_use( cursor, item_version_id = item_version_id, file_id = file_id, name = file_spec.name, type = 'W', mime_type = 'application/javascript', idx = idx ) def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: assert self.connection.in_transaction resources = get_infos_of_type(cursor, item_infos.ResourceInfo) mappings = get_infos_of_type(cursor, item_infos.MappingInfo) payloads = compute_payloads(resources.keys(), mappings.keys()) payloads_data: dict[st.PayloadRef, st.PayloadData] = {} cursor.execute('DELETE FROM payloads;') for mapping_info, by_pattern in payloads.items(): for num, (pattern, payload) in enumerate(by_pattern.items()): cursor.execute( ''' INSERT INTO payloads( mapping_item_id, pattern, eval_allowed, cors_bypass_allowed ) VALUES (?, ?, ?, ?); ''', ( mappings[mapping_info], pattern.orig_url, payload.allows_eval, payload.allows_cors_bypass ) ) cursor.execute( ''' SELECT payload_id FROM payloads WHERE mapping_item_id = ? AND pattern = ?; ''', (mappings[mapping_info], pattern.orig_url) ) (payload_id_int,), = cursor.fetchall() for res_num, resource_info in enumerate(payload.resources): cursor.execute( ''' INSERT INTO resolved_depended_resources( payload_id, resource_item_id, idx ) VALUES(?, ?, ?); ''', (payload_id_int, resources[resource_info], res_num) ) self._rebuild_structures(cursor) def _rebuild_structures(self, cursor: sqlite3.Cursor) -> None: cursor.execute( ''' SELECT p.payload_id, p.pattern, p.eval_allowed, p.cors_bypass_allowed, ms.enabled, i.identifier FROM payloads AS p JOIN item_versions AS iv ON p.mapping_item_id = iv.item_version_id JOIN items AS i USING (item_id) JOIN mapping_statuses AS ms USING (item_id); ''' ) new_policy_tree = base.PolicyTree() ui_factory = policies.WebUIPolicyFactory(builtin=True) web_ui_pattern = 'http*://hkt.mitm.it/***' for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): new_policy_tree = new_policy_tree.register( parsed_pattern, ui_factory ) new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} for row in cursor.fetchall(): (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, enabled_status, identifier) = row payload_ref = ConcretePayloadRef(str(payload_id_int)) previous_data = self.payloads_data.get(payload_ref) if previous_data is not None: token = previous_data.unique_token else: token = secrets.token_urlsafe(8) payload_key = st.PayloadKey(payload_ref, identifier) for parsed_pattern in url_patterns.parse_pattern(pattern): new_policy_tree = register_payload( new_policy_tree, parsed_pattern, payload_key, token ) pattern_path_segments = parsed_pattern.path_segments payload_data = st.PayloadData( payload_ref = payload_ref, #explicitly_enabled = enabled_status == 'E', explicitly_enabled = True, unique_token = token, pattern_path_segments = pattern_path_segments, eval_allowed = eval_allowed, cors_bypass_allowed = cors_bypass_allowed ) new_payloads_data[payload_ref] = payload_data self.policy_tree = new_policy_tree self.payloads_data = new_payloads_data def get_repo(self, repo_id: str) -> st.RepoRef: return ConcreteRepoRef(repo_id) def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef: return ConcreteRepoIterationRef(repo_iteration_id) def get_mapping(self, mapping_id: str) -> st.MappingRef: return ConcreteMappingRef(mapping_id) def get_mapping_version(self, mapping_version_id: str) \ -> st.MappingVersionRef: return ConcreteMappingVersionRef(mapping_version_id) def get_resource(self, resource_id: str) -> st.ResourceRef: return ConcreteResourceRef(resource_id) def get_resource_version(self, resource_version_id: str) \ -> st.ResourceVersionRef: return ConcreteResourceVersionRef(resource_version_id) def get_payload(self, payload_id: str) -> st.PayloadRef: return 'not implemented' def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \ -> st.RepoRef: raise NotImplementedError() def get_settings(self) -> st.HaketiloGlobalSettings: return st.HaketiloGlobalSettings( mapping_use_mode = st.MappingUseMode.AUTO, default_allow_scripts = True, repo_refresh_seconds = 0 ) def update_settings( self, *, mapping_use_mode: t.Optional[st.MappingUseMode] = None, default_allow_scripts: t.Optional[bool] = None, repo_refresh_seconds: t.Optional[int] = None ) -> None: raise NotImplementedError() def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy: """....""" with self.lock: policy_tree = self.policy_tree try: best_priority: int = 0 best_policy: t.Optional[policies.Policy] = None for factories_set in policy_tree.search(url): for stored_factory in sorted(factories_set): factory = stored_factory.item policy = factory.make_policy(self) if policy.priority > best_priority: best_priority = policy.priority best_policy = policy except Exception as e: return policies.ErrorBlockPolicy( builtin = True, error = e ) if best_policy is not None: return best_policy if self.get_settings().default_allow_scripts: return policies.FallbackAllowPolicy() else: return policies.FallbackBlockPolicy() @staticmethod def make(store_dir: Path) -> 'ConcreteHaketiloState': connection = sqlite3.connect( str(store_dir / 'sqlite3.db'), isolation_level = None, check_same_thread = False ) return ConcreteHaketiloState( store_dir = store_dir, connection = connection )