diff options
Diffstat (limited to 'src/hydrilla/proxy/state_impl/concrete_state.py')
-rw-r--r-- | src/hydrilla/proxy/state_impl/concrete_state.py | 130 |
1 files changed, 94 insertions, 36 deletions
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index 9e56bff..0de67e0 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -33,6 +33,7 @@ and resources. from __future__ import annotations import sqlite3 +import secrets import typing as t import dataclasses as dc @@ -48,6 +49,7 @@ from .. import simple_dependency_satisfying as sds from . import base from . import mappings from . import repos +from . import payloads from . import _operations @@ -120,23 +122,35 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): finally: cursor.close() - def import_packages(self, malcontent_path: Path) -> None: - with self.cursor(transaction=True) as cursor: + def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None: + with self.cursor(transaction=(repo_id == 1)) as cursor: + # This method without the repo_id argument exposed is part of the + # state API. As such, calls with repo_id = 1 (imports of local + # semirepo packages) create a new transaction. Calls with different + # values of repo_id are assumed to originate from within the state + # implementation code and expect an existing transaction. Here, we + # verify the transaction is indeed present. + assert self.connection.in_transaction + _operations._load_packages_no_state_update( cursor = cursor, malcontent_path = malcontent_path, - repo_id = 1 + repo_id = repo_id ) self.rebuild_structures() def recompute_dependencies( self, - extra_requirements: t.Iterable[sds.MappingRequirement] = [] + extra_requirements: t.Iterable[sds.MappingRequirement] = [], + prune_orphans: bool = False, ) -> None: with self.cursor() as cursor: assert self.connection.in_transaction + if prune_orphans: + _operations.prune_packages(cursor) + _operations._recompute_dependencies_no_state_update( cursor = cursor, extra_requirements = extra_requirements @@ -144,6 +158,82 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): self.rebuild_structures() + def pull_missing_files(self) -> None: + with self.cursor() as cursor: + assert self.connection.in_transaction + + _operations.pull_missing_files(cursor) + + def rebuild_structures(self) -> None: + with self.cursor(transaction=True) as cursor: + cursor.execute( + ''' + SELECT + p.payload_id, p.pattern, p.eval_allowed, + p.cors_bypass_allowed, + ms.enabled, + i.identifier + FROM + payloads AS p + JOIN item_versions AS iv + ON p.mapping_item_id = iv.item_version_id + JOIN items AS i USING (item_id) + JOIN mapping_statuses AS ms USING (item_id); + ''' + ) + + rows = cursor.fetchall() + + new_policy_tree = base.PolicyTree() + + ui_factory = policies.WebUIPolicyFactory(builtin=True) + web_ui_pattern = 'http*://hkt.mitm.it/***' + for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): + new_policy_tree = new_policy_tree.register( + parsed_pattern, + ui_factory + ) + + new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} + + for row in rows: + (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, + enabled_status, + identifier) = row + + payload_ref = payloads.ConcretePayloadRef(str(payload_id_int), self) + + previous_data = self.payloads_data.get(payload_ref) + if previous_data is not None: + token = previous_data.unique_token + else: + token = secrets.token_urlsafe(8) + + payload_key = st.PayloadKey(payload_ref, identifier) + + for parsed_pattern in url_patterns.parse_pattern(pattern): + new_policy_tree = new_policy_tree.register_payload( + parsed_pattern, + payload_key, + token + ) + + pattern_path_segments = parsed_pattern.path_segments + + payload_data = st.PayloadData( + payload_ref = payload_ref, + explicitly_enabled = enabled_status == 'E', + unique_token = token, + pattern_path_segments = pattern_path_segments, + eval_allowed = eval_allowed, + cors_bypass_allowed = cors_bypass_allowed + ) + + new_payloads_data[payload_ref] = payload_data + + self.policy_tree = new_policy_tree + self.payloads_data = new_payloads_data + def repo_store(self) -> st.RepoStore: return repos.ConcreteRepoStore(self) @@ -182,38 +272,6 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): ) -> None: raise NotImplementedError() - def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy: - """....""" - with self.lock: - policy_tree = self.policy_tree - - try: - best_priority: int = 0 - best_policy: t.Optional[policies.Policy] = None - - for factories_set in policy_tree.search(url): - for stored_factory in sorted(factories_set): - factory = stored_factory.item - - policy = factory.make_policy(self) - - if policy.priority > best_priority: - best_priority = policy.priority - best_policy = policy - except Exception as e: - return policies.ErrorBlockPolicy( - builtin = True, - error = e - ) - - if best_policy is not None: - return best_policy - - if self.get_settings().default_allow_scripts: - return policies.FallbackAllowPolicy() - else: - return policies.FallbackBlockPolicy() - @staticmethod def make(store_dir: Path) -> 'ConcreteHaketiloState': connection = sqlite3.connect( |