# SPDX-License-Identifier: GPL-3.0-or-later # Haketilo proxy data and configuration (definition of fields of a class that # will implement HaketiloState). # # This file is part of Hydrilla&Haketilo. # # Copyright (C) 2022 Wojtek Kosior # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # # I, Wojtek Kosior, thereby promise not to sue for violation of this # file's license. Although I request that you do not make use this code # in a proprietary program, I am not going to enforce this in court. """ This module defines fields that will later be part of a concrete HaketiloState subtype. """ # Enable using with Python 3.7. from __future__ import annotations import sqlite3 import secrets import threading import dataclasses as dc import typing as t from pathlib import Path from contextlib import contextmanager from abc import abstractmethod from immutables import Map from ... import url_patterns from ... import pattern_tree from .. import simple_dependency_satisfying as sds from .. import state as st from .. import policies PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData] @dc.dataclass(frozen=True, unsafe_hash=True) class ConcretePayloadRef(st.PayloadRef): state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False) def get_data(self) -> st.PayloadData: try: return self.state.payloads_data[self] except KeyError: raise st.MissingItemError() def get_mapping(self) -> st.MappingVersionRef: raise NotImplementedError() def get_script_paths(self) \ -> t.Iterable[t.Sequence[str]]: with self.state.cursor() as cursor: cursor.execute( ''' SELECT i.identifier, fu.name FROM payloads AS p LEFT JOIN resolved_depended_resources AS rdd USING (payload_id) LEFT JOIN item_versions AS iv ON rdd.resource_item_id = iv.item_version_id LEFT JOIN items AS i USING (item_id) LEFT JOIN file_uses AS fu USING (item_version_id) WHERE fu.type = 'W' AND p.payload_id = ? AND (fu.idx IS NOT NULL OR rdd.idx IS NULL) ORDER BY rdd.idx, fu.idx; ''', (self.id,) ) paths: list[t.Sequence[str]] = [] for resource_identifier, file_name in cursor.fetchall(): if resource_identifier is None: # payload found but it had no script files return () paths.append((resource_identifier, *file_name.split('/'))) if paths == []: # payload not found raise st.MissingItemError() return paths def get_file_data(self, path: t.Sequence[str]) \ -> t.Optional[st.FileData]: if len(path) == 0: raise st.MissingItemError() resource_identifier, *file_name_segments = path file_name = '/'.join(file_name_segments) with self.state.cursor() as cursor: cursor.execute( ''' SELECT f.data, fu.mime_type FROM payloads AS p JOIN resolved_depended_resources AS rdd USING (payload_id) JOIN item_versions AS iv ON rdd.resource_item_id = iv.item_version_id JOIN items AS i USING (item_id) JOIN file_uses AS fu USING (item_version_id) JOIN files AS f USING (file_id) WHERE p.payload_id = ? AND i.identifier = ? AND fu.name = ? AND fu.type = 'W'; ''', (self.id, resource_identifier, file_name) ) result = cursor.fetchall() if result == []: return None (data, mime_type), = result return st.FileData(type=mime_type, name=file_name, contents=data) def register_payload( policy_tree: PolicyTree, pattern: url_patterns.ParsedPattern, payload_key: st.PayloadKey, token: str ) -> PolicyTree: """....""" payload_policy_factory = policies.PayloadPolicyFactory( builtin = False, payload_key = payload_key ) policy_tree = policy_tree.register(pattern, payload_policy_factory) resource_policy_factory = policies.PayloadResourcePolicyFactory( builtin = False, payload_key = payload_key ) policy_tree = policy_tree.register( pattern.path_append(token, '***'), resource_policy_factory ) return policy_tree # mypy needs to be corrected: # https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 @dc.dataclass # type: ignore[misc] class HaketiloStateWithFields(st.HaketiloState): """....""" store_dir: Path connection: sqlite3.Connection current_cursor: t.Optional[sqlite3.Cursor] = None #settings: st.HaketiloGlobalSettings policy_tree: PolicyTree = PolicyTree() payloads_data: PayloadsData = dc.field(default_factory=dict) lock: threading.RLock = dc.field(default_factory=threading.RLock) @contextmanager def cursor(self, transaction: bool = False) \ -> t.Iterator[sqlite3.Cursor]: """....""" start_transaction = transaction and not self.connection.in_transaction with self.lock: if self.current_cursor is not None: yield self.current_cursor return try: self.current_cursor = self.connection.cursor() if start_transaction: self.current_cursor.execute('BEGIN TRANSACTION;') try: yield self.current_cursor if start_transaction: assert self.connection.in_transaction self.current_cursor.execute('COMMIT TRANSACTION;') except: if start_transaction: self.current_cursor.execute('ROLLBACK TRANSACTION;') raise finally: self.current_cursor = None def rebuild_structures(self) -> None: """ Recreation of data structures as done after every recomputation of dependencies as well as at startup. """ with self.cursor(transaction=True) as cursor: cursor.execute( ''' SELECT p.payload_id, p.pattern, p.eval_allowed, p.cors_bypass_allowed, ms.enabled, i.identifier FROM payloads AS p JOIN item_versions AS iv ON p.mapping_item_id = iv.item_version_id JOIN items AS i USING (item_id) JOIN mapping_statuses AS ms USING (item_id); ''' ) rows = cursor.fetchall() new_policy_tree = PolicyTree() ui_factory = policies.WebUIPolicyFactory(builtin=True) web_ui_pattern = 'http*://hkt.mitm.it/***' for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): new_policy_tree = new_policy_tree.register( parsed_pattern, ui_factory ) new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} for row in rows: (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, enabled_status, identifier) = row payload_ref = ConcretePayloadRef(str(payload_id_int), self) previous_data = self.payloads_data.get(payload_ref) if previous_data is not None: token = previous_data.unique_token else: token = secrets.token_urlsafe(8) payload_key = st.PayloadKey(payload_ref, identifier) for parsed_pattern in url_patterns.parse_pattern(pattern): new_policy_tree = register_payload( new_policy_tree, parsed_pattern, payload_key, token ) pattern_path_segments = parsed_pattern.path_segments payload_data = st.PayloadData( payload_ref = payload_ref, explicitly_enabled = enabled_status == 'E', unique_token = token, pattern_path_segments = pattern_path_segments, eval_allowed = eval_allowed, cors_bypass_allowed = cors_bypass_allowed ) new_payloads_data[payload_ref] = payload_data self.policy_tree = new_policy_tree self.payloads_data = new_payloads_data @abstractmethod def recompute_dependencies( self, requirements: t.Iterable[sds.MappingRequirement] = [] ) -> None: """....""" ...