diff options
author | Wojtek Kosior <koszko@koszko.org> | 2022-08-22 12:52:59 +0200 |
---|---|---|
committer | Wojtek Kosior <koszko@koszko.org> | 2022-09-28 12:54:51 +0200 |
commit | 8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5 (patch) | |
tree | 4c4956e45701460bedaa0d8b0be808052152777f /src/hydrilla/proxy/state_impl | |
parent | e1344ae7017b28a54d7714895bd54c8431a20bc6 (diff) | |
download | haketilo-hydrilla-8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5.tar.gz haketilo-hydrilla-8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5.zip |
allow pulling packages from remote repository
Diffstat (limited to 'src/hydrilla/proxy/state_impl')
-rw-r--r-- | src/hydrilla/proxy/state_impl/_operations/__init__.py | 9 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/_operations/load_packages.py (renamed from src/hydrilla/proxy/state_impl/load_packages.py) | 124 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/_operations/prune_packages.py (renamed from src/hydrilla/proxy/state_impl/prune_packages.py) | 65 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py | 223 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/base.py | 219 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/concrete_state.py | 322 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/mappings.py | 4 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/repos.py | 135 |
8 files changed, 711 insertions, 390 deletions
diff --git a/src/hydrilla/proxy/state_impl/_operations/__init__.py b/src/hydrilla/proxy/state_impl/_operations/__init__.py new file mode 100644 index 0000000..c147be4 --- /dev/null +++ b/src/hydrilla/proxy/state_impl/_operations/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: CC0-1.0 + +# Copyright (C) 2022 Wojtek Kosior <koszko@koszko.org> +# +# Available under the terms of Creative Commons Zero v1.0 Universal. + +from .load_packages import load_packages, FileResolver +from .prune_packages import prune_packages +from .recompute_dependencies import _recompute_dependencies_no_state_update diff --git a/src/hydrilla/proxy/state_impl/load_packages.py b/src/hydrilla/proxy/state_impl/_operations/load_packages.py index 6983c3e..c294ef0 100644 --- a/src/hydrilla/proxy/state_impl/load_packages.py +++ b/src/hydrilla/proxy/state_impl/_operations/load_packages.py @@ -37,37 +37,49 @@ import dataclasses as dc import typing as t from pathlib import Path +from abc import ABC, abstractmethod import sqlite3 -from ...exceptions import HaketiloException -from ...translations import smart_gettext as _ -from ... import versions -from ... import item_infos -from . import base +from ....exceptions import HaketiloException +from ....translations import smart_gettext as _ +from .... import versions +from .... import item_infos -def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int: +def make_repo_iteration(cursor: sqlite3.Cursor, repo_id: int) -> int: cursor.execute( ''' SELECT - repo_id, next_iteration - 1 + next_iteration FROM repos WHERE - name = ?; + repo_id = ?; ''', - (repo_name,) + (repo_id,) ) - (repo_id, last_iteration), = cursor.fetchall() + (next_iteration,), = cursor.fetchall() cursor.execute( ''' - INSERT OR IGNORE INTO repo_iterations(repo_id, iteration) + UPDATE + repos + SET + next_iteration = ? + WHERE + repo_id = ?; + ''', + (next_iteration + 1, repo_id) + ) + + cursor.execute( + ''' + INSERT INTO repo_iterations(repo_id, iteration) VALUES(?, ?); ''', - (repo_id, last_iteration) + (repo_id, next_iteration) ) cursor.execute( @@ -79,7 +91,7 @@ def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int: WHERE repo_id = ? AND iteration = ?; ''', - (repo_id, last_iteration) + (repo_id, next_iteration) ) (repo_iteration_id,), = cursor.fetchall() @@ -118,7 +130,7 @@ def get_or_make_item_version( item_id: int, repo_iteration_id: int, version: versions.VerTuple, - definition: str + definition: bytes ) -> int: ver_str = versions.version_string(version) @@ -154,8 +166,8 @@ def get_or_make_item_version( def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None: cursor.execute( ''' - INSERT OR IGNORE INTO mapping_statuses(item_id, enabled, frozen) - VALUES(?, 'E', 'R'); + INSERT OR IGNORE INTO mapping_statuses(item_id, enabled, required) + VALUES(?, 'N', FALSE); ''', (item_id,) ) @@ -215,14 +227,18 @@ class _FileInfo: id: int is_ascii: bool +class FileResolver(ABC): + @abstractmethod + def by_sha256(self, sha256: str) -> bytes: + ... + def _add_item( - cursor: sqlite3.Cursor, - files_by_sha256_path: Path, - info: item_infos.AnyInfo, - definition: str + cursor: sqlite3.Cursor, + package_file_resolver: FileResolver, + info: item_infos.AnyInfo, + definition: bytes, + repo_iteration_id: int ) -> None: - repo_iteration_id = get_or_make_repo_iteration(cursor, '<local>') - item_id = get_or_make_item(cursor, info.type_name, info.identifier) item_version_id = get_or_make_item_version( @@ -243,17 +259,7 @@ def _add_item( file_specifiers.extend(info.scripts) for file_spec in file_specifiers: - file_path = files_by_sha256_path / file_spec.sha256 - if not file_path.is_file(): - fmt = _('err.proxy.file_missing_{item_identifier}_{file_name}_{sha256}') - msg = fmt.format( - item_identifier = info.identifier, - file_name = file_spec.name, - sha256 = file_spec.sha256 - ) - raise HaketiloException(msg) - - file_bytes = file_path.read_bytes() + file_bytes = package_file_resolver.by_sha256(file_spec.sha256) sha256 = hashlib.sha256(file_bytes).digest().hex() if sha256 != file_spec.sha256: @@ -309,7 +315,7 @@ AnyInfoVar = t.TypeVar( ) def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ - -> t.Iterator[tuple[AnyInfoVar, str]]: + -> t.Iterator[tuple[AnyInfoVar, bytes]]: item_type_path = malcontent_path / item_class.type_name if not item_type_path.is_dir(): return @@ -319,8 +325,8 @@ def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ continue for item_version_path in item_path.iterdir(): - definition = item_version_path.read_text() - item_info = item_class.load(io.StringIO(definition)) + definition = item_version_path.read_bytes() + item_info = item_class.load(definition) assert item_info.identifier == item_path.name assert versions.version_string(item_info.version) == \ @@ -328,17 +334,45 @@ def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ yield item_info, definition +@dc.dataclass(frozen=True) +class MalcontentFileResolver(FileResolver): + malcontent_dir_path: Path + + def by_sha256(self, sha256: str) -> bytes: + file_path = self.malcontent_dir_path / 'file' / 'sha256' / sha256 + if not file_path.is_file(): + fmt = _('err.proxy.file_missing_{sha256}') + raise HaketiloException(fmt.format(sha256=sha256)) + + return file_path.read_bytes() + def load_packages( - state: base.HaketiloStateWithFields, - cursor: sqlite3.Cursor, - malcontent_path: Path -) -> None: - files_by_sha256_path = malcontent_path / 'file' / 'sha256' + cursor: sqlite3.Cursor, + malcontent_path: Path, + repo_id: int, + package_file_resolver: t.Optional[FileResolver] = None +) -> int: + if package_file_resolver is None: + package_file_resolver = MalcontentFileResolver(malcontent_path) - for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]: + repo_iteration_id = make_repo_iteration(cursor, repo_id) + + types: t.Iterable[t.Type[item_infos.AnyInfo]] = \ + [item_infos.ResourceInfo, item_infos.MappingInfo] + + for info_type in types: info: item_infos.AnyInfo - for info, definition in _read_items( + + for info, definition in _read_items( # type: ignore malcontent_path, - info_type # type: ignore + info_type ): - _add_item(cursor, files_by_sha256_path, info, definition) + _add_item( + cursor, + package_file_resolver, + info, + definition, + repo_iteration_id + ) + + return repo_iteration_id diff --git a/src/hydrilla/proxy/state_impl/prune_packages.py b/src/hydrilla/proxy/state_impl/_operations/prune_packages.py index 1857188..9c2b1d7 100644 --- a/src/hydrilla/proxy/state_impl/prune_packages.py +++ b/src/hydrilla/proxy/state_impl/_operations/prune_packages.py @@ -33,27 +33,42 @@ from __future__ import annotations import sqlite3 +from pathlib import Path -_remove_mapping_versions_sql = ''' -WITH removed_mappings AS ( - SELECT - iv.item_version_id - FROM - item_versions AS iv - JOIN items AS i - USING (item_id) - JOIN orphan_iterations AS oi - USING (repo_iteration_id) - LEFT JOIN payloads AS p - ON p.mapping_item_id = iv.item_version_id - WHERE - i.type = 'M' AND p.payload_id IS NULL -) -DELETE FROM - item_versions -WHERE - item_version_id IN removed_mappings; -''' + +_remove_mapping_versions_sqls = [ + ''' + CREATE TEMPORARY TABLE removed_mappings( + item_version_id INTEGER PRIMARY KEY + ); + ''', ''' + INSERT INTO + removed_mappings + SELECT + iv.item_version_id + FROM + item_versions AS iv + JOIN items AS i USING (item_id) + JOIN mapping_statuses AS ms USING (item_id) + JOIN orphan_iterations AS oi USING (repo_iteration_id) + WHERE + NOT ms.required; + ''', ''' + UPDATE + mapping_statuses + SET + active_version_id = NULL + WHERE + active_version_id IN removed_mappings; + ''', ''' + DELETE FROM + item_versions + WHERE + item_version_id IN removed_mappings; + ''', ''' + DROP TABLE removed_mappings; + ''' +] _remove_resource_versions_sql = ''' WITH removed_resources AS ( @@ -134,7 +149,7 @@ WITH removed_repos AS ( repos AS r LEFT JOIN repo_iterations AS ri USING (repo_id) WHERE - r.deleted AND ri.repo_iteration_id IS NULL + r.deleted AND ri.repo_iteration_id IS NULL AND r.repo_id != 1 ) DELETE FROM repos @@ -142,9 +157,11 @@ WHERE repo_id IN removed_repos; ''' -def prune(cursor: sqlite3.Cursor) -> None: - """....""" - cursor.execute(_remove_mapping_versions_sql) +def prune_packages(cursor: sqlite3.Cursor) -> None: + assert cursor.connection.in_transaction + + for sql in _remove_mapping_versions_sqls: + cursor.execute(sql) cursor.execute(_remove_resource_versions_sql) cursor.execute(_remove_items_sql) cursor.execute(_remove_files_sql) diff --git a/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py b/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py new file mode 100644 index 0000000..4093f12 --- /dev/null +++ b/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (update of dependency tree in the db). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +.... +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import typing as t + +import sqlite3 + +from .... import item_infos +from ... import simple_dependency_satisfying as sds + + +AnyInfoVar = t.TypeVar( + 'AnyInfoVar', + item_infos.ResourceInfo, + item_infos.MappingInfo +) + +def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \ + -> t.Mapping[int, AnyInfoVar]: + cursor.execute( + ''' + SELECT + i.item_id, iv.definition, r.name, ri.iteration + FROM + item_versions AS iv + JOIN items AS i USING (item_id) + JOIN repo_iterations AS ri USING (repo_iteration_id) + JOIN repos AS r USING (repo_id) + WHERE + i.type = ?; + ''', + (info_type.type_name[0].upper(),) + ) + + result: dict[int, AnyInfoVar] = {} + + for item_id, definition, repo_name, repo_iteration in cursor.fetchall(): + info = info_type.load(definition, repo_name, repo_iteration) + result[item_id] = info + + return result + +def _recompute_dependencies_no_state_update( + cursor: sqlite3.Cursor, + extra_requirements: t.Iterable[sds.MappingRequirement] +) -> None: + cursor.execute('DELETE FROM payloads;') + + ids_to_resources = get_infos_of_type(cursor, item_infos.ResourceInfo) + ids_to_mappings = get_infos_of_type(cursor, item_infos.MappingInfo) + + resources = ids_to_resources.items() + resources_to_ids = dict((info.identifier, id) for id, info in resources) + + mappings = ids_to_mappings.items() + mappings_to_ids = dict((info.identifier, id) for id, info in mappings) + + requirements = [*extra_requirements] + + cursor.execute( + ''' + SELECT + i.identifier + FROM + mapping_statuses AS ms + JOIN items AS i USING(item_id) + WHERE + ms.enabled = 'E' AND ms.frozen = 'N'; + ''' + ) + + for mapping_identifier, in cursor.fetchall(): + requirements.append(sds.MappingRequirement(mapping_identifier)) + + cursor.execute( + ''' + SELECT + active_version_id, frozen + FROM + mapping_statuses + WHERE + enabled = 'E' AND frozen IN ('R', 'E'); + ''' + ) + + for active_version_id, frozen in cursor.fetchall(): + info = ids_to_mappings[active_version_id] + + requirement: sds.MappingRequirement + + if frozen == 'R': + requirement = sds.MappingRepoRequirement(info.identifier, info.repo) + else: + requirement = sds.MappingVersionRequirement(info.identifier, info) + + requirements.append(requirement) + + mapping_choices = sds.compute_payloads( + ids_to_resources.values(), + ids_to_mappings.values(), + requirements + ) + + cursor.execute( + ''' + UPDATE + mapping_statuses + SET + required = FALSE, + active_version_id = NULL + WHERE + enabled != 'E'; + ''' + ) + + cursor.execute('DELETE FROM payloads;') + + for choice in mapping_choices.values(): + mapping_ver_id = mappings_to_ids[choice.info.identifier] + + cursor.execute( + ''' + SELECT + item_id + FROM + item_versions + WHERE + item_version_id = ?; + ''', + (mapping_ver_id,) + ) + + (mapping_item_id,), = cursor.fetchall() + + cursor.execute( + ''' + UPDATE + mapping_statuses + SET + required = ?, + active_version_id = ? + WHERE + item_id = ?; + ''', + (choice.required, mapping_ver_id, mapping_item_id) + ) + + for num, (pattern, payload) in enumerate(choice.payloads.items()): + cursor.execute( + ''' + INSERT INTO payloads( + mapping_item_id, + pattern, + eval_allowed, + cors_bypass_allowed + ) + VALUES (?, ?, ?, ?); + ''', + ( + mapping_ver_id, + pattern.orig_url, + payload.allows_eval, + payload.allows_cors_bypass + ) + ) + + cursor.execute( + ''' + SELECT + payload_id + FROM + payloads + WHERE + mapping_item_id = ? AND pattern = ?; + ''', + (mapping_ver_id, pattern.orig_url) + ) + + (payload_id,), = cursor.fetchall() + + for res_num, resource_info in enumerate(payload.resources): + resource_ver_id = resources_to_ids[resource_info.identifier] + cursor.execute( + ''' + INSERT INTO resolved_depended_resources( + payload_id, + resource_item_id, + idx + ) + VALUES(?, ?, ?); + ''', + (payload_id, resource_ver_id, res_num) + ) diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py index 1ae08cb..92833dd 100644 --- a/src/hydrilla/proxy/state_impl/base.py +++ b/src/hydrilla/proxy/state_impl/base.py @@ -34,34 +34,160 @@ subtype. from __future__ import annotations import sqlite3 +import secrets import threading import dataclasses as dc import typing as t from pathlib import Path from contextlib import contextmanager - -import sqlite3 +from abc import abstractmethod from immutables import Map +from ... import url_patterns from ... import pattern_tree -from .. import state +from .. import simple_dependency_satisfying as sds +from .. import state as st from .. import policies PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] -PayloadsData = t.Mapping[state.PayloadRef, state.PayloadData] +PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData] + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcretePayloadRef(st.PayloadRef): + state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False) + + def get_data(self) -> st.PayloadData: + try: + return self.state.payloads_data[self] + except KeyError: + raise st.MissingItemError() + + def get_mapping(self) -> st.MappingVersionRef: + raise NotImplementedError() + + def get_script_paths(self) \ + -> t.Iterable[t.Sequence[str]]: + with self.state.cursor() as cursor: + cursor.execute( + ''' + SELECT + i.identifier, fu.name + FROM + payloads AS p + LEFT JOIN resolved_depended_resources AS rdd + USING (payload_id) + LEFT JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + LEFT JOIN items AS i + USING (item_id) + LEFT JOIN file_uses AS fu + USING (item_version_id) + WHERE + fu.type = 'W' AND + p.payload_id = ? AND + (fu.idx IS NOT NULL OR rdd.idx IS NULL) + ORDER BY + rdd.idx, fu.idx; + ''', + (self.id,) + ) + + paths: list[t.Sequence[str]] = [] + for resource_identifier, file_name in cursor.fetchall(): + if resource_identifier is None: + # payload found but it had no script files + return () + + paths.append((resource_identifier, *file_name.split('/'))) + + if paths == []: + # payload not found + raise st.MissingItemError() + + return paths + + def get_file_data(self, path: t.Sequence[str]) \ + -> t.Optional[st.FileData]: + if len(path) == 0: + raise st.MissingItemError() + + resource_identifier, *file_name_segments = path + + file_name = '/'.join(file_name_segments) + + with self.state.cursor() as cursor: + cursor.execute( + ''' + SELECT + f.data, fu.mime_type + FROM + payloads AS p + JOIN resolved_depended_resources AS rdd + USING (payload_id) + JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + JOIN items AS i + USING (item_id) + JOIN file_uses AS fu + USING (item_version_id) + JOIN files AS f + USING (file_id) + WHERE + p.payload_id = ? AND + i.identifier = ? AND + fu.name = ? AND + fu.type = 'W'; + ''', + (self.id, resource_identifier, file_name) + ) + + result = cursor.fetchall() + + if result == []: + return None + + (data, mime_type), = result + + return st.FileData(type=mime_type, name=file_name, contents=data) + +def register_payload( + policy_tree: PolicyTree, + pattern: url_patterns.ParsedPattern, + payload_key: st.PayloadKey, + token: str +) -> PolicyTree: + """....""" + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register(pattern, payload_policy_factory) + + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + pattern.path_append(token, '***'), + resource_policy_factory + ) + + return policy_tree # mypy needs to be corrected: # https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 @dc.dataclass # type: ignore[misc] -class HaketiloStateWithFields(state.HaketiloState): +class HaketiloStateWithFields(st.HaketiloState): """....""" store_dir: Path connection: sqlite3.Connection current_cursor: t.Optional[sqlite3.Cursor] = None - #settings: state.HaketiloGlobalSettings + #settings: st.HaketiloGlobalSettings policy_tree: PolicyTree = PolicyTree() payloads_data: PayloadsData = dc.field(default_factory=dict) @@ -98,6 +224,85 @@ class HaketiloStateWithFields(state.HaketiloState): finally: self.current_cursor = None - def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: + def rebuild_structures(self) -> None: + """ + Recreation of data structures as done after every recomputation of + dependencies as well as at startup. + """ + with self.cursor(transaction=True) as cursor: + cursor.execute( + ''' + SELECT + p.payload_id, p.pattern, p.eval_allowed, + p.cors_bypass_allowed, + ms.enabled, + i.identifier + FROM + payloads AS p + JOIN item_versions AS iv + ON p.mapping_item_id = iv.item_version_id + JOIN items AS i USING (item_id) + JOIN mapping_statuses AS ms USING (item_id); + ''' + ) + + rows = cursor.fetchall() + + new_policy_tree = PolicyTree() + + ui_factory = policies.WebUIPolicyFactory(builtin=True) + web_ui_pattern = 'http*://hkt.mitm.it/***' + for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): + new_policy_tree = new_policy_tree.register( + parsed_pattern, + ui_factory + ) + + new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} + + for row in rows: + (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, + enabled_status, + identifier) = row + + payload_ref = ConcretePayloadRef(str(payload_id_int), self) + + previous_data = self.payloads_data.get(payload_ref) + if previous_data is not None: + token = previous_data.unique_token + else: + token = secrets.token_urlsafe(8) + + payload_key = st.PayloadKey(payload_ref, identifier) + + for parsed_pattern in url_patterns.parse_pattern(pattern): + new_policy_tree = register_payload( + new_policy_tree, + parsed_pattern, + payload_key, + token + ) + + pattern_path_segments = parsed_pattern.path_segments + + payload_data = st.PayloadData( + payload_ref = payload_ref, + explicitly_enabled = enabled_status == 'E', + unique_token = token, + pattern_path_segments = pattern_path_segments, + eval_allowed = eval_allowed, + cors_bypass_allowed = cors_bypass_allowed + ) + + new_payloads_data[payload_ref] = payload_data + + self.policy_tree = new_policy_tree + self.payloads_data = new_payloads_data + + @abstractmethod + def recompute_dependencies( + self, + requirements: t.Iterable[sds.MappingRequirement] = [] + ) -> None: """....""" ... diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index 525a702..6a59a75 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -32,28 +32,23 @@ and resources. # Enable using with Python 3.7. from __future__ import annotations -import secrets -import io -import hashlib +import sqlite3 import typing as t import dataclasses as dc from pathlib import Path -import sqlite3 - from ...exceptions import HaketiloException from ...translations import smart_gettext as _ -from ... import pattern_tree from ... import url_patterns from ... import item_infos -from ..simple_dependency_satisfying import compute_payloads, ComputedPayload from .. import state as st from .. import policies +from .. import simple_dependency_satisfying as sds from . import base from . import mappings from . import repos -from .load_packages import load_packages +from . import _operations here = Path(__file__).resolve().parent @@ -73,169 +68,14 @@ class ConcreteResourceRef(st.ResourceRef): class ConcreteResourceVersionRef(st.ResourceVersionRef): pass - -@dc.dataclass(frozen=True, unsafe_hash=True) -class ConcretePayloadRef(st.PayloadRef): - state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False) - - def get_data(self) -> st.PayloadData: - try: - return self.state.payloads_data[self] - except KeyError: - raise st.MissingItemError() - - def get_mapping(self) -> st.MappingVersionRef: - raise NotImplementedError() - - def get_script_paths(self) \ - -> t.Iterable[t.Sequence[str]]: - with self.state.cursor() as cursor: - cursor.execute( - ''' - SELECT - i.identifier, fu.name - FROM - payloads AS p - LEFT JOIN resolved_depended_resources AS rdd - USING (payload_id) - LEFT JOIN item_versions AS iv - ON rdd.resource_item_id = iv.item_version_id - LEFT JOIN items AS i - USING (item_id) - LEFT JOIN file_uses AS fu - USING (item_version_id) - WHERE - fu.type = 'W' AND - p.payload_id = ? AND - (fu.idx IS NOT NULL OR rdd.idx IS NULL) - ORDER BY - rdd.idx, fu.idx; - ''', - (self.id,) - ) - - paths: list[t.Sequence[str]] = [] - for resource_identifier, file_name in cursor.fetchall(): - if resource_identifier is None: - # payload found but it had no script files - return () - - paths.append((resource_identifier, *file_name.split('/'))) - - if paths == []: - # payload not found - raise st.MissingItemError() - - return paths - - def get_file_data(self, path: t.Sequence[str]) \ - -> t.Optional[st.FileData]: - if len(path) == 0: - raise st.MissingItemError() - - resource_identifier, *file_name_segments = path - - file_name = '/'.join(file_name_segments) - - with self.state.cursor() as cursor: - cursor.execute( - ''' - SELECT - f.data, fu.mime_type - FROM - payloads AS p - JOIN resolved_depended_resources AS rdd - USING (payload_id) - JOIN item_versions AS iv - ON rdd.resource_item_id = iv.item_version_id - JOIN items AS i - USING (item_id) - JOIN file_uses AS fu - USING (item_version_id) - JOIN files AS f - USING (file_id) - WHERE - p.payload_id = ? AND - i.identifier = ? AND - fu.name = ? AND - fu.type = 'W'; - ''', - (self.id, resource_identifier, file_name) - ) - - result = cursor.fetchall() - - if result == []: - return None - - (data, mime_type), = result - - return st.FileData(type=mime_type, name=file_name, contents=data) - -def register_payload( - policy_tree: base.PolicyTree, - pattern: url_patterns.ParsedPattern, - payload_key: st.PayloadKey, - token: str -) -> base.PolicyTree: - """....""" - payload_policy_factory = policies.PayloadPolicyFactory( - builtin = False, - payload_key = payload_key - ) - - policy_tree = policy_tree.register(pattern, payload_policy_factory) - - resource_policy_factory = policies.PayloadResourcePolicyFactory( - builtin = False, - payload_key = payload_key - ) - - policy_tree = policy_tree.register( - pattern.path_append(token, '***'), - resource_policy_factory - ) - - return policy_tree - -AnyInfoVar = t.TypeVar( - 'AnyInfoVar', - item_infos.ResourceInfo, - item_infos.MappingInfo -) - -def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \ - -> t.Mapping[AnyInfoVar, int]: - cursor.execute( - ''' - SELECT - i.item_id, iv.definition, r.name, ri.iteration - FROM - item_versions AS iv - JOIN items AS i USING (item_id) - JOIN repo_iterations AS ri USING (repo_iteration_id) - JOIN repos AS r USING (repo_id) - WHERE - i.type = ?; - ''', - (info_type.type_name[0].upper(),) - ) - - result: dict[AnyInfoVar, int] = {} - - for item_id, definition, repo_name, repo_iteration in cursor.fetchall(): - definition_io = io.StringIO(definition) - info = info_type.load(definition_io, repo_name, repo_iteration) - result[info] = item_id - - return result - @dc.dataclass class ConcreteHaketiloState(base.HaketiloStateWithFields): def __post_init__(self) -> None: + sqlite3.enable_callback_tracebacks(True) + self._prepare_database() - self._rebuild_structures() + self.rebuild_structures() def _prepare_database(self) -> None: """....""" @@ -282,147 +122,25 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): def import_packages(self, malcontent_path: Path) -> None: with self.cursor(transaction=True) as cursor: - load_packages(self, cursor, malcontent_path) - self.recompute_payloads(cursor) - - def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: - assert self.connection.in_transaction - - resources = get_infos_of_type(cursor, item_infos.ResourceInfo) - mappings = get_infos_of_type(cursor, item_infos.MappingInfo) - - payloads = compute_payloads(resources.keys(), mappings.keys()) - - payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - - cursor.execute('DELETE FROM payloads;') - - for mapping_info, by_pattern in payloads.items(): - for num, (pattern, payload) in enumerate(by_pattern.items()): - cursor.execute( - ''' - INSERT INTO payloads( - mapping_item_id, - pattern, - eval_allowed, - cors_bypass_allowed - ) - VALUES (?, ?, ?, ?); - ''', - ( - mappings[mapping_info], - pattern.orig_url, - payload.allows_eval, - payload.allows_cors_bypass - ) - ) + _operations.load_packages(cursor, malcontent_path, 1) + raise NotImplementedError() + _operations.prune_packages(cursor) - cursor.execute( - ''' - SELECT - payload_id - FROM - payloads - WHERE - mapping_item_id = ? AND pattern = ?; - ''', - (mappings[mapping_info], pattern.orig_url) - ) + self.recompute_dependencies() - (payload_id_int,), = cursor.fetchall() - - for res_num, resource_info in enumerate(payload.resources): - cursor.execute( - ''' - INSERT INTO resolved_depended_resources( - payload_id, - resource_item_id, - idx - ) - VALUES(?, ?, ?); - ''', - (payload_id_int, resources[resource_info], res_num) - ) - - self._rebuild_structures(cursor) - - def _rebuild_structures(self, cursor: t.Optional[sqlite3.Cursor] = None) \ - -> None: - """ - Recreation of data structures as done after every recomputation of - dependencies as well as at startup. - """ - if cursor is None: - with self.cursor() as new_cursor: - return self._rebuild_structures(new_cursor) - - cursor.execute( - ''' - SELECT - p.payload_id, p.pattern, p.eval_allowed, - p.cors_bypass_allowed, - ms.enabled, - i.identifier - FROM - payloads AS p - JOIN item_versions AS iv - ON p.mapping_item_id = iv.item_version_id - JOIN items AS i USING (item_id) - JOIN mapping_statuses AS ms USING (item_id); - ''' - ) - - new_policy_tree = base.PolicyTree() - - ui_factory = policies.WebUIPolicyFactory(builtin=True) - web_ui_pattern = 'http*://hkt.mitm.it/***' - for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): - new_policy_tree = new_policy_tree.register( - parsed_pattern, - ui_factory - ) - - new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - - for row in cursor.fetchall(): - (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, - enabled_status, - identifier) = row - - payload_ref = ConcretePayloadRef(str(payload_id_int), self) - - previous_data = self.payloads_data.get(payload_ref) - if previous_data is not None: - token = previous_data.unique_token - else: - token = secrets.token_urlsafe(8) - - payload_key = st.PayloadKey(payload_ref, identifier) - - for parsed_pattern in url_patterns.parse_pattern(pattern): - new_policy_tree = register_payload( - new_policy_tree, - parsed_pattern, - payload_key, - token - ) - - pattern_path_segments = parsed_pattern.path_segments + def recompute_dependencies( + self, + extra_requirements: t.Iterable[sds.MappingRequirement] = [] + ) -> None: + with self.cursor() as cursor: + assert self.connection.in_transaction - payload_data = st.PayloadData( - payload_ref = payload_ref, - #explicitly_enabled = enabled_status == 'E', - explicitly_enabled = True, - unique_token = token, - pattern_path_segments = pattern_path_segments, - eval_allowed = eval_allowed, - cors_bypass_allowed = cors_bypass_allowed + _operations._recompute_dependencies_no_state_update( + cursor, + extra_requirements ) - new_payloads_data[payload_ref] = payload_data - - self.policy_tree = new_policy_tree - self.payloads_data = new_payloads_data + self.rebuild_structures() def repo_store(self) -> st.RepoStore: return repos.ConcreteRepoStore(self) diff --git a/src/hydrilla/proxy/state_impl/mappings.py b/src/hydrilla/proxy/state_impl/mappings.py index 3668784..cce2a36 100644 --- a/src/hydrilla/proxy/state_impl/mappings.py +++ b/src/hydrilla/proxy/state_impl/mappings.py @@ -76,7 +76,7 @@ class ConcreteMappingVersionRef(st.MappingVersionRef): (status_letter, definition, repo, repo_iteration, is_orphan), = rows item_info = item_infos.MappingInfo.load( - io.StringIO(definition), + definition, repo, repo_iteration ) @@ -120,7 +120,7 @@ class ConcreteMappingVersionStore(st.MappingVersionStore): ref = ConcreteMappingVersionRef(str(item_version_id), self.state) item_info = item_infos.MappingInfo.load( - io.StringIO(definition), + definition, repo, repo_iteration ) diff --git a/src/hydrilla/proxy/state_impl/repos.py b/src/hydrilla/proxy/state_impl/repos.py index 5553ec2..f4c7c71 100644 --- a/src/hydrilla/proxy/state_impl/repos.py +++ b/src/hydrilla/proxy/state_impl/repos.py @@ -33,20 +33,27 @@ inside Haketilo. from __future__ import annotations import re +import json +import tempfile +import requests +import sqlite3 import typing as t import dataclasses as dc -from urllib.parse import urlparse +from urllib.parse import urlparse, urljoin from datetime import datetime +from pathlib import Path -import sqlite3 - +from ... import json_instances +from ... import item_infos +from ... import versions from .. import state as st +from .. import simple_dependency_satisfying as sds from . import base -from . import prune_packages +from . import _operations -def validate_repo_url(url: str) -> None: +def sanitize_repo_url(url: str) -> str: try: parsed = urlparse(url) except: @@ -55,6 +62,11 @@ def validate_repo_url(url: str) -> None: if parsed.scheme not in ('http', 'https'): raise st.RepoUrlInvalid() + if url[-1] != '/': + url = url + '/' + + return url + def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None: cursor.execute( @@ -73,6 +85,53 @@ def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None: raise st.MissingItemError() +def sync_remote_repo_definitions(repo_url: str, dest: Path) -> None: + try: + list_all_response = requests.get(urljoin(repo_url, 'list_all')) + assert list_all_response.ok + + list_instance = list_all_response.json() + except: + raise st.RepoCommunicationError() + + try: + json_instances.validate_instance( + list_instance, + 'api_package_list-{}.schema.json' + ) + except json_instances.UnknownSchemaError: + raise st.RepoApiVersionUnsupported() + except: + raise st.RepoCommunicationError() + + ref: dict[str, t.Any] + + for item_type_name in ('resource', 'mapping'): + for ref in list_instance[item_type_name + 's']: + ver = versions.version_string(versions.normalize(ref['version'])) + item_rel_path = f'{item_type_name}/{ref["identifier"]}/{ver}' + + try: + item_response = requests.get(urljoin(repo_url, item_rel_path)) + assert item_response.ok + except: + raise st.RepoCommunicationError() + + item_path = dest / item_rel_path + item_path.parent.mkdir(parents=True, exist_ok=True) + item_path.write_bytes(item_response.content) + + +@dc.dataclass(frozen=True) +class RemoteFileResolver(_operations.FileResolver): + repo_url: str + + def by_sha256(self, sha256: str) -> bytes: + response = requests.get(urljoin(self.repo_url, f'file/sha256/{sha256}')) + assert response.ok + return response.content + + def make_repo_display_info( ref: st.RepoRef, name: str, @@ -122,6 +181,37 @@ class ConcreteRepoRef(st.RepoRef): (self.id,) ) + _operations.prune_packages(cursor) + + # For mappings explicitly enabled by the user (+ all mappings they + # recursively depend on) let's make sure that their exact same + # versions will be enabled after the change. + cursor.execute( + ''' + SELECT + iv.definition, r.name, ri.iteration + FROM + mapping_statuses AS ms + JOIN item_versions AS iv + ON ms.active_version_id = iv.item_version_id + JOIN repo_iterations AS ri + USING (repo_iteration_id) + JOIN repos AS r + USING (repo_id) + WHERE + ms.required + ''' + ) + + requirements = [] + + for definition, repo, iteration in cursor.fetchall(): + info = item_infos.MappingInfo.load(definition, repo, iteration) + req = sds.MappingVersionRequirement(info.identifier, info) + requirements.append(req) + + self.state.recompute_dependencies(requirements) + def update( self, *, @@ -134,7 +224,7 @@ class ConcreteRepoRef(st.RepoRef): if url is None: return - validate_repo_url(url) + url = sanitize_repo_url(url) with self.state.cursor(transaction=True) as cursor: ensure_repo_not_deleted(cursor, self.id) @@ -144,10 +234,30 @@ class ConcreteRepoRef(st.RepoRef): (url, self.id) ) - prune_packages.prune(cursor) - def refresh(self) -> st.RepoIterationRef: - raise NotImplementedError() + with self.state.cursor(transaction=True) as cursor: + ensure_repo_not_deleted(cursor, self.id) + + cursor.execute( + 'SELECT url from repos where repo_id = ?;', + (self.id,) + ) + + (repo_url,), = cursor.fetchall() + + with tempfile.TemporaryDirectory() as tmpdir_str: + tmpdir = Path(tmpdir_str) + sync_remote_repo_definitions(repo_url, tmpdir) + new_iteration_id = _operations.load_packages( + cursor, + tmpdir, + int(self.id), + RemoteFileResolver(repo_url) + ) + + self.state.recompute_dependencies() + + return ConcreteRepoIterationRef(str(new_iteration_id), self.state) def get_display_info(self) -> st.RepoDisplayInfo: with self.state.cursor() as cursor: @@ -199,7 +309,7 @@ class ConcreteRepoStore(st.RepoStore): if repo_name_regex.match(name) is None: raise st.RepoNameInvalid() - validate_repo_url(url) + url = sanitize_repo_url(url) with self.state.cursor(transaction=True) as cursor: cursor.execute( @@ -272,3 +382,8 @@ class ConcreteRepoStore(st.RepoStore): result.append(make_repo_display_info(ref, *rest)) return result + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteRepoIterationRef(st.RepoIterationRef): + state: base.HaketiloStateWithFields |