diff options
author | Wojtek Kosior <koszko@koszko.org> | 2022-08-22 12:52:59 +0200 |
---|---|---|
committer | Wojtek Kosior <koszko@koszko.org> | 2022-09-09 13:50:40 +0200 |
commit | 33606b647eec91128097ea4b5fb59834d1dd9943 (patch) | |
tree | b48fa711001b884f5ccce51c7767a1806e090748 /src/hydrilla/proxy/state_impl/concrete_state.py | |
parent | 12ff72506af9b3c8cb1ce604d86232600a26e2c2 (diff) | |
download | haketilo-hydrilla-33606b647eec91128097ea4b5fb59834d1dd9943.tar.gz haketilo-hydrilla-33606b647eec91128097ea4b5fb59834d1dd9943.zip |
allow pulling packages from remote repository
Diffstat (limited to 'src/hydrilla/proxy/state_impl/concrete_state.py')
-rw-r--r-- | src/hydrilla/proxy/state_impl/concrete_state.py | 322 |
1 files changed, 20 insertions, 302 deletions
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index 525a702..6a59a75 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -32,28 +32,23 @@ and resources. # Enable using with Python 3.7. from __future__ import annotations -import secrets -import io -import hashlib +import sqlite3 import typing as t import dataclasses as dc from pathlib import Path -import sqlite3 - from ...exceptions import HaketiloException from ...translations import smart_gettext as _ -from ... import pattern_tree from ... import url_patterns from ... import item_infos -from ..simple_dependency_satisfying import compute_payloads, ComputedPayload from .. import state as st from .. import policies +from .. import simple_dependency_satisfying as sds from . import base from . import mappings from . import repos -from .load_packages import load_packages +from . import _operations here = Path(__file__).resolve().parent @@ -73,169 +68,14 @@ class ConcreteResourceRef(st.ResourceRef): class ConcreteResourceVersionRef(st.ResourceVersionRef): pass - -@dc.dataclass(frozen=True, unsafe_hash=True) -class ConcretePayloadRef(st.PayloadRef): - state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False) - - def get_data(self) -> st.PayloadData: - try: - return self.state.payloads_data[self] - except KeyError: - raise st.MissingItemError() - - def get_mapping(self) -> st.MappingVersionRef: - raise NotImplementedError() - - def get_script_paths(self) \ - -> t.Iterable[t.Sequence[str]]: - with self.state.cursor() as cursor: - cursor.execute( - ''' - SELECT - i.identifier, fu.name - FROM - payloads AS p - LEFT JOIN resolved_depended_resources AS rdd - USING (payload_id) - LEFT JOIN item_versions AS iv - ON rdd.resource_item_id = iv.item_version_id - LEFT JOIN items AS i - USING (item_id) - LEFT JOIN file_uses AS fu - USING (item_version_id) - WHERE - fu.type = 'W' AND - p.payload_id = ? AND - (fu.idx IS NOT NULL OR rdd.idx IS NULL) - ORDER BY - rdd.idx, fu.idx; - ''', - (self.id,) - ) - - paths: list[t.Sequence[str]] = [] - for resource_identifier, file_name in cursor.fetchall(): - if resource_identifier is None: - # payload found but it had no script files - return () - - paths.append((resource_identifier, *file_name.split('/'))) - - if paths == []: - # payload not found - raise st.MissingItemError() - - return paths - - def get_file_data(self, path: t.Sequence[str]) \ - -> t.Optional[st.FileData]: - if len(path) == 0: - raise st.MissingItemError() - - resource_identifier, *file_name_segments = path - - file_name = '/'.join(file_name_segments) - - with self.state.cursor() as cursor: - cursor.execute( - ''' - SELECT - f.data, fu.mime_type - FROM - payloads AS p - JOIN resolved_depended_resources AS rdd - USING (payload_id) - JOIN item_versions AS iv - ON rdd.resource_item_id = iv.item_version_id - JOIN items AS i - USING (item_id) - JOIN file_uses AS fu - USING (item_version_id) - JOIN files AS f - USING (file_id) - WHERE - p.payload_id = ? AND - i.identifier = ? AND - fu.name = ? AND - fu.type = 'W'; - ''', - (self.id, resource_identifier, file_name) - ) - - result = cursor.fetchall() - - if result == []: - return None - - (data, mime_type), = result - - return st.FileData(type=mime_type, name=file_name, contents=data) - -def register_payload( - policy_tree: base.PolicyTree, - pattern: url_patterns.ParsedPattern, - payload_key: st.PayloadKey, - token: str -) -> base.PolicyTree: - """....""" - payload_policy_factory = policies.PayloadPolicyFactory( - builtin = False, - payload_key = payload_key - ) - - policy_tree = policy_tree.register(pattern, payload_policy_factory) - - resource_policy_factory = policies.PayloadResourcePolicyFactory( - builtin = False, - payload_key = payload_key - ) - - policy_tree = policy_tree.register( - pattern.path_append(token, '***'), - resource_policy_factory - ) - - return policy_tree - -AnyInfoVar = t.TypeVar( - 'AnyInfoVar', - item_infos.ResourceInfo, - item_infos.MappingInfo -) - -def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \ - -> t.Mapping[AnyInfoVar, int]: - cursor.execute( - ''' - SELECT - i.item_id, iv.definition, r.name, ri.iteration - FROM - item_versions AS iv - JOIN items AS i USING (item_id) - JOIN repo_iterations AS ri USING (repo_iteration_id) - JOIN repos AS r USING (repo_id) - WHERE - i.type = ?; - ''', - (info_type.type_name[0].upper(),) - ) - - result: dict[AnyInfoVar, int] = {} - - for item_id, definition, repo_name, repo_iteration in cursor.fetchall(): - definition_io = io.StringIO(definition) - info = info_type.load(definition_io, repo_name, repo_iteration) - result[info] = item_id - - return result - @dc.dataclass class ConcreteHaketiloState(base.HaketiloStateWithFields): def __post_init__(self) -> None: + sqlite3.enable_callback_tracebacks(True) + self._prepare_database() - self._rebuild_structures() + self.rebuild_structures() def _prepare_database(self) -> None: """....""" @@ -282,147 +122,25 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): def import_packages(self, malcontent_path: Path) -> None: with self.cursor(transaction=True) as cursor: - load_packages(self, cursor, malcontent_path) - self.recompute_payloads(cursor) - - def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: - assert self.connection.in_transaction - - resources = get_infos_of_type(cursor, item_infos.ResourceInfo) - mappings = get_infos_of_type(cursor, item_infos.MappingInfo) - - payloads = compute_payloads(resources.keys(), mappings.keys()) - - payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - - cursor.execute('DELETE FROM payloads;') - - for mapping_info, by_pattern in payloads.items(): - for num, (pattern, payload) in enumerate(by_pattern.items()): - cursor.execute( - ''' - INSERT INTO payloads( - mapping_item_id, - pattern, - eval_allowed, - cors_bypass_allowed - ) - VALUES (?, ?, ?, ?); - ''', - ( - mappings[mapping_info], - pattern.orig_url, - payload.allows_eval, - payload.allows_cors_bypass - ) - ) + _operations.load_packages(cursor, malcontent_path, 1) + raise NotImplementedError() + _operations.prune_packages(cursor) - cursor.execute( - ''' - SELECT - payload_id - FROM - payloads - WHERE - mapping_item_id = ? AND pattern = ?; - ''', - (mappings[mapping_info], pattern.orig_url) - ) + self.recompute_dependencies() - (payload_id_int,), = cursor.fetchall() - - for res_num, resource_info in enumerate(payload.resources): - cursor.execute( - ''' - INSERT INTO resolved_depended_resources( - payload_id, - resource_item_id, - idx - ) - VALUES(?, ?, ?); - ''', - (payload_id_int, resources[resource_info], res_num) - ) - - self._rebuild_structures(cursor) - - def _rebuild_structures(self, cursor: t.Optional[sqlite3.Cursor] = None) \ - -> None: - """ - Recreation of data structures as done after every recomputation of - dependencies as well as at startup. - """ - if cursor is None: - with self.cursor() as new_cursor: - return self._rebuild_structures(new_cursor) - - cursor.execute( - ''' - SELECT - p.payload_id, p.pattern, p.eval_allowed, - p.cors_bypass_allowed, - ms.enabled, - i.identifier - FROM - payloads AS p - JOIN item_versions AS iv - ON p.mapping_item_id = iv.item_version_id - JOIN items AS i USING (item_id) - JOIN mapping_statuses AS ms USING (item_id); - ''' - ) - - new_policy_tree = base.PolicyTree() - - ui_factory = policies.WebUIPolicyFactory(builtin=True) - web_ui_pattern = 'http*://hkt.mitm.it/***' - for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): - new_policy_tree = new_policy_tree.register( - parsed_pattern, - ui_factory - ) - - new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - - for row in cursor.fetchall(): - (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, - enabled_status, - identifier) = row - - payload_ref = ConcretePayloadRef(str(payload_id_int), self) - - previous_data = self.payloads_data.get(payload_ref) - if previous_data is not None: - token = previous_data.unique_token - else: - token = secrets.token_urlsafe(8) - - payload_key = st.PayloadKey(payload_ref, identifier) - - for parsed_pattern in url_patterns.parse_pattern(pattern): - new_policy_tree = register_payload( - new_policy_tree, - parsed_pattern, - payload_key, - token - ) - - pattern_path_segments = parsed_pattern.path_segments + def recompute_dependencies( + self, + extra_requirements: t.Iterable[sds.MappingRequirement] = [] + ) -> None: + with self.cursor() as cursor: + assert self.connection.in_transaction - payload_data = st.PayloadData( - payload_ref = payload_ref, - #explicitly_enabled = enabled_status == 'E', - explicitly_enabled = True, - unique_token = token, - pattern_path_segments = pattern_path_segments, - eval_allowed = eval_allowed, - cors_bypass_allowed = cors_bypass_allowed + _operations._recompute_dependencies_no_state_update( + cursor, + extra_requirements ) - new_payloads_data[payload_ref] = payload_data - - self.policy_tree = new_policy_tree - self.payloads_data = new_payloads_data + self.rebuild_structures() def repo_store(self) -> st.RepoStore: return repos.ConcreteRepoStore(self) |