diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/hydrilla/item_infos.py | 31 | ||||
-rw-r--r-- | src/hydrilla/json_instances.py | 34 | ||||
-rw-r--r-- | src/hydrilla/proxy/simple_dependency_satisfying.py | 177 | ||||
-rw-r--r-- | src/hydrilla/proxy/state.py | 6 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/_operations/__init__.py | 9 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/_operations/load_packages.py (renamed from src/hydrilla/proxy/state_impl/load_packages.py) | 124 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/_operations/prune_packages.py (renamed from src/hydrilla/proxy/state_impl/prune_packages.py) | 65 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py | 223 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/base.py | 219 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/concrete_state.py | 322 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/mappings.py | 4 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/repos.py | 135 | ||||
-rw-r--r-- | src/hydrilla/proxy/tables.sql | 29 | ||||
-rw-r--r-- | src/hydrilla/server/malcontent.py | 2 | ||||
-rw-r--r-- | src/hydrilla/versions.py | 4 |
15 files changed, 902 insertions, 482 deletions
diff --git a/src/hydrilla/item_infos.py b/src/hydrilla/item_infos.py index a01fe3a..2b89600 100644 --- a/src/hydrilla/item_infos.py +++ b/src/hydrilla/item_infos.py @@ -181,7 +181,7 @@ class ItemInfoBase(ABC, ItemIdentity, Categorizable): item_obj['source_copyright'] ) - version = versions.normalize_version(item_obj['version']) + version = versions.normalize(item_obj['version']) perms_obj = item_obj.get('permissions', {}) @@ -262,14 +262,14 @@ class ResourceInfo(ItemInfoBase): @staticmethod def load( - instance_or_path: json_instances.InstanceOrPathOrIO, - repo: str = '<dummyrepo>', - repo_iteration: int = -1 + instance_source: json_instances.InstanceSource, + repo: str = '<dummyrepo>', + repo_iteration: int = -1 ) -> 'ResourceInfo': """....""" return _load_item_info( ResourceInfo, - instance_or_path, + instance_source, repo, repo_iteration ) @@ -291,7 +291,8 @@ class MappingInfo(ItemInfoBase): """....""" type_name: t.ClassVar[str] = 'mapping' - payloads: t.Mapping[ParsedPattern, ItemSpecifier] = dc.field(hash=False, compare=False) + payloads: t.Mapping[ParsedPattern, ItemSpecifier] = \ + dc.field(hash=False, compare=False) @staticmethod def make( @@ -316,14 +317,14 @@ class MappingInfo(ItemInfoBase): @staticmethod def load( - instance_or_path: json_instances.InstanceOrPathOrIO, - repo: str = '<dummyrepo>', - repo_iteration: int = -1 + instance_source: json_instances.InstanceSource, + repo: str = '<dummyrepo>', + repo_iteration: int = -1 ) -> 'MappingInfo': """....""" return _load_item_info( MappingInfo, - instance_or_path, + instance_source, repo, repo_iteration ) @@ -349,13 +350,13 @@ AnyInfo = t.Union[ResourceInfo, MappingInfo] LoadedType = t.TypeVar('LoadedType', ResourceInfo, MappingInfo) def _load_item_info( - info_type: t.Type[LoadedType], - instance_or_path: json_instances.InstanceOrPathOrIO, - repo: str, - repo_iteration: int + info_type: t.Type[LoadedType], + instance_source: json_instances.InstanceSource, + repo: str, + repo_iteration: int ) -> LoadedType: """Read, validate and autocomplete a mapping/resource description.""" - instance = json_instances.read_instance(instance_or_path) + instance = json_instances.read_instance(instance_source) schema_fmt = f'api_{info_type.type_name}_description-{{}}.schema.json' diff --git a/src/hydrilla/json_instances.py b/src/hydrilla/json_instances.py index be8dbc6..8bec808 100644 --- a/src/hydrilla/json_instances.py +++ b/src/hydrilla/json_instances.py @@ -127,11 +127,14 @@ schema_paths.update([(f'https://hydrilla.koszko.org/schemas/{name}', path) schemas: dict[Path, dict[str, t.Any]] = {} +class UnknownSchemaError(HaketiloException): + pass + def _get_schema(schema_name: str) -> dict[str, t.Any]: """Return loaded JSON of the requested schema. Cache results.""" path = schema_paths.get(schema_name) if path is None: - raise HaketiloException(_('unknown_schema_{}').format(schema_name)) + raise UnknownSchemaError(_('unknown_schema_{}').format(schema_name)) if path not in schemas: schemas[path] = json.loads(path.read_text()) @@ -159,28 +162,33 @@ def parse_instance(text: str) -> object: """Parse 'text' as JSON with additional '//' comments support.""" return json.loads(strip_json_comments(text)) -InstanceOrPathOrIO = t.Union[Path, str, io.TextIOBase, dict[str, t.Any]] +InstanceSource = t.Union[Path, str, io.TextIOBase, dict[str, t.Any], bytes] -def read_instance(instance_or_path: InstanceOrPathOrIO) -> object: +def read_instance(instance_or_path: InstanceSource) -> object: """....""" if isinstance(instance_or_path, dict): return instance_or_path - if isinstance(instance_or_path, io.TextIOBase): - handle = instance_or_path + if isinstance(instance_or_path, bytes): + encoding = json.detect_encoding(instance_or_path) + text = instance_or_path.decode(encoding) + elif isinstance(instance_or_path, io.TextIOBase): + try: + text = instance_or_path.read() + finally: + instance_or_path.close() else: - handle = t.cast(io.TextIOBase, open(instance_or_path, 'rt')) - - try: - text = handle.read() - finally: - handle.close() + text = Path(instance_or_path).read_text() try: return parse_instance(text) except: - fmt = _('err.util.text_in_{}_not_valid_json') - raise HaketiloException(fmt.format(instance_or_path)) + if isinstance(instance_or_path, str) or \ + isinstance(instance_or_path, Path): + fmt = _('err.util.text_in_{}_not_valid_json') + raise HaketiloException(fmt.format(instance_or_path)) + else: + raise HaketiloException(_('err.util.text_not_valid_json')) def get_schema_version(instance: object) -> tuple[int, ...]: """ diff --git a/src/hydrilla/proxy/simple_dependency_satisfying.py b/src/hydrilla/proxy/simple_dependency_satisfying.py index 889ae98..f1371db 100644 --- a/src/hydrilla/proxy/simple_dependency_satisfying.py +++ b/src/hydrilla/proxy/simple_dependency_satisfying.py @@ -34,9 +34,40 @@ from __future__ import annotations import dataclasses as dc import typing as t +from ..exceptions import HaketiloException from .. import item_infos from .. import url_patterns + +class ImpossibleSituation(HaketiloException): + pass + + +@dc.dataclass(frozen=True) +class MappingRequirement: + identifier: str + + def is_fulfilled_by(self, info: item_infos.MappingInfo) -> bool: + return True + +@dc.dataclass(frozen=True) +class MappingRepoRequirement(MappingRequirement): + repo: str + + def is_fulfilled_by(self, info: item_infos.MappingInfo) -> bool: + return info.repo == self.repo + +@dc.dataclass(frozen=True) +class MappingVersionRequirement(MappingRequirement): + version_info: item_infos.MappingInfo + + def __post_init__(self): + assert self.version_info.identifier == self.identifier + + def is_fulfilled_by(self, info: item_infos.MappingInfo) -> bool: + return info == self.version_info + + @dc.dataclass class ComputedPayload: resources: list[item_infos.ResourceInfo] = dc.field(default_factory=list) @@ -44,22 +75,40 @@ class ComputedPayload: allows_eval: bool = False allows_cors_bypass: bool = False -SingleMappingPayloads = t.Mapping[ - url_patterns.ParsedPattern, - ComputedPayload -] +@dc.dataclass +class MappingChoice: + info: item_infos.MappingInfo + required: bool = False + payloads: dict[url_patterns.ParsedPattern, ComputedPayload] = \ + dc.field(default_factory=dict) + -ComputedPayloadsDict = dict[ - item_infos.MappingInfo, - SingleMappingPayloads +MappingsGraph = t.Union[ + t.Mapping[str, set[str]], + t.Mapping[str, frozenset[str]] ] -empty_identifiers_set: set[str] = set() +def _mark_mappings( + identifier: str, + mappings_graph: MappingsGraph, + marked_mappings: set[str] +) -> None: + if identifier in marked_mappings: + return + + marked_mappings.add(identifier) + + for next_mapping in mappings_graph.get(identifier, ()): + _mark_mappings(next_mapping, mappings_graph, marked_mappings) + + +ComputedChoices = dict[str, MappingChoice] @dc.dataclass(frozen=True) -class _ItemsCollection: +class _ComputationData: resources: t.Mapping[str, item_infos.ResourceInfo] mappings: t.Mapping[str, item_infos.MappingInfo] + required: frozenset[str] def _satisfy_payload_resource_rec( self, @@ -108,11 +157,11 @@ class _ItemsCollection: ComputedPayload() ) - def _compute_payloads_no_mapping_requirements(self) -> ComputedPayloadsDict: - computed_result: ComputedPayloadsDict = ComputedPayloadsDict() + def _compute_payloads_no_mapping_requirements(self) -> ComputedChoices: + computed_result: ComputedChoices = ComputedChoices() for mapping_info in self.mappings.values(): - by_pattern: dict[url_patterns.ParsedPattern, ComputedPayload] = {} + mapping_choice = MappingChoice(mapping_info) failure = False @@ -130,63 +179,66 @@ class _ItemsCollection: if mapping_info.allows_cors_bypass: computed_payload.allows_cors_bypass = True - by_pattern[pattern] = computed_payload + mapping_choice.payloads[pattern] = computed_payload if not failure: - computed_result[mapping_info] = by_pattern + computed_result[mapping_info.identifier] = mapping_choice return computed_result - def _mark_mappings_bad( - self, - identifier: str, - reverse_mapping_deps: t.Mapping[str, set[str]], - bad_mappings: set[str] - ) -> None: - if identifier in bad_mappings: - return + def _compute_inter_mapping_deps(self, choices: ComputedChoices) \ + -> dict[str, frozenset[str]]: + mapping_deps: dict[str, frozenset[str]] = {} - bad_mappings.add(identifier) + for mapping_choice in choices.values(): + specs_to_resolve = [*mapping_choice.info.required_mappings] - for requiring in reverse_mapping_deps.get(identifier, ()): - self._mark_mappings_bad( - requiring, - reverse_mapping_deps, - bad_mappings - ) + for computed_payload in mapping_choice.payloads.values(): + for resource_info in computed_payload.resources: + specs_to_resolve.extend(resource_info.required_mappings) - def compute_payloads(self) -> ComputedPayloadsDict: - computed_result = self._compute_payloads_no_mapping_requirements() + depended = frozenset(spec.identifier for spec in specs_to_resolve) + mapping_deps[mapping_choice.info.identifier] = depended - reverse_mapping_deps: dict[str, set[str]] = {} + return mapping_deps - for mapping_info, by_pattern in computed_result.items(): - specs_to_resolve = [*mapping_info.required_mappings] + def compute_payloads(self) -> ComputedChoices: + choices = self._compute_payloads_no_mapping_requirements() - for computed_payload in by_pattern.values(): - for resource_info in computed_payload.resources: - specs_to_resolve.extend(resource_info.required_mappings) + mapping_deps = self._compute_inter_mapping_deps(choices) + + reverse_deps: dict[str, set[str]] = {} - for required_mapping_spec in specs_to_resolve: - identifier = required_mapping_spec.identifier - requiring = reverse_mapping_deps.setdefault(identifier, set()) - requiring.add(mapping_info.identifier) + for depending, depended_set in mapping_deps.items(): + for depended in depended_set: + reverse_deps.setdefault(depended, set()).add(depending) bad_mappings: set[str] = set() - for required_identifier in reverse_mapping_deps.keys(): - if self.mappings.get(required_identifier) not in computed_result: - self._mark_mappings_bad( - required_identifier, - reverse_mapping_deps, - bad_mappings - ) + for depended_identifier in reverse_deps.keys(): + if self.mappings.get(depended_identifier) not in choices: + _mark_mappings(depended_identifier, reverse_deps, bad_mappings) + + if any(identifier in self.required for identifier in bad_mappings): + raise ImpossibleSituation() for identifier in bad_mappings: + if identifier in self.required: + raise ImpossibleSituation() + if identifier in self.mappings: - computed_result.pop(self.mappings[identifier], None) + choices.pop(identifier, None) + + required_mappings: set[str] = set() + + for identifier in self.required: + _mark_mappings(identifier, mapping_deps, required_mappings) + + for identifier in required_mappings: + choices[identifier].required = True + + return choices - return computed_result AnyInfoVar = t.TypeVar( 'AnyInfoVar', @@ -207,10 +259,25 @@ def _choose_newest(infos: t.Iterable[AnyInfoVar]) -> dict[str, AnyInfoVar]: return best_versions def compute_payloads( - resources: t.Iterable[item_infos.ResourceInfo], - mappings: t.Iterable[item_infos.MappingInfo] -) -> ComputedPayloadsDict: + resources: t.Iterable[item_infos.ResourceInfo], + mappings: t.Iterable[item_infos.MappingInfo], + requirements: t.Iterable[MappingRequirement] +) -> ComputedChoices: + reqs_by_identifier = dict((req.identifier, req) for req in requirements) + + filtered_mappings = [] + + for mapping_info in mappings: + req = reqs_by_identifier.get(mapping_info.identifier) + if req is not None and not req.is_fulfilled_by(mapping_info): + continue + + filtered_mappings.append(mapping_info) + best_resources = _choose_newest(resources) - best_mappings = _choose_newest(mappings) + best_mappings = _choose_newest(filtered_mappings) + + required = frozenset(reqs_by_identifier.keys()) - return _ItemsCollection(best_resources, best_mappings).compute_payloads() + return _ComputationData(best_resources, best_mappings, required)\ + .compute_payloads() diff --git a/src/hydrilla/proxy/state.py b/src/hydrilla/proxy/state.py index 6414ae8..42ca998 100644 --- a/src/hydrilla/proxy/state.py +++ b/src/hydrilla/proxy/state.py @@ -91,6 +91,12 @@ class RepoNameTaken(HaketiloException): class RepoUrlInvalid(HaketiloException): pass +class RepoCommunicationError(HaketiloException): + pass + +class RepoApiVersionUnsupported(HaketiloException): + pass + # mypy needs to be corrected: # https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 @dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] diff --git a/src/hydrilla/proxy/state_impl/_operations/__init__.py b/src/hydrilla/proxy/state_impl/_operations/__init__.py new file mode 100644 index 0000000..c147be4 --- /dev/null +++ b/src/hydrilla/proxy/state_impl/_operations/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: CC0-1.0 + +# Copyright (C) 2022 Wojtek Kosior <koszko@koszko.org> +# +# Available under the terms of Creative Commons Zero v1.0 Universal. + +from .load_packages import load_packages, FileResolver +from .prune_packages import prune_packages +from .recompute_dependencies import _recompute_dependencies_no_state_update diff --git a/src/hydrilla/proxy/state_impl/load_packages.py b/src/hydrilla/proxy/state_impl/_operations/load_packages.py index 6983c3e..c294ef0 100644 --- a/src/hydrilla/proxy/state_impl/load_packages.py +++ b/src/hydrilla/proxy/state_impl/_operations/load_packages.py @@ -37,37 +37,49 @@ import dataclasses as dc import typing as t from pathlib import Path +from abc import ABC, abstractmethod import sqlite3 -from ...exceptions import HaketiloException -from ...translations import smart_gettext as _ -from ... import versions -from ... import item_infos -from . import base +from ....exceptions import HaketiloException +from ....translations import smart_gettext as _ +from .... import versions +from .... import item_infos -def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int: +def make_repo_iteration(cursor: sqlite3.Cursor, repo_id: int) -> int: cursor.execute( ''' SELECT - repo_id, next_iteration - 1 + next_iteration FROM repos WHERE - name = ?; + repo_id = ?; ''', - (repo_name,) + (repo_id,) ) - (repo_id, last_iteration), = cursor.fetchall() + (next_iteration,), = cursor.fetchall() cursor.execute( ''' - INSERT OR IGNORE INTO repo_iterations(repo_id, iteration) + UPDATE + repos + SET + next_iteration = ? + WHERE + repo_id = ?; + ''', + (next_iteration + 1, repo_id) + ) + + cursor.execute( + ''' + INSERT INTO repo_iterations(repo_id, iteration) VALUES(?, ?); ''', - (repo_id, last_iteration) + (repo_id, next_iteration) ) cursor.execute( @@ -79,7 +91,7 @@ def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int: WHERE repo_id = ? AND iteration = ?; ''', - (repo_id, last_iteration) + (repo_id, next_iteration) ) (repo_iteration_id,), = cursor.fetchall() @@ -118,7 +130,7 @@ def get_or_make_item_version( item_id: int, repo_iteration_id: int, version: versions.VerTuple, - definition: str + definition: bytes ) -> int: ver_str = versions.version_string(version) @@ -154,8 +166,8 @@ def get_or_make_item_version( def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None: cursor.execute( ''' - INSERT OR IGNORE INTO mapping_statuses(item_id, enabled, frozen) - VALUES(?, 'E', 'R'); + INSERT OR IGNORE INTO mapping_statuses(item_id, enabled, required) + VALUES(?, 'N', FALSE); ''', (item_id,) ) @@ -215,14 +227,18 @@ class _FileInfo: id: int is_ascii: bool +class FileResolver(ABC): + @abstractmethod + def by_sha256(self, sha256: str) -> bytes: + ... + def _add_item( - cursor: sqlite3.Cursor, - files_by_sha256_path: Path, - info: item_infos.AnyInfo, - definition: str + cursor: sqlite3.Cursor, + package_file_resolver: FileResolver, + info: item_infos.AnyInfo, + definition: bytes, + repo_iteration_id: int ) -> None: - repo_iteration_id = get_or_make_repo_iteration(cursor, '<local>') - item_id = get_or_make_item(cursor, info.type_name, info.identifier) item_version_id = get_or_make_item_version( @@ -243,17 +259,7 @@ def _add_item( file_specifiers.extend(info.scripts) for file_spec in file_specifiers: - file_path = files_by_sha256_path / file_spec.sha256 - if not file_path.is_file(): - fmt = _('err.proxy.file_missing_{item_identifier}_{file_name}_{sha256}') - msg = fmt.format( - item_identifier = info.identifier, - file_name = file_spec.name, - sha256 = file_spec.sha256 - ) - raise HaketiloException(msg) - - file_bytes = file_path.read_bytes() + file_bytes = package_file_resolver.by_sha256(file_spec.sha256) sha256 = hashlib.sha256(file_bytes).digest().hex() if sha256 != file_spec.sha256: @@ -309,7 +315,7 @@ AnyInfoVar = t.TypeVar( ) def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ - -> t.Iterator[tuple[AnyInfoVar, str]]: + -> t.Iterator[tuple[AnyInfoVar, bytes]]: item_type_path = malcontent_path / item_class.type_name if not item_type_path.is_dir(): return @@ -319,8 +325,8 @@ def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ continue for item_version_path in item_path.iterdir(): - definition = item_version_path.read_text() - item_info = item_class.load(io.StringIO(definition)) + definition = item_version_path.read_bytes() + item_info = item_class.load(definition) assert item_info.identifier == item_path.name assert versions.version_string(item_info.version) == \ @@ -328,17 +334,45 @@ def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ yield item_info, definition +@dc.dataclass(frozen=True) +class MalcontentFileResolver(FileResolver): + malcontent_dir_path: Path + + def by_sha256(self, sha256: str) -> bytes: + file_path = self.malcontent_dir_path / 'file' / 'sha256' / sha256 + if not file_path.is_file(): + fmt = _('err.proxy.file_missing_{sha256}') + raise HaketiloException(fmt.format(sha256=sha256)) + + return file_path.read_bytes() + def load_packages( - state: base.HaketiloStateWithFields, - cursor: sqlite3.Cursor, - malcontent_path: Path -) -> None: - files_by_sha256_path = malcontent_path / 'file' / 'sha256' + cursor: sqlite3.Cursor, + malcontent_path: Path, + repo_id: int, + package_file_resolver: t.Optional[FileResolver] = None +) -> int: + if package_file_resolver is None: + package_file_resolver = MalcontentFileResolver(malcontent_path) - for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]: + repo_iteration_id = make_repo_iteration(cursor, repo_id) + + types: t.Iterable[t.Type[item_infos.AnyInfo]] = \ + [item_infos.ResourceInfo, item_infos.MappingInfo] + + for info_type in types: info: item_infos.AnyInfo - for info, definition in _read_items( + + for info, definition in _read_items( # type: ignore malcontent_path, - info_type # type: ignore + info_type ): - _add_item(cursor, files_by_sha256_path, info, definition) + _add_item( + cursor, + package_file_resolver, + info, + definition, + repo_iteration_id + ) + + return repo_iteration_id diff --git a/src/hydrilla/proxy/state_impl/prune_packages.py b/src/hydrilla/proxy/state_impl/_operations/prune_packages.py index 1857188..9c2b1d7 100644 --- a/src/hydrilla/proxy/state_impl/prune_packages.py +++ b/src/hydrilla/proxy/state_impl/_operations/prune_packages.py @@ -33,27 +33,42 @@ from __future__ import annotations import sqlite3 +from pathlib import Path -_remove_mapping_versions_sql = ''' -WITH removed_mappings AS ( - SELECT - iv.item_version_id - FROM - item_versions AS iv - JOIN items AS i - USING (item_id) - JOIN orphan_iterations AS oi - USING (repo_iteration_id) - LEFT JOIN payloads AS p - ON p.mapping_item_id = iv.item_version_id - WHERE - i.type = 'M' AND p.payload_id IS NULL -) -DELETE FROM - item_versions -WHERE - item_version_id IN removed_mappings; -''' + +_remove_mapping_versions_sqls = [ + ''' + CREATE TEMPORARY TABLE removed_mappings( + item_version_id INTEGER PRIMARY KEY + ); + ''', ''' + INSERT INTO + removed_mappings + SELECT + iv.item_version_id + FROM + item_versions AS iv + JOIN items AS i USING (item_id) + JOIN mapping_statuses AS ms USING (item_id) + JOIN orphan_iterations AS oi USING (repo_iteration_id) + WHERE + NOT ms.required; + ''', ''' + UPDATE + mapping_statuses + SET + active_version_id = NULL + WHERE + active_version_id IN removed_mappings; + ''', ''' + DELETE FROM + item_versions + WHERE + item_version_id IN removed_mappings; + ''', ''' + DROP TABLE removed_mappings; + ''' +] _remove_resource_versions_sql = ''' WITH removed_resources AS ( @@ -134,7 +149,7 @@ WITH removed_repos AS ( repos AS r LEFT JOIN repo_iterations AS ri USING (repo_id) WHERE - r.deleted AND ri.repo_iteration_id IS NULL + r.deleted AND ri.repo_iteration_id IS NULL AND r.repo_id != 1 ) DELETE FROM repos @@ -142,9 +157,11 @@ WHERE repo_id IN removed_repos; ''' -def prune(cursor: sqlite3.Cursor) -> None: - """....""" - cursor.execute(_remove_mapping_versions_sql) +def prune_packages(cursor: sqlite3.Cursor) -> None: + assert cursor.connection.in_transaction + + for sql in _remove_mapping_versions_sqls: + cursor.execute(sql) cursor.execute(_remove_resource_versions_sql) cursor.execute(_remove_items_sql) cursor.execute(_remove_files_sql) diff --git a/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py b/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py new file mode 100644 index 0000000..4093f12 --- /dev/null +++ b/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (update of dependency tree in the db). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +.... +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import typing as t + +import sqlite3 + +from .... import item_infos +from ... import simple_dependency_satisfying as sds + + +AnyInfoVar = t.TypeVar( + 'AnyInfoVar', + item_infos.ResourceInfo, + item_infos.MappingInfo +) + +def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \ + -> t.Mapping[int, AnyInfoVar]: + cursor.execute( + ''' + SELECT + i.item_id, iv.definition, r.name, ri.iteration + FROM + item_versions AS iv + JOIN items AS i USING (item_id) + JOIN repo_iterations AS ri USING (repo_iteration_id) + JOIN repos AS r USING (repo_id) + WHERE + i.type = ?; + ''', + (info_type.type_name[0].upper(),) + ) + + result: dict[int, AnyInfoVar] = {} + + for item_id, definition, repo_name, repo_iteration in cursor.fetchall(): + info = info_type.load(definition, repo_name, repo_iteration) + result[item_id] = info + + return result + +def _recompute_dependencies_no_state_update( + cursor: sqlite3.Cursor, + extra_requirements: t.Iterable[sds.MappingRequirement] +) -> None: + cursor.execute('DELETE FROM payloads;') + + ids_to_resources = get_infos_of_type(cursor, item_infos.ResourceInfo) + ids_to_mappings = get_infos_of_type(cursor, item_infos.MappingInfo) + + resources = ids_to_resources.items() + resources_to_ids = dict((info.identifier, id) for id, info in resources) + + mappings = ids_to_mappings.items() + mappings_to_ids = dict((info.identifier, id) for id, info in mappings) + + requirements = [*extra_requirements] + + cursor.execute( + ''' + SELECT + i.identifier + FROM + mapping_statuses AS ms + JOIN items AS i USING(item_id) + WHERE + ms.enabled = 'E' AND ms.frozen = 'N'; + ''' + ) + + for mapping_identifier, in cursor.fetchall(): + requirements.append(sds.MappingRequirement(mapping_identifier)) + + cursor.execute( + ''' + SELECT + active_version_id, frozen + FROM + mapping_statuses + WHERE + enabled = 'E' AND frozen IN ('R', 'E'); + ''' + ) + + for active_version_id, frozen in cursor.fetchall(): + info = ids_to_mappings[active_version_id] + + requirement: sds.MappingRequirement + + if frozen == 'R': + requirement = sds.MappingRepoRequirement(info.identifier, info.repo) + else: + requirement = sds.MappingVersionRequirement(info.identifier, info) + + requirements.append(requirement) + + mapping_choices = sds.compute_payloads( + ids_to_resources.values(), + ids_to_mappings.values(), + requirements + ) + + cursor.execute( + ''' + UPDATE + mapping_statuses + SET + required = FALSE, + active_version_id = NULL + WHERE + enabled != 'E'; + ''' + ) + + cursor.execute('DELETE FROM payloads;') + + for choice in mapping_choices.values(): + mapping_ver_id = mappings_to_ids[choice.info.identifier] + + cursor.execute( + ''' + SELECT + item_id + FROM + item_versions + WHERE + item_version_id = ?; + ''', + (mapping_ver_id,) + ) + + (mapping_item_id,), = cursor.fetchall() + + cursor.execute( + ''' + UPDATE + mapping_statuses + SET + required = ?, + active_version_id = ? + WHERE + item_id = ?; + ''', + (choice.required, mapping_ver_id, mapping_item_id) + ) + + for num, (pattern, payload) in enumerate(choice.payloads.items()): + cursor.execute( + ''' + INSERT INTO payloads( + mapping_item_id, + pattern, + eval_allowed, + cors_bypass_allowed + ) + VALUES (?, ?, ?, ?); + ''', + ( + mapping_ver_id, + pattern.orig_url, + payload.allows_eval, + payload.allows_cors_bypass + ) + ) + + cursor.execute( + ''' + SELECT + payload_id + FROM + payloads + WHERE + mapping_item_id = ? AND pattern = ?; + ''', + (mapping_ver_id, pattern.orig_url) + ) + + (payload_id,), = cursor.fetchall() + + for res_num, resource_info in enumerate(payload.resources): + resource_ver_id = resources_to_ids[resource_info.identifier] + cursor.execute( + ''' + INSERT INTO resolved_depended_resources( + payload_id, + resource_item_id, + idx + ) + VALUES(?, ?, ?); + ''', + (payload_id, resource_ver_id, res_num) + ) diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py index 1ae08cb..92833dd 100644 --- a/src/hydrilla/proxy/state_impl/base.py +++ b/src/hydrilla/proxy/state_impl/base.py @@ -34,34 +34,160 @@ subtype. from __future__ import annotations import sqlite3 +import secrets import threading import dataclasses as dc import typing as t from pathlib import Path from contextlib import contextmanager - -import sqlite3 +from abc import abstractmethod from immutables import Map +from ... import url_patterns from ... import pattern_tree -from .. import state +from .. import simple_dependency_satisfying as sds +from .. import state as st from .. import policies PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] -PayloadsData = t.Mapping[state.PayloadRef, state.PayloadData] +PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData] + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcretePayloadRef(st.PayloadRef): + state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False) + + def get_data(self) -> st.PayloadData: + try: + return self.state.payloads_data[self] + except KeyError: + raise st.MissingItemError() + + def get_mapping(self) -> st.MappingVersionRef: + raise NotImplementedError() + + def get_script_paths(self) \ + -> t.Iterable[t.Sequence[str]]: + with self.state.cursor() as cursor: + cursor.execute( + ''' + SELECT + i.identifier, fu.name + FROM + payloads AS p + LEFT JOIN resolved_depended_resources AS rdd + USING (payload_id) + LEFT JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + LEFT JOIN items AS i + USING (item_id) + LEFT JOIN file_uses AS fu + USING (item_version_id) + WHERE + fu.type = 'W' AND + p.payload_id = ? AND + (fu.idx IS NOT NULL OR rdd.idx IS NULL) + ORDER BY + rdd.idx, fu.idx; + ''', + (self.id,) + ) + + paths: list[t.Sequence[str]] = [] + for resource_identifier, file_name in cursor.fetchall(): + if resource_identifier is None: + # payload found but it had no script files + return () + + paths.append((resource_identifier, *file_name.split('/'))) + + if paths == []: + # payload not found + raise st.MissingItemError() + + return paths + + def get_file_data(self, path: t.Sequence[str]) \ + -> t.Optional[st.FileData]: + if len(path) == 0: + raise st.MissingItemError() + + resource_identifier, *file_name_segments = path + + file_name = '/'.join(file_name_segments) + + with self.state.cursor() as cursor: + cursor.execute( + ''' + SELECT + f.data, fu.mime_type + FROM + payloads AS p + JOIN resolved_depended_resources AS rdd + USING (payload_id) + JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + JOIN items AS i + USING (item_id) + JOIN file_uses AS fu + USING (item_version_id) + JOIN files AS f + USING (file_id) + WHERE + p.payload_id = ? AND + i.identifier = ? AND + fu.name = ? AND + fu.type = 'W'; + ''', + (self.id, resource_identifier, file_name) + ) + + result = cursor.fetchall() + + if result == []: + return None + + (data, mime_type), = result + + return st.FileData(type=mime_type, name=file_name, contents=data) + +def register_payload( + policy_tree: PolicyTree, + pattern: url_patterns.ParsedPattern, + payload_key: st.PayloadKey, + token: str +) -> PolicyTree: + """....""" + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register(pattern, payload_policy_factory) + + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + pattern.path_append(token, '***'), + resource_policy_factory + ) + + return policy_tree # mypy needs to be corrected: # https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 @dc.dataclass # type: ignore[misc] -class HaketiloStateWithFields(state.HaketiloState): +class HaketiloStateWithFields(st.HaketiloState): """....""" store_dir: Path connection: sqlite3.Connection current_cursor: t.Optional[sqlite3.Cursor] = None - #settings: state.HaketiloGlobalSettings + #settings: st.HaketiloGlobalSettings policy_tree: PolicyTree = PolicyTree() payloads_data: PayloadsData = dc.field(default_factory=dict) @@ -98,6 +224,85 @@ class HaketiloStateWithFields(state.HaketiloState): finally: self.current_cursor = None - def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: + def rebuild_structures(self) -> None: + """ + Recreation of data structures as done after every recomputation of + dependencies as well as at startup. + """ + with self.cursor(transaction=True) as cursor: + cursor.execute( + ''' + SELECT + p.payload_id, p.pattern, p.eval_allowed, + p.cors_bypass_allowed, + ms.enabled, + i.identifier + FROM + payloads AS p + JOIN item_versions AS iv + ON p.mapping_item_id = iv.item_version_id + JOIN items AS i USING (item_id) + JOIN mapping_statuses AS ms USING (item_id); + ''' + ) + + rows = cursor.fetchall() + + new_policy_tree = PolicyTree() + + ui_factory = policies.WebUIPolicyFactory(builtin=True) + web_ui_pattern = 'http*://hkt.mitm.it/***' + for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): + new_policy_tree = new_policy_tree.register( + parsed_pattern, + ui_factory + ) + + new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} + + for row in rows: + (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, + enabled_status, + identifier) = row + + payload_ref = ConcretePayloadRef(str(payload_id_int), self) + + previous_data = self.payloads_data.get(payload_ref) + if previous_data is not None: + token = previous_data.unique_token + else: + token = secrets.token_urlsafe(8) + + payload_key = st.PayloadKey(payload_ref, identifier) + + for parsed_pattern in url_patterns.parse_pattern(pattern): + new_policy_tree = register_payload( + new_policy_tree, + parsed_pattern, + payload_key, + token + ) + + pattern_path_segments = parsed_pattern.path_segments + + payload_data = st.PayloadData( + payload_ref = payload_ref, + explicitly_enabled = enabled_status == 'E', + unique_token = token, + pattern_path_segments = pattern_path_segments, + eval_allowed = eval_allowed, + cors_bypass_allowed = cors_bypass_allowed + ) + + new_payloads_data[payload_ref] = payload_data + + self.policy_tree = new_policy_tree + self.payloads_data = new_payloads_data + + @abstractmethod + def recompute_dependencies( + self, + requirements: t.Iterable[sds.MappingRequirement] = [] + ) -> None: """....""" ... diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index 525a702..6a59a75 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -32,28 +32,23 @@ and resources. # Enable using with Python 3.7. from __future__ import annotations -import secrets -import io -import hashlib +import sqlite3 import typing as t import dataclasses as dc from pathlib import Path -import sqlite3 - from ...exceptions import HaketiloException from ...translations import smart_gettext as _ -from ... import pattern_tree from ... import url_patterns from ... import item_infos -from ..simple_dependency_satisfying import compute_payloads, ComputedPayload from .. import state as st from .. import policies +from .. import simple_dependency_satisfying as sds from . import base from . import mappings from . import repos -from .load_packages import load_packages +from . import _operations here = Path(__file__).resolve().parent @@ -73,169 +68,14 @@ class ConcreteResourceRef(st.ResourceRef): class ConcreteResourceVersionRef(st.ResourceVersionRef): pass - -@dc.dataclass(frozen=True, unsafe_hash=True) -class ConcretePayloadRef(st.PayloadRef): - state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False) - - def get_data(self) -> st.PayloadData: - try: - return self.state.payloads_data[self] - except KeyError: - raise st.MissingItemError() - - def get_mapping(self) -> st.MappingVersionRef: - raise NotImplementedError() - - def get_script_paths(self) \ - -> t.Iterable[t.Sequence[str]]: - with self.state.cursor() as cursor: - cursor.execute( - ''' - SELECT - i.identifier, fu.name - FROM - payloads AS p - LEFT JOIN resolved_depended_resources AS rdd - USING (payload_id) - LEFT JOIN item_versions AS iv - ON rdd.resource_item_id = iv.item_version_id - LEFT JOIN items AS i - USING (item_id) - LEFT JOIN file_uses AS fu - USING (item_version_id) - WHERE - fu.type = 'W' AND - p.payload_id = ? AND - (fu.idx IS NOT NULL OR rdd.idx IS NULL) - ORDER BY - rdd.idx, fu.idx; - ''', - (self.id,) - ) - - paths: list[t.Sequence[str]] = [] - for resource_identifier, file_name in cursor.fetchall(): - if resource_identifier is None: - # payload found but it had no script files - return () - - paths.append((resource_identifier, *file_name.split('/'))) - - if paths == []: - # payload not found - raise st.MissingItemError() - - return paths - - def get_file_data(self, path: t.Sequence[str]) \ - -> t.Optional[st.FileData]: - if len(path) == 0: - raise st.MissingItemError() - - resource_identifier, *file_name_segments = path - - file_name = '/'.join(file_name_segments) - - with self.state.cursor() as cursor: - cursor.execute( - ''' - SELECT - f.data, fu.mime_type - FROM - payloads AS p - JOIN resolved_depended_resources AS rdd - USING (payload_id) - JOIN item_versions AS iv - ON rdd.resource_item_id = iv.item_version_id - JOIN items AS i - USING (item_id) - JOIN file_uses AS fu - USING (item_version_id) - JOIN files AS f - USING (file_id) - WHERE - p.payload_id = ? AND - i.identifier = ? AND - fu.name = ? AND - fu.type = 'W'; - ''', - (self.id, resource_identifier, file_name) - ) - - result = cursor.fetchall() - - if result == []: - return None - - (data, mime_type), = result - - return st.FileData(type=mime_type, name=file_name, contents=data) - -def register_payload( - policy_tree: base.PolicyTree, - pattern: url_patterns.ParsedPattern, - payload_key: st.PayloadKey, - token: str -) -> base.PolicyTree: - """....""" - payload_policy_factory = policies.PayloadPolicyFactory( - builtin = False, - payload_key = payload_key - ) - - policy_tree = policy_tree.register(pattern, payload_policy_factory) - - resource_policy_factory = policies.PayloadResourcePolicyFactory( - builtin = False, - payload_key = payload_key - ) - - policy_tree = policy_tree.register( - pattern.path_append(token, '***'), - resource_policy_factory - ) - - return policy_tree - -AnyInfoVar = t.TypeVar( - 'AnyInfoVar', - item_infos.ResourceInfo, - item_infos.MappingInfo -) - -def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \ - -> t.Mapping[AnyInfoVar, int]: - cursor.execute( - ''' - SELECT - i.item_id, iv.definition, r.name, ri.iteration - FROM - item_versions AS iv - JOIN items AS i USING (item_id) - JOIN repo_iterations AS ri USING (repo_iteration_id) - JOIN repos AS r USING (repo_id) - WHERE - i.type = ?; - ''', - (info_type.type_name[0].upper(),) - ) - - result: dict[AnyInfoVar, int] = {} - - for item_id, definition, repo_name, repo_iteration in cursor.fetchall(): - definition_io = io.StringIO(definition) - info = info_type.load(definition_io, repo_name, repo_iteration) - result[info] = item_id - - return result - @dc.dataclass class ConcreteHaketiloState(base.HaketiloStateWithFields): def __post_init__(self) -> None: + sqlite3.enable_callback_tracebacks(True) + self._prepare_database() - self._rebuild_structures() + self.rebuild_structures() def _prepare_database(self) -> None: """....""" @@ -282,147 +122,25 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): def import_packages(self, malcontent_path: Path) -> None: with self.cursor(transaction=True) as cursor: - load_packages(self, cursor, malcontent_path) - self.recompute_payloads(cursor) - - def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: - assert self.connection.in_transaction - - resources = get_infos_of_type(cursor, item_infos.ResourceInfo) - mappings = get_infos_of_type(cursor, item_infos.MappingInfo) - - payloads = compute_payloads(resources.keys(), mappings.keys()) - - payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - - cursor.execute('DELETE FROM payloads;') - - for mapping_info, by_pattern in payloads.items(): - for num, (pattern, payload) in enumerate(by_pattern.items()): - cursor.execute( - ''' - INSERT INTO payloads( - mapping_item_id, - pattern, - eval_allowed, - cors_bypass_allowed - ) - VALUES (?, ?, ?, ?); - ''', - ( - mappings[mapping_info], - pattern.orig_url, - payload.allows_eval, - payload.allows_cors_bypass - ) - ) + _operations.load_packages(cursor, malcontent_path, 1) + raise NotImplementedError() + _operations.prune_packages(cursor) - cursor.execute( - ''' - SELECT - payload_id - FROM - payloads - WHERE - mapping_item_id = ? AND pattern = ?; - ''', - (mappings[mapping_info], pattern.orig_url) - ) + self.recompute_dependencies() - (payload_id_int,), = cursor.fetchall() - - for res_num, resource_info in enumerate(payload.resources): - cursor.execute( - ''' - INSERT INTO resolved_depended_resources( - payload_id, - resource_item_id, - idx - ) - VALUES(?, ?, ?); - ''', - (payload_id_int, resources[resource_info], res_num) - ) - - self._rebuild_structures(cursor) - - def _rebuild_structures(self, cursor: t.Optional[sqlite3.Cursor] = None) \ - -> None: - """ - Recreation of data structures as done after every recomputation of - dependencies as well as at startup. - """ - if cursor is None: - with self.cursor() as new_cursor: - return self._rebuild_structures(new_cursor) - - cursor.execute( - ''' - SELECT - p.payload_id, p.pattern, p.eval_allowed, - p.cors_bypass_allowed, - ms.enabled, - i.identifier - FROM - payloads AS p - JOIN item_versions AS iv - ON p.mapping_item_id = iv.item_version_id - JOIN items AS i USING (item_id) - JOIN mapping_statuses AS ms USING (item_id); - ''' - ) - - new_policy_tree = base.PolicyTree() - - ui_factory = policies.WebUIPolicyFactory(builtin=True) - web_ui_pattern = 'http*://hkt.mitm.it/***' - for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern): - new_policy_tree = new_policy_tree.register( - parsed_pattern, - ui_factory - ) - - new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - - for row in cursor.fetchall(): - (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, - enabled_status, - identifier) = row - - payload_ref = ConcretePayloadRef(str(payload_id_int), self) - - previous_data = self.payloads_data.get(payload_ref) - if previous_data is not None: - token = previous_data.unique_token - else: - token = secrets.token_urlsafe(8) - - payload_key = st.PayloadKey(payload_ref, identifier) - - for parsed_pattern in url_patterns.parse_pattern(pattern): - new_policy_tree = register_payload( - new_policy_tree, - parsed_pattern, - payload_key, - token - ) - - pattern_path_segments = parsed_pattern.path_segments + def recompute_dependencies( + self, + extra_requirements: t.Iterable[sds.MappingRequirement] = [] + ) -> None: + with self.cursor() as cursor: + assert self.connection.in_transaction - payload_data = st.PayloadData( - payload_ref = payload_ref, - #explicitly_enabled = enabled_status == 'E', - explicitly_enabled = True, - unique_token = token, - pattern_path_segments = pattern_path_segments, - eval_allowed = eval_allowed, - cors_bypass_allowed = cors_bypass_allowed + _operations._recompute_dependencies_no_state_update( + cursor, + extra_requirements ) - new_payloads_data[payload_ref] = payload_data - - self.policy_tree = new_policy_tree - self.payloads_data = new_payloads_data + self.rebuild_structures() def repo_store(self) -> st.RepoStore: return repos.ConcreteRepoStore(self) diff --git a/src/hydrilla/proxy/state_impl/mappings.py b/src/hydrilla/proxy/state_impl/mappings.py index 3668784..cce2a36 100644 --- a/src/hydrilla/proxy/state_impl/mappings.py +++ b/src/hydrilla/proxy/state_impl/mappings.py @@ -76,7 +76,7 @@ class ConcreteMappingVersionRef(st.MappingVersionRef): (status_letter, definition, repo, repo_iteration, is_orphan), = rows item_info = item_infos.MappingInfo.load( - io.StringIO(definition), + definition, repo, repo_iteration ) @@ -120,7 +120,7 @@ class ConcreteMappingVersionStore(st.MappingVersionStore): ref = ConcreteMappingVersionRef(str(item_version_id), self.state) item_info = item_infos.MappingInfo.load( - io.StringIO(definition), + definition, repo, repo_iteration ) diff --git a/src/hydrilla/proxy/state_impl/repos.py b/src/hydrilla/proxy/state_impl/repos.py index 5553ec2..f4c7c71 100644 --- a/src/hydrilla/proxy/state_impl/repos.py +++ b/src/hydrilla/proxy/state_impl/repos.py @@ -33,20 +33,27 @@ inside Haketilo. from __future__ import annotations import re +import json +import tempfile +import requests +import sqlite3 import typing as t import dataclasses as dc -from urllib.parse import urlparse +from urllib.parse import urlparse, urljoin from datetime import datetime +from pathlib import Path -import sqlite3 - +from ... import json_instances +from ... import item_infos +from ... import versions from .. import state as st +from .. import simple_dependency_satisfying as sds from . import base -from . import prune_packages +from . import _operations -def validate_repo_url(url: str) -> None: +def sanitize_repo_url(url: str) -> str: try: parsed = urlparse(url) except: @@ -55,6 +62,11 @@ def validate_repo_url(url: str) -> None: if parsed.scheme not in ('http', 'https'): raise st.RepoUrlInvalid() + if url[-1] != '/': + url = url + '/' + + return url + def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None: cursor.execute( @@ -73,6 +85,53 @@ def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None: raise st.MissingItemError() +def sync_remote_repo_definitions(repo_url: str, dest: Path) -> None: + try: + list_all_response = requests.get(urljoin(repo_url, 'list_all')) + assert list_all_response.ok + + list_instance = list_all_response.json() + except: + raise st.RepoCommunicationError() + + try: + json_instances.validate_instance( + list_instance, + 'api_package_list-{}.schema.json' + ) + except json_instances.UnknownSchemaError: + raise st.RepoApiVersionUnsupported() + except: + raise st.RepoCommunicationError() + + ref: dict[str, t.Any] + + for item_type_name in ('resource', 'mapping'): + for ref in list_instance[item_type_name + 's']: + ver = versions.version_string(versions.normalize(ref['version'])) + item_rel_path = f'{item_type_name}/{ref["identifier"]}/{ver}' + + try: + item_response = requests.get(urljoin(repo_url, item_rel_path)) + assert item_response.ok + except: + raise st.RepoCommunicationError() + + item_path = dest / item_rel_path + item_path.parent.mkdir(parents=True, exist_ok=True) + item_path.write_bytes(item_response.content) + + +@dc.dataclass(frozen=True) +class RemoteFileResolver(_operations.FileResolver): + repo_url: str + + def by_sha256(self, sha256: str) -> bytes: + response = requests.get(urljoin(self.repo_url, f'file/sha256/{sha256}')) + assert response.ok + return response.content + + def make_repo_display_info( ref: st.RepoRef, name: str, @@ -122,6 +181,37 @@ class ConcreteRepoRef(st.RepoRef): (self.id,) ) + _operations.prune_packages(cursor) + + # For mappings explicitly enabled by the user (+ all mappings they + # recursively depend on) let's make sure that their exact same + # versions will be enabled after the change. + cursor.execute( + ''' + SELECT + iv.definition, r.name, ri.iteration + FROM + mapping_statuses AS ms + JOIN item_versions AS iv + ON ms.active_version_id = iv.item_version_id + JOIN repo_iterations AS ri + USING (repo_iteration_id) + JOIN repos AS r + USING (repo_id) + WHERE + ms.required + ''' + ) + + requirements = [] + + for definition, repo, iteration in cursor.fetchall(): + info = item_infos.MappingInfo.load(definition, repo, iteration) + req = sds.MappingVersionRequirement(info.identifier, info) + requirements.append(req) + + self.state.recompute_dependencies(requirements) + def update( self, *, @@ -134,7 +224,7 @@ class ConcreteRepoRef(st.RepoRef): if url is None: return - validate_repo_url(url) + url = sanitize_repo_url(url) with self.state.cursor(transaction=True) as cursor: ensure_repo_not_deleted(cursor, self.id) @@ -144,10 +234,30 @@ class ConcreteRepoRef(st.RepoRef): (url, self.id) ) - prune_packages.prune(cursor) - def refresh(self) -> st.RepoIterationRef: - raise NotImplementedError() + with self.state.cursor(transaction=True) as cursor: + ensure_repo_not_deleted(cursor, self.id) + + cursor.execute( + 'SELECT url from repos where repo_id = ?;', + (self.id,) + ) + + (repo_url,), = cursor.fetchall() + + with tempfile.TemporaryDirectory() as tmpdir_str: + tmpdir = Path(tmpdir_str) + sync_remote_repo_definitions(repo_url, tmpdir) + new_iteration_id = _operations.load_packages( + cursor, + tmpdir, + int(self.id), + RemoteFileResolver(repo_url) + ) + + self.state.recompute_dependencies() + + return ConcreteRepoIterationRef(str(new_iteration_id), self.state) def get_display_info(self) -> st.RepoDisplayInfo: with self.state.cursor() as cursor: @@ -199,7 +309,7 @@ class ConcreteRepoStore(st.RepoStore): if repo_name_regex.match(name) is None: raise st.RepoNameInvalid() - validate_repo_url(url) + url = sanitize_repo_url(url) with self.state.cursor(transaction=True) as cursor: cursor.execute( @@ -272,3 +382,8 @@ class ConcreteRepoStore(st.RepoStore): result.append(make_repo_display_info(ref, *rest)) return result + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteRepoIterationRef(st.RepoIterationRef): + state: base.HaketiloStateWithFields diff --git a/src/hydrilla/proxy/tables.sql b/src/hydrilla/proxy/tables.sql index fc7c65c..0c604fc 100644 --- a/src/hydrilla/proxy/tables.sql +++ b/src/hydrilla/proxy/tables.sql @@ -132,7 +132,9 @@ CREATE TABLE items( ); CREATE TABLE mapping_statuses( - -- The item with this id shall be a mapping ("type" = 'M'). + -- The item with this id shall be a mapping ("type" = 'M'). For each + -- mapping row in "items" there must be an accompanying row in this + -- table. item_id INTEGER PRIMARY KEY, -- "enabled" determines whether mapping's status is ENABLED, @@ -142,19 +144,32 @@ CREATE TABLE mapping_statuses( -- EXACT_VERSION, is to be updated only with versions from the same -- REPOSITORY or is NOT_FROZEN at all. frozen CHAR(1) NULL, + -- The last 2 fields defined below shall be updated when dependency tree + -- is recomputed. + -- When "required" is TRUE, the mapping is assumed to either be enabled + -- or be (directly or indirectly) required by another mapping which is + -- enabled (i.e. has "enabled" set to 'E'). + required BOOLEAN NOT NULL, + -- Only one version of a mapping is allowed to be active at any time. + -- "active_version_id" indicates which version it is. Only a mapping + -- version referenced by "active_version_id" is allowed to have rows + -- in the "payloads" table reference it. active_version_id INTEGER NULL, CHECK (enabled IN ('E', 'D', 'N')), CHECK ((frozen IS NULL) = (enabled != 'E')), CHECK (frozen IS NULL OR frozen in ('E', 'R', 'N')), - CHECK (enabled != 'E' OR active_version_id IS NOT NULL) - CHECK (enabled != 'D' OR active_version_id IS NULL) + CHECK (enabled != 'E' OR required), + CHECK (enabled != 'D' OR NOT required), + CHECK (not required OR active_version_id IS NOT NULL), + CHECK (enabled != 'D' OR active_version_id IS NULL), FOREIGN KEY (item_id) REFERENCES items (item_id) ON DELETE CASCADE, FOREIGN KEY (active_version_id, item_id) REFERENCES item_versions (item_version_id, item_id) + ON DELETE SET NULL ); CREATE TABLE item_versions( @@ -163,7 +178,7 @@ CREATE TABLE item_versions( item_id INTEGER NOT NULL, version VARCHAR NOT NULL, repo_iteration_id INTEGER NOT NULL, - definition TEXT NOT NULL, + definition BLOB NOT NULL, UNIQUE (item_id, version, repo_iteration_id), -- Constraint below needed to allow foreign key from "mapping_statuses". @@ -189,8 +204,10 @@ FROM GROUP BY r.repo_id, r.name, r.url, r.deleted, r.last_refreshed; --- Every time a repository gets refreshed, or a mapping gets enabled/disabled, --- all dependencies the "payloads" table and those that reference it are +-- Every time a repository gets refreshed or a mapping gets enabled/disabled, +-- the dependency tree is recomputed. In the process the "payloads" table gets +-- cleare and repopulated together with the "resolved_depended_resources" that +-- depends on it. CREATE TABLE payloads( payload_id INTEGER PRIMARY KEY, diff --git a/src/hydrilla/server/malcontent.py b/src/hydrilla/server/malcontent.py index 49c0fb4..ce24330 100644 --- a/src/hydrilla/server/malcontent.py +++ b/src/hydrilla/server/malcontent.py @@ -79,7 +79,7 @@ class VersionedItemInfo( Find and return info of the specified version of the item (or None if absent). """ - return self.items.get(versions.normalize_version(ver)) + return self.items.get(versions.normalize(ver)) def get_all(self) -> t.Iterable[VersionedType]: """Generate item info for all its versions, from oldest to newest.""" diff --git a/src/hydrilla/versions.py b/src/hydrilla/versions.py index 93f395d..12d1c18 100644 --- a/src/hydrilla/versions.py +++ b/src/hydrilla/versions.py @@ -36,7 +36,7 @@ import typing as t VerTuple = t.NewType('VerTuple', 'tuple[int, ...]') -def normalize_version(ver: t.Sequence[int]) -> VerTuple: +def normalize(ver: t.Sequence[int]) -> VerTuple: """Strip rightmost zeroes from 'ver'.""" new_len = 0 for i, num in enumerate(ver): @@ -57,7 +57,7 @@ def parse_normalize(ver_str: str) -> VerTuple: Convert 'ver_str' into a VerTuple representation, e.g. for ver_str="4.6.13.0" return (4, 6, 13). """ - return normalize_version(parse(ver_str)) + return normalize(parse(ver_str)) def version_string(ver: VerTuple, rev: t.Optional[int] = None) -> str: """ |