From 4dbbb2aec204a5cccc713e2e2098d6e0a47f8cf6 Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Thu, 25 Aug 2022 11:53:14 +0200 Subject: [proxy] refactor state implementation --- src/hydrilla/proxy/state_impl/base.py | 267 ++++++++++------------------------ 1 file changed, 74 insertions(+), 193 deletions(-) (limited to 'src/hydrilla/proxy/state_impl/base.py') diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py index 92833dd..25fd4c5 100644 --- a/src/hydrilla/proxy/state_impl/base.py +++ b/src/hydrilla/proxy/state_impl/base.py @@ -34,7 +34,6 @@ subtype. from __future__ import annotations import sqlite3 -import secrets import threading import dataclasses as dc import typing as t @@ -43,8 +42,6 @@ from pathlib import Path from contextlib import contextmanager from abc import abstractmethod -from immutables import Map - from ... import url_patterns from ... import pattern_tree from .. import simple_dependency_satisfying as sds @@ -52,132 +49,36 @@ from .. import state as st from .. import policies -PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] -PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData] - -@dc.dataclass(frozen=True, unsafe_hash=True) -class ConcretePayloadRef(st.PayloadRef): - state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False) - - def get_data(self) -> st.PayloadData: - try: - return self.state.payloads_data[self] - except KeyError: - raise st.MissingItemError() - - def get_mapping(self) -> st.MappingVersionRef: - raise NotImplementedError() - - def get_script_paths(self) \ - -> t.Iterable[t.Sequence[str]]: - with self.state.cursor() as cursor: - cursor.execute( - ''' - SELECT - i.identifier, fu.name - FROM - payloads AS p - LEFT JOIN resolved_depended_resources AS rdd - USING (payload_id) - LEFT JOIN item_versions AS iv - ON rdd.resource_item_id = iv.item_version_id - LEFT JOIN items AS i - USING (item_id) - LEFT JOIN file_uses AS fu - USING (item_version_id) - WHERE - fu.type = 'W' AND - p.payload_id = ? AND - (fu.idx IS NOT NULL OR rdd.idx IS NULL) - ORDER BY - rdd.idx, fu.idx; - ''', - (self.id,) - ) - - paths: list[t.Sequence[str]] = [] - for resource_identifier, file_name in cursor.fetchall(): - if resource_identifier is None: - # payload found but it had no script files - return () - - paths.append((resource_identifier, *file_name.split('/'))) - - if paths == []: - # payload not found - raise st.MissingItemError() - - return paths - - def get_file_data(self, path: t.Sequence[str]) \ - -> t.Optional[st.FileData]: - if len(path) == 0: - raise st.MissingItemError() - - resource_identifier, *file_name_segments = path - - file_name = '/'.join(file_name_segments) - - with self.state.cursor() as cursor: - cursor.execute( - ''' - SELECT - f.data, fu.mime_type - FROM - payloads AS p - JOIN resolved_depended_resources AS rdd - USING (payload_id) - JOIN item_versions AS iv - ON rdd.resource_item_id = iv.item_version_id - JOIN items AS i - USING (item_id) - JOIN file_uses AS fu - USING (item_version_id) - JOIN files AS f - USING (file_id) - WHERE - p.payload_id = ? AND - i.identifier = ? AND - fu.name = ? AND - fu.type = 'W'; - ''', - (self.id, resource_identifier, file_name) - ) - - result = cursor.fetchall() - - if result == []: - return None - - (data, mime_type), = result +@dc.dataclass(frozen=True) +class PolicyTree(pattern_tree.PatternTree[policies.PolicyFactory]): + SelfType = t.TypeVar('SelfType', bound='PolicyTree') - return st.FileData(type=mime_type, name=file_name, contents=data) + def register_payload( + self: 'SelfType', + pattern: url_patterns.ParsedPattern, + payload_key: st.PayloadKey, + token: str + ) -> 'SelfType': + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_key = payload_key + ) -def register_payload( - policy_tree: PolicyTree, - pattern: url_patterns.ParsedPattern, - payload_key: st.PayloadKey, - token: str -) -> PolicyTree: - """....""" - payload_policy_factory = policies.PayloadPolicyFactory( - builtin = False, - payload_key = payload_key - ) + policy_tree = self.register(pattern, payload_policy_factory) - policy_tree = policy_tree.register(pattern, payload_policy_factory) + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_key = payload_key + ) - resource_policy_factory = policies.PayloadResourcePolicyFactory( - builtin = False, - payload_key = payload_key - ) + policy_tree = policy_tree.register( + pattern.path_append(token, '***'), + resource_policy_factory + ) - policy_tree = policy_tree.register( - pattern.path_append(token, '***'), - resource_policy_factory - ) + return policy_tree - return policy_tree +PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData] # mypy needs to be corrected: # https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 @@ -187,6 +88,7 @@ class HaketiloStateWithFields(st.HaketiloState): store_dir: Path connection: sqlite3.Connection current_cursor: t.Optional[sqlite3.Cursor] = None + #settings: st.HaketiloGlobalSettings policy_tree: PolicyTree = PolicyTree() @@ -224,85 +126,64 @@ class HaketiloStateWithFields(st.HaketiloState): finally: self.current_cursor = None - def rebuild_structures(self) -> None: - """ - Recreation of data structures as done after every recomputation of - dependencies as well as at startup. - """ - with self.cursor(transaction=True) as cursor: - cursor.execute( - ''' - SELECT - p.payload_id, p.pattern, p.eval_allowed, - p.cors_bypass_allowed, - ms.enabled, - i.identifier - FROM - payloads AS p - JOIN item_versions AS iv - ON p.mapping_item_id = iv.item_version_id - JOIN items AS i USING (item_id) - JOIN mapping_statuses AS ms USING (item_id); - ''' - ) - - rows = cursor.fetchall() - - new_policy_tree = PolicyTree() + def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy: + """....""" + with self.lock: + policy_tree = self.policy_tree - ui_factory = policies.WebUIPolicyFactory(builtin=True) - web_ui_pattern = 'http*://hkt.mitm.it/***' - for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): - new_policy_tree = new_policy_tree.register( - parsed_pattern, - ui_factory + try: + best_priority: int = 0 + best_policy: t.Optional[policies.Policy] = None + + for factories_set in policy_tree.search(url): + for stored_factory in sorted(factories_set): + factory = stored_factory.item + + policy = factory.make_policy(self) + + if policy.priority > best_priority: + best_priority = policy.priority + best_policy = policy + except Exception as e: + return policies.ErrorBlockPolicy( + builtin = True, + error = e ) - new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - - for row in rows: - (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, - enabled_status, - identifier) = row - - payload_ref = ConcretePayloadRef(str(payload_id_int), self) - - previous_data = self.payloads_data.get(payload_ref) - if previous_data is not None: - token = previous_data.unique_token - else: - token = secrets.token_urlsafe(8) - - payload_key = st.PayloadKey(payload_ref, identifier) - - for parsed_pattern in url_patterns.parse_pattern(pattern): - new_policy_tree = register_payload( - new_policy_tree, - parsed_pattern, - payload_key, - token - ) + if best_policy is not None: + return best_policy - pattern_path_segments = parsed_pattern.path_segments + if self.get_settings().default_allow_scripts: + return policies.FallbackAllowPolicy() + else: + return policies.FallbackBlockPolicy() - payload_data = st.PayloadData( - payload_ref = payload_ref, - explicitly_enabled = enabled_status == 'E', - unique_token = token, - pattern_path_segments = pattern_path_segments, - eval_allowed = eval_allowed, - cors_bypass_allowed = cors_bypass_allowed - ) - - new_payloads_data[payload_ref] = payload_data - - self.policy_tree = new_policy_tree - self.payloads_data = new_payloads_data + @abstractmethod + def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None: + ... @abstractmethod def recompute_dependencies( self, - requirements: t.Iterable[sds.MappingRequirement] = [] + requirements: t.Iterable[sds.MappingRequirement] = [], + prune_orphans: bool = False ) -> None: """....""" ... + + @abstractmethod + def pull_missing_files(self) -> None: + """ + This function checks which packages marked as installed are missing + files in the database. It attempts to restore integrity by downloading + the files from their respective repositories. + """ + ... + + @abstractmethod + def rebuild_structures(self) -> None: + """ + Recreation of data structures as done after every recomputation of + dependencies as well as at startup. + """ + ... -- cgit v1.2.3