From 8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5 Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Mon, 22 Aug 2022 12:52:59 +0200 Subject: allow pulling packages from remote repository --- src/hydrilla/proxy/state_impl/base.py | 219 ++++++++++++++++++++++++++++++++-- 1 file changed, 212 insertions(+), 7 deletions(-) (limited to 'src/hydrilla/proxy/state_impl/base.py') diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py index 1ae08cb..92833dd 100644 --- a/src/hydrilla/proxy/state_impl/base.py +++ b/src/hydrilla/proxy/state_impl/base.py @@ -34,34 +34,160 @@ subtype. from __future__ import annotations import sqlite3 +import secrets import threading import dataclasses as dc import typing as t from pathlib import Path from contextlib import contextmanager - -import sqlite3 +from abc import abstractmethod from immutables import Map +from ... import url_patterns from ... import pattern_tree -from .. import state +from .. import simple_dependency_satisfying as sds +from .. import state as st from .. import policies PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] -PayloadsData = t.Mapping[state.PayloadRef, state.PayloadData] +PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData] + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcretePayloadRef(st.PayloadRef): + state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False) + + def get_data(self) -> st.PayloadData: + try: + return self.state.payloads_data[self] + except KeyError: + raise st.MissingItemError() + + def get_mapping(self) -> st.MappingVersionRef: + raise NotImplementedError() + + def get_script_paths(self) \ + -> t.Iterable[t.Sequence[str]]: + with self.state.cursor() as cursor: + cursor.execute( + ''' + SELECT + i.identifier, fu.name + FROM + payloads AS p + LEFT JOIN resolved_depended_resources AS rdd + USING (payload_id) + LEFT JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + LEFT JOIN items AS i + USING (item_id) + LEFT JOIN file_uses AS fu + USING (item_version_id) + WHERE + fu.type = 'W' AND + p.payload_id = ? AND + (fu.idx IS NOT NULL OR rdd.idx IS NULL) + ORDER BY + rdd.idx, fu.idx; + ''', + (self.id,) + ) + + paths: list[t.Sequence[str]] = [] + for resource_identifier, file_name in cursor.fetchall(): + if resource_identifier is None: + # payload found but it had no script files + return () + + paths.append((resource_identifier, *file_name.split('/'))) + + if paths == []: + # payload not found + raise st.MissingItemError() + + return paths + + def get_file_data(self, path: t.Sequence[str]) \ + -> t.Optional[st.FileData]: + if len(path) == 0: + raise st.MissingItemError() + + resource_identifier, *file_name_segments = path + + file_name = '/'.join(file_name_segments) + + with self.state.cursor() as cursor: + cursor.execute( + ''' + SELECT + f.data, fu.mime_type + FROM + payloads AS p + JOIN resolved_depended_resources AS rdd + USING (payload_id) + JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + JOIN items AS i + USING (item_id) + JOIN file_uses AS fu + USING (item_version_id) + JOIN files AS f + USING (file_id) + WHERE + p.payload_id = ? AND + i.identifier = ? AND + fu.name = ? AND + fu.type = 'W'; + ''', + (self.id, resource_identifier, file_name) + ) + + result = cursor.fetchall() + + if result == []: + return None + + (data, mime_type), = result + + return st.FileData(type=mime_type, name=file_name, contents=data) + +def register_payload( + policy_tree: PolicyTree, + pattern: url_patterns.ParsedPattern, + payload_key: st.PayloadKey, + token: str +) -> PolicyTree: + """....""" + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register(pattern, payload_policy_factory) + + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + pattern.path_append(token, '***'), + resource_policy_factory + ) + + return policy_tree # mypy needs to be corrected: # https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 @dc.dataclass # type: ignore[misc] -class HaketiloStateWithFields(state.HaketiloState): +class HaketiloStateWithFields(st.HaketiloState): """....""" store_dir: Path connection: sqlite3.Connection current_cursor: t.Optional[sqlite3.Cursor] = None - #settings: state.HaketiloGlobalSettings + #settings: st.HaketiloGlobalSettings policy_tree: PolicyTree = PolicyTree() payloads_data: PayloadsData = dc.field(default_factory=dict) @@ -98,6 +224,85 @@ class HaketiloStateWithFields(state.HaketiloState): finally: self.current_cursor = None - def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: + def rebuild_structures(self) -> None: + """ + Recreation of data structures as done after every recomputation of + dependencies as well as at startup. + """ + with self.cursor(transaction=True) as cursor: + cursor.execute( + ''' + SELECT + p.payload_id, p.pattern, p.eval_allowed, + p.cors_bypass_allowed, + ms.enabled, + i.identifier + FROM + payloads AS p + JOIN item_versions AS iv + ON p.mapping_item_id = iv.item_version_id + JOIN items AS i USING (item_id) + JOIN mapping_statuses AS ms USING (item_id); + ''' + ) + + rows = cursor.fetchall() + + new_policy_tree = PolicyTree() + + ui_factory = policies.WebUIPolicyFactory(builtin=True) + web_ui_pattern = 'http*://hkt.mitm.it/***' + for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): + new_policy_tree = new_policy_tree.register( + parsed_pattern, + ui_factory + ) + + new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} + + for row in rows: + (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, + enabled_status, + identifier) = row + + payload_ref = ConcretePayloadRef(str(payload_id_int), self) + + previous_data = self.payloads_data.get(payload_ref) + if previous_data is not None: + token = previous_data.unique_token + else: + token = secrets.token_urlsafe(8) + + payload_key = st.PayloadKey(payload_ref, identifier) + + for parsed_pattern in url_patterns.parse_pattern(pattern): + new_policy_tree = register_payload( + new_policy_tree, + parsed_pattern, + payload_key, + token + ) + + pattern_path_segments = parsed_pattern.path_segments + + payload_data = st.PayloadData( + payload_ref = payload_ref, + explicitly_enabled = enabled_status == 'E', + unique_token = token, + pattern_path_segments = pattern_path_segments, + eval_allowed = eval_allowed, + cors_bypass_allowed = cors_bypass_allowed + ) + + new_payloads_data[payload_ref] = payload_data + + self.policy_tree = new_policy_tree + self.payloads_data = new_payloads_data + + @abstractmethod + def recompute_dependencies( + self, + requirements: t.Iterable[sds.MappingRequirement] = [] + ) -> None: """....""" ... -- cgit v1.2.3