From 4dbbb2aec204a5cccc713e2e2098d6e0a47f8cf6 Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Thu, 25 Aug 2022 11:53:14 +0200 Subject: [proxy] refactor state implementation --- src/hydrilla/proxy/state_impl/base.py | 267 +++++++----------------- src/hydrilla/proxy/state_impl/concrete_state.py | 130 ++++++++---- src/hydrilla/proxy/state_impl/mappings.py | 3 +- src/hydrilla/proxy/state_impl/payloads.py | 137 ++++++++++++ src/hydrilla/proxy/state_impl/repos.py | 17 +- 5 files changed, 309 insertions(+), 245 deletions(-) create mode 100644 src/hydrilla/proxy/state_impl/payloads.py (limited to 'src/hydrilla/proxy/state_impl') diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py index 92833dd..25fd4c5 100644 --- a/src/hydrilla/proxy/state_impl/base.py +++ b/src/hydrilla/proxy/state_impl/base.py @@ -34,7 +34,6 @@ subtype. from __future__ import annotations import sqlite3 -import secrets import threading import dataclasses as dc import typing as t @@ -43,8 +42,6 @@ from pathlib import Path from contextlib import contextmanager from abc import abstractmethod -from immutables import Map - from ... import url_patterns from ... import pattern_tree from .. import simple_dependency_satisfying as sds @@ -52,132 +49,36 @@ from .. import state as st from .. import policies -PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] -PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData] - -@dc.dataclass(frozen=True, unsafe_hash=True) -class ConcretePayloadRef(st.PayloadRef): - state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False) - - def get_data(self) -> st.PayloadData: - try: - return self.state.payloads_data[self] - except KeyError: - raise st.MissingItemError() - - def get_mapping(self) -> st.MappingVersionRef: - raise NotImplementedError() - - def get_script_paths(self) \ - -> t.Iterable[t.Sequence[str]]: - with self.state.cursor() as cursor: - cursor.execute( - ''' - SELECT - i.identifier, fu.name - FROM - payloads AS p - LEFT JOIN resolved_depended_resources AS rdd - USING (payload_id) - LEFT JOIN item_versions AS iv - ON rdd.resource_item_id = iv.item_version_id - LEFT JOIN items AS i - USING (item_id) - LEFT JOIN file_uses AS fu - USING (item_version_id) - WHERE - fu.type = 'W' AND - p.payload_id = ? AND - (fu.idx IS NOT NULL OR rdd.idx IS NULL) - ORDER BY - rdd.idx, fu.idx; - ''', - (self.id,) - ) - - paths: list[t.Sequence[str]] = [] - for resource_identifier, file_name in cursor.fetchall(): - if resource_identifier is None: - # payload found but it had no script files - return () - - paths.append((resource_identifier, *file_name.split('/'))) - - if paths == []: - # payload not found - raise st.MissingItemError() - - return paths - - def get_file_data(self, path: t.Sequence[str]) \ - -> t.Optional[st.FileData]: - if len(path) == 0: - raise st.MissingItemError() - - resource_identifier, *file_name_segments = path - - file_name = '/'.join(file_name_segments) - - with self.state.cursor() as cursor: - cursor.execute( - ''' - SELECT - f.data, fu.mime_type - FROM - payloads AS p - JOIN resolved_depended_resources AS rdd - USING (payload_id) - JOIN item_versions AS iv - ON rdd.resource_item_id = iv.item_version_id - JOIN items AS i - USING (item_id) - JOIN file_uses AS fu - USING (item_version_id) - JOIN files AS f - USING (file_id) - WHERE - p.payload_id = ? AND - i.identifier = ? AND - fu.name = ? AND - fu.type = 'W'; - ''', - (self.id, resource_identifier, file_name) - ) - - result = cursor.fetchall() - - if result == []: - return None - - (data, mime_type), = result +@dc.dataclass(frozen=True) +class PolicyTree(pattern_tree.PatternTree[policies.PolicyFactory]): + SelfType = t.TypeVar('SelfType', bound='PolicyTree') - return st.FileData(type=mime_type, name=file_name, contents=data) + def register_payload( + self: 'SelfType', + pattern: url_patterns.ParsedPattern, + payload_key: st.PayloadKey, + token: str + ) -> 'SelfType': + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_key = payload_key + ) -def register_payload( - policy_tree: PolicyTree, - pattern: url_patterns.ParsedPattern, - payload_key: st.PayloadKey, - token: str -) -> PolicyTree: - """....""" - payload_policy_factory = policies.PayloadPolicyFactory( - builtin = False, - payload_key = payload_key - ) + policy_tree = self.register(pattern, payload_policy_factory) - policy_tree = policy_tree.register(pattern, payload_policy_factory) + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_key = payload_key + ) - resource_policy_factory = policies.PayloadResourcePolicyFactory( - builtin = False, - payload_key = payload_key - ) + policy_tree = policy_tree.register( + pattern.path_append(token, '***'), + resource_policy_factory + ) - policy_tree = policy_tree.register( - pattern.path_append(token, '***'), - resource_policy_factory - ) + return policy_tree - return policy_tree +PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData] # mypy needs to be corrected: # https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 @@ -187,6 +88,7 @@ class HaketiloStateWithFields(st.HaketiloState): store_dir: Path connection: sqlite3.Connection current_cursor: t.Optional[sqlite3.Cursor] = None + #settings: st.HaketiloGlobalSettings policy_tree: PolicyTree = PolicyTree() @@ -224,85 +126,64 @@ class HaketiloStateWithFields(st.HaketiloState): finally: self.current_cursor = None - def rebuild_structures(self) -> None: - """ - Recreation of data structures as done after every recomputation of - dependencies as well as at startup. - """ - with self.cursor(transaction=True) as cursor: - cursor.execute( - ''' - SELECT - p.payload_id, p.pattern, p.eval_allowed, - p.cors_bypass_allowed, - ms.enabled, - i.identifier - FROM - payloads AS p - JOIN item_versions AS iv - ON p.mapping_item_id = iv.item_version_id - JOIN items AS i USING (item_id) - JOIN mapping_statuses AS ms USING (item_id); - ''' - ) - - rows = cursor.fetchall() - - new_policy_tree = PolicyTree() + def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy: + """....""" + with self.lock: + policy_tree = self.policy_tree - ui_factory = policies.WebUIPolicyFactory(builtin=True) - web_ui_pattern = 'http*://hkt.mitm.it/***' - for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): - new_policy_tree = new_policy_tree.register( - parsed_pattern, - ui_factory + try: + best_priority: int = 0 + best_policy: t.Optional[policies.Policy] = None + + for factories_set in policy_tree.search(url): + for stored_factory in sorted(factories_set): + factory = stored_factory.item + + policy = factory.make_policy(self) + + if policy.priority > best_priority: + best_priority = policy.priority + best_policy = policy + except Exception as e: + return policies.ErrorBlockPolicy( + builtin = True, + error = e ) - new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - - for row in rows: - (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, - enabled_status, - identifier) = row - - payload_ref = ConcretePayloadRef(str(payload_id_int), self) - - previous_data = self.payloads_data.get(payload_ref) - if previous_data is not None: - token = previous_data.unique_token - else: - token = secrets.token_urlsafe(8) - - payload_key = st.PayloadKey(payload_ref, identifier) - - for parsed_pattern in url_patterns.parse_pattern(pattern): - new_policy_tree = register_payload( - new_policy_tree, - parsed_pattern, - payload_key, - token - ) + if best_policy is not None: + return best_policy - pattern_path_segments = parsed_pattern.path_segments + if self.get_settings().default_allow_scripts: + return policies.FallbackAllowPolicy() + else: + return policies.FallbackBlockPolicy() - payload_data = st.PayloadData( - payload_ref = payload_ref, - explicitly_enabled = enabled_status == 'E', - unique_token = token, - pattern_path_segments = pattern_path_segments, - eval_allowed = eval_allowed, - cors_bypass_allowed = cors_bypass_allowed - ) - - new_payloads_data[payload_ref] = payload_data - - self.policy_tree = new_policy_tree - self.payloads_data = new_payloads_data + @abstractmethod + def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None: + ... @abstractmethod def recompute_dependencies( self, - requirements: t.Iterable[sds.MappingRequirement] = [] + requirements: t.Iterable[sds.MappingRequirement] = [], + prune_orphans: bool = False ) -> None: """....""" ... + + @abstractmethod + def pull_missing_files(self) -> None: + """ + This function checks which packages marked as installed are missing + files in the database. It attempts to restore integrity by downloading + the files from their respective repositories. + """ + ... + + @abstractmethod + def rebuild_structures(self) -> None: + """ + Recreation of data structures as done after every recomputation of + dependencies as well as at startup. + """ + ... diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index 9e56bff..0de67e0 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -33,6 +33,7 @@ and resources. from __future__ import annotations import sqlite3 +import secrets import typing as t import dataclasses as dc @@ -48,6 +49,7 @@ from .. import simple_dependency_satisfying as sds from . import base from . import mappings from . import repos +from . import payloads from . import _operations @@ -120,23 +122,35 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): finally: cursor.close() - def import_packages(self, malcontent_path: Path) -> None: - with self.cursor(transaction=True) as cursor: + def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None: + with self.cursor(transaction=(repo_id == 1)) as cursor: + # This method without the repo_id argument exposed is part of the + # state API. As such, calls with repo_id = 1 (imports of local + # semirepo packages) create a new transaction. Calls with different + # values of repo_id are assumed to originate from within the state + # implementation code and expect an existing transaction. Here, we + # verify the transaction is indeed present. + assert self.connection.in_transaction + _operations._load_packages_no_state_update( cursor = cursor, malcontent_path = malcontent_path, - repo_id = 1 + repo_id = repo_id ) self.rebuild_structures() def recompute_dependencies( self, - extra_requirements: t.Iterable[sds.MappingRequirement] = [] + extra_requirements: t.Iterable[sds.MappingRequirement] = [], + prune_orphans: bool = False, ) -> None: with self.cursor() as cursor: assert self.connection.in_transaction + if prune_orphans: + _operations.prune_packages(cursor) + _operations._recompute_dependencies_no_state_update( cursor = cursor, extra_requirements = extra_requirements @@ -144,6 +158,82 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): self.rebuild_structures() + def pull_missing_files(self) -> None: + with self.cursor() as cursor: + assert self.connection.in_transaction + + _operations.pull_missing_files(cursor) + + def rebuild_structures(self) -> None: + with self.cursor(transaction=True) as cursor: + cursor.execute( + ''' + SELECT + p.payload_id, p.pattern, p.eval_allowed, + p.cors_bypass_allowed, + ms.enabled, + i.identifier + FROM + payloads AS p + JOIN item_versions AS iv + ON p.mapping_item_id = iv.item_version_id + JOIN items AS i USING (item_id) + JOIN mapping_statuses AS ms USING (item_id); + ''' + ) + + rows = cursor.fetchall() + + new_policy_tree = base.PolicyTree() + + ui_factory = policies.WebUIPolicyFactory(builtin=True) + web_ui_pattern = 'http*://hkt.mitm.it/***' + for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): + new_policy_tree = new_policy_tree.register( + parsed_pattern, + ui_factory + ) + + new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} + + for row in rows: + (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, + enabled_status, + identifier) = row + + payload_ref = payloads.ConcretePayloadRef(str(payload_id_int), self) + + previous_data = self.payloads_data.get(payload_ref) + if previous_data is not None: + token = previous_data.unique_token + else: + token = secrets.token_urlsafe(8) + + payload_key = st.PayloadKey(payload_ref, identifier) + + for parsed_pattern in url_patterns.parse_pattern(pattern): + new_policy_tree = new_policy_tree.register_payload( + parsed_pattern, + payload_key, + token + ) + + pattern_path_segments = parsed_pattern.path_segments + + payload_data = st.PayloadData( + payload_ref = payload_ref, + explicitly_enabled = enabled_status == 'E', + unique_token = token, + pattern_path_segments = pattern_path_segments, + eval_allowed = eval_allowed, + cors_bypass_allowed = cors_bypass_allowed + ) + + new_payloads_data[payload_ref] = payload_data + + self.policy_tree = new_policy_tree + self.payloads_data = new_payloads_data + def repo_store(self) -> st.RepoStore: return repos.ConcreteRepoStore(self) @@ -182,38 +272,6 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): ) -> None: raise NotImplementedError() - def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy: - """....""" - with self.lock: - policy_tree = self.policy_tree - - try: - best_priority: int = 0 - best_policy: t.Optional[policies.Policy] = None - - for factories_set in policy_tree.search(url): - for stored_factory in sorted(factories_set): - factory = stored_factory.item - - policy = factory.make_policy(self) - - if policy.priority > best_priority: - best_priority = policy.priority - best_policy = policy - except Exception as e: - return policies.ErrorBlockPolicy( - builtin = True, - error = e - ) - - if best_policy is not None: - return best_policy - - if self.get_settings().default_allow_scripts: - return policies.FallbackAllowPolicy() - else: - return policies.FallbackBlockPolicy() - @staticmethod def make(store_dir: Path) -> 'ConcreteHaketiloState': connection = sqlite3.connect( diff --git a/src/hydrilla/proxy/state_impl/mappings.py b/src/hydrilla/proxy/state_impl/mappings.py index 7d08e58..8a401b8 100644 --- a/src/hydrilla/proxy/state_impl/mappings.py +++ b/src/hydrilla/proxy/state_impl/mappings.py @@ -38,7 +38,6 @@ import dataclasses as dc from ... import item_infos from .. import state as st from . import base -from . import _operations @dc.dataclass(frozen=True, unsafe_hash=True) @@ -227,7 +226,7 @@ class ConcreteMappingVersionRef(st.MappingVersionRef): self._set_installed_status(cursor, st.InstalledStatus.INSTALLED) - _operations.pull_missing_files(cursor) + self.state.pull_missing_files() def uninstall(self) -> None: raise NotImplementedError() diff --git a/src/hydrilla/proxy/state_impl/payloads.py b/src/hydrilla/proxy/state_impl/payloads.py new file mode 100644 index 0000000..2bee11f --- /dev/null +++ b/src/hydrilla/proxy/state_impl/payloads.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (PayloadRef subtype). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +This module provides an interface to interact with payloads inside Haketilo. +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import dataclasses as dc +import typing as t + +from .. import state as st +from . import base + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcretePayloadRef(st.PayloadRef): + state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False) + + def get_data(self) -> st.PayloadData: + try: + return self.state.payloads_data[self] + except KeyError: + raise st.MissingItemError() + + def get_mapping(self) -> st.MappingVersionRef: + raise NotImplementedError() + + def get_script_paths(self) \ + -> t.Iterable[t.Sequence[str]]: + with self.state.cursor() as cursor: + cursor.execute( + ''' + SELECT + i.identifier, fu.name + FROM + payloads AS p + LEFT JOIN resolved_depended_resources AS rdd + USING (payload_id) + LEFT JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + LEFT JOIN items AS i + USING (item_id) + LEFT JOIN file_uses AS fu + USING (item_version_id) + WHERE + fu.type = 'W' AND + p.payload_id = ? AND + (fu.idx IS NOT NULL OR rdd.idx IS NULL) + ORDER BY + rdd.idx, fu.idx; + ''', + (self.id,) + ) + + paths: list[t.Sequence[str]] = [] + for resource_identifier, file_name in cursor.fetchall(): + if resource_identifier is None: + # payload found but it had no script files + return () + + paths.append((resource_identifier, *file_name.split('/'))) + + if paths == []: + # payload not found + raise st.MissingItemError() + + return paths + + def get_file_data(self, path: t.Sequence[str]) \ + -> t.Optional[st.FileData]: + if len(path) == 0: + raise st.MissingItemError() + + resource_identifier, *file_name_segments = path + + file_name = '/'.join(file_name_segments) + + with self.state.cursor() as cursor: + cursor.execute( + ''' + SELECT + f.data, fu.mime_type + FROM + payloads AS p + JOIN resolved_depended_resources AS rdd + USING (payload_id) + JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + JOIN items AS i + USING (item_id) + JOIN file_uses AS fu + USING (item_version_id) + JOIN files AS f + USING (file_id) + WHERE + p.payload_id = ? AND + i.identifier = ? AND + fu.name = ? AND + fu.type = 'W'; + ''', + (self.id, resource_identifier, file_name) + ) + + result = cursor.fetchall() + + if result == []: + return None + + (data, mime_type), = result + + return st.FileData(type=mime_type, name=file_name, contents=data) diff --git a/src/hydrilla/proxy/state_impl/repos.py b/src/hydrilla/proxy/state_impl/repos.py index 346e113..838698c 100644 --- a/src/hydrilla/proxy/state_impl/repos.py +++ b/src/hydrilla/proxy/state_impl/repos.py @@ -51,7 +51,6 @@ from ... import versions from .. import state as st from .. import simple_dependency_satisfying as sds from . import base -from . import _operations repo_name_regex = re.compile(r''' @@ -194,8 +193,6 @@ class ConcreteRepoRef(st.RepoRef): (self.id,) ) - _operations.prune_packages(cursor) - # For mappings explicitly enabled by the user (+ all mappings they # recursively depend on) let's make sure that their exact same # versions will be enabled after the change. @@ -217,7 +214,7 @@ class ConcreteRepoRef(st.RepoRef): req = sds.MappingVersionRequirement(info.identifier, info) requirements.append(req) - self.state.recompute_dependencies(requirements) + self.state.recompute_dependencies(requirements, prune_orphans=True) def update( self, @@ -260,7 +257,7 @@ class ConcreteRepoRef(st.RepoRef): self.state.recompute_dependencies() - def refresh(self) -> st.RepoIterationRef: + def refresh(self) -> None: with self.state.cursor(transaction=True) as cursor: ensure_repo_not_deleted(cursor, self.id) @@ -274,15 +271,7 @@ class ConcreteRepoRef(st.RepoRef): with tempfile.TemporaryDirectory() as tmpdir_str: tmpdir = Path(tmpdir_str) sync_remote_repo_definitions(repo_url, tmpdir) - new_iteration_id = _operations._load_packages_no_state_update( - cursor = cursor, - malcontent_path = tmpdir, - repo_id = int(self.id) - ) - - self.state.rebuild_structures() - - return ConcreteRepoIterationRef(str(new_iteration_id), self.state) + self.state.import_items(tmpdir, int(self.id)) def get_display_info(self) -> st.RepoDisplayInfo: with self.state.cursor() as cursor: -- cgit v1.2.3