aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla
diff options
context:
space:
mode:
authorWojtek Kosior <koszko@koszko.org>2022-08-22 12:52:59 +0200
committerWojtek Kosior <koszko@koszko.org>2022-09-28 12:54:51 +0200
commit8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5 (patch)
tree4c4956e45701460bedaa0d8b0be808052152777f /src/hydrilla
parente1344ae7017b28a54d7714895bd54c8431a20bc6 (diff)
downloadhaketilo-hydrilla-8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5.tar.gz
haketilo-hydrilla-8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5.zip
allow pulling packages from remote repository
Diffstat (limited to 'src/hydrilla')
-rw-r--r--src/hydrilla/item_infos.py31
-rw-r--r--src/hydrilla/json_instances.py34
-rw-r--r--src/hydrilla/proxy/simple_dependency_satisfying.py177
-rw-r--r--src/hydrilla/proxy/state.py6
-rw-r--r--src/hydrilla/proxy/state_impl/_operations/__init__.py9
-rw-r--r--src/hydrilla/proxy/state_impl/_operations/load_packages.py (renamed from src/hydrilla/proxy/state_impl/load_packages.py)124
-rw-r--r--src/hydrilla/proxy/state_impl/_operations/prune_packages.py (renamed from src/hydrilla/proxy/state_impl/prune_packages.py)65
-rw-r--r--src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py223
-rw-r--r--src/hydrilla/proxy/state_impl/base.py219
-rw-r--r--src/hydrilla/proxy/state_impl/concrete_state.py322
-rw-r--r--src/hydrilla/proxy/state_impl/mappings.py4
-rw-r--r--src/hydrilla/proxy/state_impl/repos.py135
-rw-r--r--src/hydrilla/proxy/tables.sql29
-rw-r--r--src/hydrilla/server/malcontent.py2
-rw-r--r--src/hydrilla/versions.py4
15 files changed, 902 insertions, 482 deletions
diff --git a/src/hydrilla/item_infos.py b/src/hydrilla/item_infos.py
index a01fe3a..2b89600 100644
--- a/src/hydrilla/item_infos.py
+++ b/src/hydrilla/item_infos.py
@@ -181,7 +181,7 @@ class ItemInfoBase(ABC, ItemIdentity, Categorizable):
item_obj['source_copyright']
)
- version = versions.normalize_version(item_obj['version'])
+ version = versions.normalize(item_obj['version'])
perms_obj = item_obj.get('permissions', {})
@@ -262,14 +262,14 @@ class ResourceInfo(ItemInfoBase):
@staticmethod
def load(
- instance_or_path: json_instances.InstanceOrPathOrIO,
- repo: str = '<dummyrepo>',
- repo_iteration: int = -1
+ instance_source: json_instances.InstanceSource,
+ repo: str = '<dummyrepo>',
+ repo_iteration: int = -1
) -> 'ResourceInfo':
"""...."""
return _load_item_info(
ResourceInfo,
- instance_or_path,
+ instance_source,
repo,
repo_iteration
)
@@ -291,7 +291,8 @@ class MappingInfo(ItemInfoBase):
"""...."""
type_name: t.ClassVar[str] = 'mapping'
- payloads: t.Mapping[ParsedPattern, ItemSpecifier] = dc.field(hash=False, compare=False)
+ payloads: t.Mapping[ParsedPattern, ItemSpecifier] = \
+ dc.field(hash=False, compare=False)
@staticmethod
def make(
@@ -316,14 +317,14 @@ class MappingInfo(ItemInfoBase):
@staticmethod
def load(
- instance_or_path: json_instances.InstanceOrPathOrIO,
- repo: str = '<dummyrepo>',
- repo_iteration: int = -1
+ instance_source: json_instances.InstanceSource,
+ repo: str = '<dummyrepo>',
+ repo_iteration: int = -1
) -> 'MappingInfo':
"""...."""
return _load_item_info(
MappingInfo,
- instance_or_path,
+ instance_source,
repo,
repo_iteration
)
@@ -349,13 +350,13 @@ AnyInfo = t.Union[ResourceInfo, MappingInfo]
LoadedType = t.TypeVar('LoadedType', ResourceInfo, MappingInfo)
def _load_item_info(
- info_type: t.Type[LoadedType],
- instance_or_path: json_instances.InstanceOrPathOrIO,
- repo: str,
- repo_iteration: int
+ info_type: t.Type[LoadedType],
+ instance_source: json_instances.InstanceSource,
+ repo: str,
+ repo_iteration: int
) -> LoadedType:
"""Read, validate and autocomplete a mapping/resource description."""
- instance = json_instances.read_instance(instance_or_path)
+ instance = json_instances.read_instance(instance_source)
schema_fmt = f'api_{info_type.type_name}_description-{{}}.schema.json'
diff --git a/src/hydrilla/json_instances.py b/src/hydrilla/json_instances.py
index be8dbc6..8bec808 100644
--- a/src/hydrilla/json_instances.py
+++ b/src/hydrilla/json_instances.py
@@ -127,11 +127,14 @@ schema_paths.update([(f'https://hydrilla.koszko.org/schemas/{name}', path)
schemas: dict[Path, dict[str, t.Any]] = {}
+class UnknownSchemaError(HaketiloException):
+ pass
+
def _get_schema(schema_name: str) -> dict[str, t.Any]:
"""Return loaded JSON of the requested schema. Cache results."""
path = schema_paths.get(schema_name)
if path is None:
- raise HaketiloException(_('unknown_schema_{}').format(schema_name))
+ raise UnknownSchemaError(_('unknown_schema_{}').format(schema_name))
if path not in schemas:
schemas[path] = json.loads(path.read_text())
@@ -159,28 +162,33 @@ def parse_instance(text: str) -> object:
"""Parse 'text' as JSON with additional '//' comments support."""
return json.loads(strip_json_comments(text))
-InstanceOrPathOrIO = t.Union[Path, str, io.TextIOBase, dict[str, t.Any]]
+InstanceSource = t.Union[Path, str, io.TextIOBase, dict[str, t.Any], bytes]
-def read_instance(instance_or_path: InstanceOrPathOrIO) -> object:
+def read_instance(instance_or_path: InstanceSource) -> object:
"""...."""
if isinstance(instance_or_path, dict):
return instance_or_path
- if isinstance(instance_or_path, io.TextIOBase):
- handle = instance_or_path
+ if isinstance(instance_or_path, bytes):
+ encoding = json.detect_encoding(instance_or_path)
+ text = instance_or_path.decode(encoding)
+ elif isinstance(instance_or_path, io.TextIOBase):
+ try:
+ text = instance_or_path.read()
+ finally:
+ instance_or_path.close()
else:
- handle = t.cast(io.TextIOBase, open(instance_or_path, 'rt'))
-
- try:
- text = handle.read()
- finally:
- handle.close()
+ text = Path(instance_or_path).read_text()
try:
return parse_instance(text)
except:
- fmt = _('err.util.text_in_{}_not_valid_json')
- raise HaketiloException(fmt.format(instance_or_path))
+ if isinstance(instance_or_path, str) or \
+ isinstance(instance_or_path, Path):
+ fmt = _('err.util.text_in_{}_not_valid_json')
+ raise HaketiloException(fmt.format(instance_or_path))
+ else:
+ raise HaketiloException(_('err.util.text_not_valid_json'))
def get_schema_version(instance: object) -> tuple[int, ...]:
"""
diff --git a/src/hydrilla/proxy/simple_dependency_satisfying.py b/src/hydrilla/proxy/simple_dependency_satisfying.py
index 889ae98..f1371db 100644
--- a/src/hydrilla/proxy/simple_dependency_satisfying.py
+++ b/src/hydrilla/proxy/simple_dependency_satisfying.py
@@ -34,9 +34,40 @@ from __future__ import annotations
import dataclasses as dc
import typing as t
+from ..exceptions import HaketiloException
from .. import item_infos
from .. import url_patterns
+
+class ImpossibleSituation(HaketiloException):
+ pass
+
+
+@dc.dataclass(frozen=True)
+class MappingRequirement:
+ identifier: str
+
+ def is_fulfilled_by(self, info: item_infos.MappingInfo) -> bool:
+ return True
+
+@dc.dataclass(frozen=True)
+class MappingRepoRequirement(MappingRequirement):
+ repo: str
+
+ def is_fulfilled_by(self, info: item_infos.MappingInfo) -> bool:
+ return info.repo == self.repo
+
+@dc.dataclass(frozen=True)
+class MappingVersionRequirement(MappingRequirement):
+ version_info: item_infos.MappingInfo
+
+ def __post_init__(self):
+ assert self.version_info.identifier == self.identifier
+
+ def is_fulfilled_by(self, info: item_infos.MappingInfo) -> bool:
+ return info == self.version_info
+
+
@dc.dataclass
class ComputedPayload:
resources: list[item_infos.ResourceInfo] = dc.field(default_factory=list)
@@ -44,22 +75,40 @@ class ComputedPayload:
allows_eval: bool = False
allows_cors_bypass: bool = False
-SingleMappingPayloads = t.Mapping[
- url_patterns.ParsedPattern,
- ComputedPayload
-]
+@dc.dataclass
+class MappingChoice:
+ info: item_infos.MappingInfo
+ required: bool = False
+ payloads: dict[url_patterns.ParsedPattern, ComputedPayload] = \
+ dc.field(default_factory=dict)
+
-ComputedPayloadsDict = dict[
- item_infos.MappingInfo,
- SingleMappingPayloads
+MappingsGraph = t.Union[
+ t.Mapping[str, set[str]],
+ t.Mapping[str, frozenset[str]]
]
-empty_identifiers_set: set[str] = set()
+def _mark_mappings(
+ identifier: str,
+ mappings_graph: MappingsGraph,
+ marked_mappings: set[str]
+) -> None:
+ if identifier in marked_mappings:
+ return
+
+ marked_mappings.add(identifier)
+
+ for next_mapping in mappings_graph.get(identifier, ()):
+ _mark_mappings(next_mapping, mappings_graph, marked_mappings)
+
+
+ComputedChoices = dict[str, MappingChoice]
@dc.dataclass(frozen=True)
-class _ItemsCollection:
+class _ComputationData:
resources: t.Mapping[str, item_infos.ResourceInfo]
mappings: t.Mapping[str, item_infos.MappingInfo]
+ required: frozenset[str]
def _satisfy_payload_resource_rec(
self,
@@ -108,11 +157,11 @@ class _ItemsCollection:
ComputedPayload()
)
- def _compute_payloads_no_mapping_requirements(self) -> ComputedPayloadsDict:
- computed_result: ComputedPayloadsDict = ComputedPayloadsDict()
+ def _compute_payloads_no_mapping_requirements(self) -> ComputedChoices:
+ computed_result: ComputedChoices = ComputedChoices()
for mapping_info in self.mappings.values():
- by_pattern: dict[url_patterns.ParsedPattern, ComputedPayload] = {}
+ mapping_choice = MappingChoice(mapping_info)
failure = False
@@ -130,63 +179,66 @@ class _ItemsCollection:
if mapping_info.allows_cors_bypass:
computed_payload.allows_cors_bypass = True
- by_pattern[pattern] = computed_payload
+ mapping_choice.payloads[pattern] = computed_payload
if not failure:
- computed_result[mapping_info] = by_pattern
+ computed_result[mapping_info.identifier] = mapping_choice
return computed_result
- def _mark_mappings_bad(
- self,
- identifier: str,
- reverse_mapping_deps: t.Mapping[str, set[str]],
- bad_mappings: set[str]
- ) -> None:
- if identifier in bad_mappings:
- return
+ def _compute_inter_mapping_deps(self, choices: ComputedChoices) \
+ -> dict[str, frozenset[str]]:
+ mapping_deps: dict[str, frozenset[str]] = {}
- bad_mappings.add(identifier)
+ for mapping_choice in choices.values():
+ specs_to_resolve = [*mapping_choice.info.required_mappings]
- for requiring in reverse_mapping_deps.get(identifier, ()):
- self._mark_mappings_bad(
- requiring,
- reverse_mapping_deps,
- bad_mappings
- )
+ for computed_payload in mapping_choice.payloads.values():
+ for resource_info in computed_payload.resources:
+ specs_to_resolve.extend(resource_info.required_mappings)
- def compute_payloads(self) -> ComputedPayloadsDict:
- computed_result = self._compute_payloads_no_mapping_requirements()
+ depended = frozenset(spec.identifier for spec in specs_to_resolve)
+ mapping_deps[mapping_choice.info.identifier] = depended
- reverse_mapping_deps: dict[str, set[str]] = {}
+ return mapping_deps
- for mapping_info, by_pattern in computed_result.items():
- specs_to_resolve = [*mapping_info.required_mappings]
+ def compute_payloads(self) -> ComputedChoices:
+ choices = self._compute_payloads_no_mapping_requirements()
- for computed_payload in by_pattern.values():
- for resource_info in computed_payload.resources:
- specs_to_resolve.extend(resource_info.required_mappings)
+ mapping_deps = self._compute_inter_mapping_deps(choices)
+
+ reverse_deps: dict[str, set[str]] = {}
- for required_mapping_spec in specs_to_resolve:
- identifier = required_mapping_spec.identifier
- requiring = reverse_mapping_deps.setdefault(identifier, set())
- requiring.add(mapping_info.identifier)
+ for depending, depended_set in mapping_deps.items():
+ for depended in depended_set:
+ reverse_deps.setdefault(depended, set()).add(depending)
bad_mappings: set[str] = set()
- for required_identifier in reverse_mapping_deps.keys():
- if self.mappings.get(required_identifier) not in computed_result:
- self._mark_mappings_bad(
- required_identifier,
- reverse_mapping_deps,
- bad_mappings
- )
+ for depended_identifier in reverse_deps.keys():
+ if self.mappings.get(depended_identifier) not in choices:
+ _mark_mappings(depended_identifier, reverse_deps, bad_mappings)
+
+ if any(identifier in self.required for identifier in bad_mappings):
+ raise ImpossibleSituation()
for identifier in bad_mappings:
+ if identifier in self.required:
+ raise ImpossibleSituation()
+
if identifier in self.mappings:
- computed_result.pop(self.mappings[identifier], None)
+ choices.pop(identifier, None)
+
+ required_mappings: set[str] = set()
+
+ for identifier in self.required:
+ _mark_mappings(identifier, mapping_deps, required_mappings)
+
+ for identifier in required_mappings:
+ choices[identifier].required = True
+
+ return choices
- return computed_result
AnyInfoVar = t.TypeVar(
'AnyInfoVar',
@@ -207,10 +259,25 @@ def _choose_newest(infos: t.Iterable[AnyInfoVar]) -> dict[str, AnyInfoVar]:
return best_versions
def compute_payloads(
- resources: t.Iterable[item_infos.ResourceInfo],
- mappings: t.Iterable[item_infos.MappingInfo]
-) -> ComputedPayloadsDict:
+ resources: t.Iterable[item_infos.ResourceInfo],
+ mappings: t.Iterable[item_infos.MappingInfo],
+ requirements: t.Iterable[MappingRequirement]
+) -> ComputedChoices:
+ reqs_by_identifier = dict((req.identifier, req) for req in requirements)
+
+ filtered_mappings = []
+
+ for mapping_info in mappings:
+ req = reqs_by_identifier.get(mapping_info.identifier)
+ if req is not None and not req.is_fulfilled_by(mapping_info):
+ continue
+
+ filtered_mappings.append(mapping_info)
+
best_resources = _choose_newest(resources)
- best_mappings = _choose_newest(mappings)
+ best_mappings = _choose_newest(filtered_mappings)
+
+ required = frozenset(reqs_by_identifier.keys())
- return _ItemsCollection(best_resources, best_mappings).compute_payloads()
+ return _ComputationData(best_resources, best_mappings, required)\
+ .compute_payloads()
diff --git a/src/hydrilla/proxy/state.py b/src/hydrilla/proxy/state.py
index 6414ae8..42ca998 100644
--- a/src/hydrilla/proxy/state.py
+++ b/src/hydrilla/proxy/state.py
@@ -91,6 +91,12 @@ class RepoNameTaken(HaketiloException):
class RepoUrlInvalid(HaketiloException):
pass
+class RepoCommunicationError(HaketiloException):
+ pass
+
+class RepoApiVersionUnsupported(HaketiloException):
+ pass
+
# mypy needs to be corrected:
# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
diff --git a/src/hydrilla/proxy/state_impl/_operations/__init__.py b/src/hydrilla/proxy/state_impl/_operations/__init__.py
new file mode 100644
index 0000000..c147be4
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl/_operations/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-License-Identifier: CC0-1.0
+
+# Copyright (C) 2022 Wojtek Kosior <koszko@koszko.org>
+#
+# Available under the terms of Creative Commons Zero v1.0 Universal.
+
+from .load_packages import load_packages, FileResolver
+from .prune_packages import prune_packages
+from .recompute_dependencies import _recompute_dependencies_no_state_update
diff --git a/src/hydrilla/proxy/state_impl/load_packages.py b/src/hydrilla/proxy/state_impl/_operations/load_packages.py
index 6983c3e..c294ef0 100644
--- a/src/hydrilla/proxy/state_impl/load_packages.py
+++ b/src/hydrilla/proxy/state_impl/_operations/load_packages.py
@@ -37,37 +37,49 @@ import dataclasses as dc
import typing as t
from pathlib import Path
+from abc import ABC, abstractmethod
import sqlite3
-from ...exceptions import HaketiloException
-from ...translations import smart_gettext as _
-from ... import versions
-from ... import item_infos
-from . import base
+from ....exceptions import HaketiloException
+from ....translations import smart_gettext as _
+from .... import versions
+from .... import item_infos
-def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int:
+def make_repo_iteration(cursor: sqlite3.Cursor, repo_id: int) -> int:
cursor.execute(
'''
SELECT
- repo_id, next_iteration - 1
+ next_iteration
FROM
repos
WHERE
- name = ?;
+ repo_id = ?;
''',
- (repo_name,)
+ (repo_id,)
)
- (repo_id, last_iteration), = cursor.fetchall()
+ (next_iteration,), = cursor.fetchall()
cursor.execute(
'''
- INSERT OR IGNORE INTO repo_iterations(repo_id, iteration)
+ UPDATE
+ repos
+ SET
+ next_iteration = ?
+ WHERE
+ repo_id = ?;
+ ''',
+ (next_iteration + 1, repo_id)
+ )
+
+ cursor.execute(
+ '''
+ INSERT INTO repo_iterations(repo_id, iteration)
VALUES(?, ?);
''',
- (repo_id, last_iteration)
+ (repo_id, next_iteration)
)
cursor.execute(
@@ -79,7 +91,7 @@ def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int:
WHERE
repo_id = ? AND iteration = ?;
''',
- (repo_id, last_iteration)
+ (repo_id, next_iteration)
)
(repo_iteration_id,), = cursor.fetchall()
@@ -118,7 +130,7 @@ def get_or_make_item_version(
item_id: int,
repo_iteration_id: int,
version: versions.VerTuple,
- definition: str
+ definition: bytes
) -> int:
ver_str = versions.version_string(version)
@@ -154,8 +166,8 @@ def get_or_make_item_version(
def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None:
cursor.execute(
'''
- INSERT OR IGNORE INTO mapping_statuses(item_id, enabled, frozen)
- VALUES(?, 'E', 'R');
+ INSERT OR IGNORE INTO mapping_statuses(item_id, enabled, required)
+ VALUES(?, 'N', FALSE);
''',
(item_id,)
)
@@ -215,14 +227,18 @@ class _FileInfo:
id: int
is_ascii: bool
+class FileResolver(ABC):
+ @abstractmethod
+ def by_sha256(self, sha256: str) -> bytes:
+ ...
+
def _add_item(
- cursor: sqlite3.Cursor,
- files_by_sha256_path: Path,
- info: item_infos.AnyInfo,
- definition: str
+ cursor: sqlite3.Cursor,
+ package_file_resolver: FileResolver,
+ info: item_infos.AnyInfo,
+ definition: bytes,
+ repo_iteration_id: int
) -> None:
- repo_iteration_id = get_or_make_repo_iteration(cursor, '<local>')
-
item_id = get_or_make_item(cursor, info.type_name, info.identifier)
item_version_id = get_or_make_item_version(
@@ -243,17 +259,7 @@ def _add_item(
file_specifiers.extend(info.scripts)
for file_spec in file_specifiers:
- file_path = files_by_sha256_path / file_spec.sha256
- if not file_path.is_file():
- fmt = _('err.proxy.file_missing_{item_identifier}_{file_name}_{sha256}')
- msg = fmt.format(
- item_identifier = info.identifier,
- file_name = file_spec.name,
- sha256 = file_spec.sha256
- )
- raise HaketiloException(msg)
-
- file_bytes = file_path.read_bytes()
+ file_bytes = package_file_resolver.by_sha256(file_spec.sha256)
sha256 = hashlib.sha256(file_bytes).digest().hex()
if sha256 != file_spec.sha256:
@@ -309,7 +315,7 @@ AnyInfoVar = t.TypeVar(
)
def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
- -> t.Iterator[tuple[AnyInfoVar, str]]:
+ -> t.Iterator[tuple[AnyInfoVar, bytes]]:
item_type_path = malcontent_path / item_class.type_name
if not item_type_path.is_dir():
return
@@ -319,8 +325,8 @@ def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
continue
for item_version_path in item_path.iterdir():
- definition = item_version_path.read_text()
- item_info = item_class.load(io.StringIO(definition))
+ definition = item_version_path.read_bytes()
+ item_info = item_class.load(definition)
assert item_info.identifier == item_path.name
assert versions.version_string(item_info.version) == \
@@ -328,17 +334,45 @@ def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
yield item_info, definition
+@dc.dataclass(frozen=True)
+class MalcontentFileResolver(FileResolver):
+ malcontent_dir_path: Path
+
+ def by_sha256(self, sha256: str) -> bytes:
+ file_path = self.malcontent_dir_path / 'file' / 'sha256' / sha256
+ if not file_path.is_file():
+ fmt = _('err.proxy.file_missing_{sha256}')
+ raise HaketiloException(fmt.format(sha256=sha256))
+
+ return file_path.read_bytes()
+
def load_packages(
- state: base.HaketiloStateWithFields,
- cursor: sqlite3.Cursor,
- malcontent_path: Path
-) -> None:
- files_by_sha256_path = malcontent_path / 'file' / 'sha256'
+ cursor: sqlite3.Cursor,
+ malcontent_path: Path,
+ repo_id: int,
+ package_file_resolver: t.Optional[FileResolver] = None
+) -> int:
+ if package_file_resolver is None:
+ package_file_resolver = MalcontentFileResolver(malcontent_path)
- for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]:
+ repo_iteration_id = make_repo_iteration(cursor, repo_id)
+
+ types: t.Iterable[t.Type[item_infos.AnyInfo]] = \
+ [item_infos.ResourceInfo, item_infos.MappingInfo]
+
+ for info_type in types:
info: item_infos.AnyInfo
- for info, definition in _read_items(
+
+ for info, definition in _read_items( # type: ignore
malcontent_path,
- info_type # type: ignore
+ info_type
):
- _add_item(cursor, files_by_sha256_path, info, definition)
+ _add_item(
+ cursor,
+ package_file_resolver,
+ info,
+ definition,
+ repo_iteration_id
+ )
+
+ return repo_iteration_id
diff --git a/src/hydrilla/proxy/state_impl/prune_packages.py b/src/hydrilla/proxy/state_impl/_operations/prune_packages.py
index 1857188..9c2b1d7 100644
--- a/src/hydrilla/proxy/state_impl/prune_packages.py
+++ b/src/hydrilla/proxy/state_impl/_operations/prune_packages.py
@@ -33,27 +33,42 @@ from __future__ import annotations
import sqlite3
+from pathlib import Path
-_remove_mapping_versions_sql = '''
-WITH removed_mappings AS (
- SELECT
- iv.item_version_id
- FROM
- item_versions AS iv
- JOIN items AS i
- USING (item_id)
- JOIN orphan_iterations AS oi
- USING (repo_iteration_id)
- LEFT JOIN payloads AS p
- ON p.mapping_item_id = iv.item_version_id
- WHERE
- i.type = 'M' AND p.payload_id IS NULL
-)
-DELETE FROM
- item_versions
-WHERE
- item_version_id IN removed_mappings;
-'''
+
+_remove_mapping_versions_sqls = [
+ '''
+ CREATE TEMPORARY TABLE removed_mappings(
+ item_version_id INTEGER PRIMARY KEY
+ );
+ ''', '''
+ INSERT INTO
+ removed_mappings
+ SELECT
+ iv.item_version_id
+ FROM
+ item_versions AS iv
+ JOIN items AS i USING (item_id)
+ JOIN mapping_statuses AS ms USING (item_id)
+ JOIN orphan_iterations AS oi USING (repo_iteration_id)
+ WHERE
+ NOT ms.required;
+ ''', '''
+ UPDATE
+ mapping_statuses
+ SET
+ active_version_id = NULL
+ WHERE
+ active_version_id IN removed_mappings;
+ ''', '''
+ DELETE FROM
+ item_versions
+ WHERE
+ item_version_id IN removed_mappings;
+ ''', '''
+ DROP TABLE removed_mappings;
+ '''
+]
_remove_resource_versions_sql = '''
WITH removed_resources AS (
@@ -134,7 +149,7 @@ WITH removed_repos AS (
repos AS r
LEFT JOIN repo_iterations AS ri USING (repo_id)
WHERE
- r.deleted AND ri.repo_iteration_id IS NULL
+ r.deleted AND ri.repo_iteration_id IS NULL AND r.repo_id != 1
)
DELETE FROM
repos
@@ -142,9 +157,11 @@ WHERE
repo_id IN removed_repos;
'''
-def prune(cursor: sqlite3.Cursor) -> None:
- """...."""
- cursor.execute(_remove_mapping_versions_sql)
+def prune_packages(cursor: sqlite3.Cursor) -> None:
+ assert cursor.connection.in_transaction
+
+ for sql in _remove_mapping_versions_sqls:
+ cursor.execute(sql)
cursor.execute(_remove_resource_versions_sql)
cursor.execute(_remove_items_sql)
cursor.execute(_remove_files_sql)
diff --git a/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py b/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py
new file mode 100644
index 0000000..4093f12
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py
@@ -0,0 +1,223 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy data and configuration (update of dependency tree in the db).
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+....
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import typing as t
+
+import sqlite3
+
+from .... import item_infos
+from ... import simple_dependency_satisfying as sds
+
+
+AnyInfoVar = t.TypeVar(
+ 'AnyInfoVar',
+ item_infos.ResourceInfo,
+ item_infos.MappingInfo
+)
+
+def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \
+ -> t.Mapping[int, AnyInfoVar]:
+ cursor.execute(
+ '''
+ SELECT
+ i.item_id, iv.definition, r.name, ri.iteration
+ FROM
+ item_versions AS iv
+ JOIN items AS i USING (item_id)
+ JOIN repo_iterations AS ri USING (repo_iteration_id)
+ JOIN repos AS r USING (repo_id)
+ WHERE
+ i.type = ?;
+ ''',
+ (info_type.type_name[0].upper(),)
+ )
+
+ result: dict[int, AnyInfoVar] = {}
+
+ for item_id, definition, repo_name, repo_iteration in cursor.fetchall():
+ info = info_type.load(definition, repo_name, repo_iteration)
+ result[item_id] = info
+
+ return result
+
+def _recompute_dependencies_no_state_update(
+ cursor: sqlite3.Cursor,
+ extra_requirements: t.Iterable[sds.MappingRequirement]
+) -> None:
+ cursor.execute('DELETE FROM payloads;')
+
+ ids_to_resources = get_infos_of_type(cursor, item_infos.ResourceInfo)
+ ids_to_mappings = get_infos_of_type(cursor, item_infos.MappingInfo)
+
+ resources = ids_to_resources.items()
+ resources_to_ids = dict((info.identifier, id) for id, info in resources)
+
+ mappings = ids_to_mappings.items()
+ mappings_to_ids = dict((info.identifier, id) for id, info in mappings)
+
+ requirements = [*extra_requirements]
+
+ cursor.execute(
+ '''
+ SELECT
+ i.identifier
+ FROM
+ mapping_statuses AS ms
+ JOIN items AS i USING(item_id)
+ WHERE
+ ms.enabled = 'E' AND ms.frozen = 'N';
+ '''
+ )
+
+ for mapping_identifier, in cursor.fetchall():
+ requirements.append(sds.MappingRequirement(mapping_identifier))
+
+ cursor.execute(
+ '''
+ SELECT
+ active_version_id, frozen
+ FROM
+ mapping_statuses
+ WHERE
+ enabled = 'E' AND frozen IN ('R', 'E');
+ '''
+ )
+
+ for active_version_id, frozen in cursor.fetchall():
+ info = ids_to_mappings[active_version_id]
+
+ requirement: sds.MappingRequirement
+
+ if frozen == 'R':
+ requirement = sds.MappingRepoRequirement(info.identifier, info.repo)
+ else:
+ requirement = sds.MappingVersionRequirement(info.identifier, info)
+
+ requirements.append(requirement)
+
+ mapping_choices = sds.compute_payloads(
+ ids_to_resources.values(),
+ ids_to_mappings.values(),
+ requirements
+ )
+
+ cursor.execute(
+ '''
+ UPDATE
+ mapping_statuses
+ SET
+ required = FALSE,
+ active_version_id = NULL
+ WHERE
+ enabled != 'E';
+ '''
+ )
+
+ cursor.execute('DELETE FROM payloads;')
+
+ for choice in mapping_choices.values():
+ mapping_ver_id = mappings_to_ids[choice.info.identifier]
+
+ cursor.execute(
+ '''
+ SELECT
+ item_id
+ FROM
+ item_versions
+ WHERE
+ item_version_id = ?;
+ ''',
+ (mapping_ver_id,)
+ )
+
+ (mapping_item_id,), = cursor.fetchall()
+
+ cursor.execute(
+ '''
+ UPDATE
+ mapping_statuses
+ SET
+ required = ?,
+ active_version_id = ?
+ WHERE
+ item_id = ?;
+ ''',
+ (choice.required, mapping_ver_id, mapping_item_id)
+ )
+
+ for num, (pattern, payload) in enumerate(choice.payloads.items()):
+ cursor.execute(
+ '''
+ INSERT INTO payloads(
+ mapping_item_id,
+ pattern,
+ eval_allowed,
+ cors_bypass_allowed
+ )
+ VALUES (?, ?, ?, ?);
+ ''',
+ (
+ mapping_ver_id,
+ pattern.orig_url,
+ payload.allows_eval,
+ payload.allows_cors_bypass
+ )
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ payload_id
+ FROM
+ payloads
+ WHERE
+ mapping_item_id = ? AND pattern = ?;
+ ''',
+ (mapping_ver_id, pattern.orig_url)
+ )
+
+ (payload_id,), = cursor.fetchall()
+
+ for res_num, resource_info in enumerate(payload.resources):
+ resource_ver_id = resources_to_ids[resource_info.identifier]
+ cursor.execute(
+ '''
+ INSERT INTO resolved_depended_resources(
+ payload_id,
+ resource_item_id,
+ idx
+ )
+ VALUES(?, ?, ?);
+ ''',
+ (payload_id, resource_ver_id, res_num)
+ )
diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py
index 1ae08cb..92833dd 100644
--- a/src/hydrilla/proxy/state_impl/base.py
+++ b/src/hydrilla/proxy/state_impl/base.py
@@ -34,34 +34,160 @@ subtype.
from __future__ import annotations
import sqlite3
+import secrets
import threading
import dataclasses as dc
import typing as t
from pathlib import Path
from contextlib import contextmanager
-
-import sqlite3
+from abc import abstractmethod
from immutables import Map
+from ... import url_patterns
from ... import pattern_tree
-from .. import state
+from .. import simple_dependency_satisfying as sds
+from .. import state as st
from .. import policies
PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
-PayloadsData = t.Mapping[state.PayloadRef, state.PayloadData]
+PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData]
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcretePayloadRef(st.PayloadRef):
+ state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False)
+
+ def get_data(self) -> st.PayloadData:
+ try:
+ return self.state.payloads_data[self]
+ except KeyError:
+ raise st.MissingItemError()
+
+ def get_mapping(self) -> st.MappingVersionRef:
+ raise NotImplementedError()
+
+ def get_script_paths(self) \
+ -> t.Iterable[t.Sequence[str]]:
+ with self.state.cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ i.identifier, fu.name
+ FROM
+ payloads AS p
+ LEFT JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ LEFT JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ LEFT JOIN items AS i
+ USING (item_id)
+ LEFT JOIN file_uses AS fu
+ USING (item_version_id)
+ WHERE
+ fu.type = 'W' AND
+ p.payload_id = ? AND
+ (fu.idx IS NOT NULL OR rdd.idx IS NULL)
+ ORDER BY
+ rdd.idx, fu.idx;
+ ''',
+ (self.id,)
+ )
+
+ paths: list[t.Sequence[str]] = []
+ for resource_identifier, file_name in cursor.fetchall():
+ if resource_identifier is None:
+ # payload found but it had no script files
+ return ()
+
+ paths.append((resource_identifier, *file_name.split('/')))
+
+ if paths == []:
+ # payload not found
+ raise st.MissingItemError()
+
+ return paths
+
+ def get_file_data(self, path: t.Sequence[str]) \
+ -> t.Optional[st.FileData]:
+ if len(path) == 0:
+ raise st.MissingItemError()
+
+ resource_identifier, *file_name_segments = path
+
+ file_name = '/'.join(file_name_segments)
+
+ with self.state.cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ f.data, fu.mime_type
+ FROM
+ payloads AS p
+ JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ JOIN items AS i
+ USING (item_id)
+ JOIN file_uses AS fu
+ USING (item_version_id)
+ JOIN files AS f
+ USING (file_id)
+ WHERE
+ p.payload_id = ? AND
+ i.identifier = ? AND
+ fu.name = ? AND
+ fu.type = 'W';
+ ''',
+ (self.id, resource_identifier, file_name)
+ )
+
+ result = cursor.fetchall()
+
+ if result == []:
+ return None
+
+ (data, mime_type), = result
+
+ return st.FileData(type=mime_type, name=file_name, contents=data)
+
+def register_payload(
+ policy_tree: PolicyTree,
+ pattern: url_patterns.ParsedPattern,
+ payload_key: st.PayloadKey,
+ token: str
+) -> PolicyTree:
+ """...."""
+ payload_policy_factory = policies.PayloadPolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(pattern, payload_policy_factory)
+
+ resource_policy_factory = policies.PayloadResourcePolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ pattern.path_append(token, '***'),
+ resource_policy_factory
+ )
+
+ return policy_tree
# mypy needs to be corrected:
# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
@dc.dataclass # type: ignore[misc]
-class HaketiloStateWithFields(state.HaketiloState):
+class HaketiloStateWithFields(st.HaketiloState):
"""...."""
store_dir: Path
connection: sqlite3.Connection
current_cursor: t.Optional[sqlite3.Cursor] = None
- #settings: state.HaketiloGlobalSettings
+ #settings: st.HaketiloGlobalSettings
policy_tree: PolicyTree = PolicyTree()
payloads_data: PayloadsData = dc.field(default_factory=dict)
@@ -98,6 +224,85 @@ class HaketiloStateWithFields(state.HaketiloState):
finally:
self.current_cursor = None
- def recompute_payloads(self, cursor: sqlite3.Cursor) -> None:
+ def rebuild_structures(self) -> None:
+ """
+ Recreation of data structures as done after every recomputation of
+ dependencies as well as at startup.
+ """
+ with self.cursor(transaction=True) as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ p.payload_id, p.pattern, p.eval_allowed,
+ p.cors_bypass_allowed,
+ ms.enabled,
+ i.identifier
+ FROM
+ payloads AS p
+ JOIN item_versions AS iv
+ ON p.mapping_item_id = iv.item_version_id
+ JOIN items AS i USING (item_id)
+ JOIN mapping_statuses AS ms USING (item_id);
+ '''
+ )
+
+ rows = cursor.fetchall()
+
+ new_policy_tree = PolicyTree()
+
+ ui_factory = policies.WebUIPolicyFactory(builtin=True)
+ web_ui_pattern = 'http*://hkt.mitm.it/***'
+ for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern):
+ new_policy_tree = new_policy_tree.register(
+ parsed_pattern,
+ ui_factory
+ )
+
+ new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
+
+ for row in rows:
+ (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
+ enabled_status,
+ identifier) = row
+
+ payload_ref = ConcretePayloadRef(str(payload_id_int), self)
+
+ previous_data = self.payloads_data.get(payload_ref)
+ if previous_data is not None:
+ token = previous_data.unique_token
+ else:
+ token = secrets.token_urlsafe(8)
+
+ payload_key = st.PayloadKey(payload_ref, identifier)
+
+ for parsed_pattern in url_patterns.parse_pattern(pattern):
+ new_policy_tree = register_payload(
+ new_policy_tree,
+ parsed_pattern,
+ payload_key,
+ token
+ )
+
+ pattern_path_segments = parsed_pattern.path_segments
+
+ payload_data = st.PayloadData(
+ payload_ref = payload_ref,
+ explicitly_enabled = enabled_status == 'E',
+ unique_token = token,
+ pattern_path_segments = pattern_path_segments,
+ eval_allowed = eval_allowed,
+ cors_bypass_allowed = cors_bypass_allowed
+ )
+
+ new_payloads_data[payload_ref] = payload_data
+
+ self.policy_tree = new_policy_tree
+ self.payloads_data = new_payloads_data
+
+ @abstractmethod
+ def recompute_dependencies(
+ self,
+ requirements: t.Iterable[sds.MappingRequirement] = []
+ ) -> None:
"""...."""
...
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py
index 525a702..6a59a75 100644
--- a/src/hydrilla/proxy/state_impl/concrete_state.py
+++ b/src/hydrilla/proxy/state_impl/concrete_state.py
@@ -32,28 +32,23 @@ and resources.
# Enable using with Python 3.7.
from __future__ import annotations
-import secrets
-import io
-import hashlib
+import sqlite3
import typing as t
import dataclasses as dc
from pathlib import Path
-import sqlite3
-
from ...exceptions import HaketiloException
from ...translations import smart_gettext as _
-from ... import pattern_tree
from ... import url_patterns
from ... import item_infos
-from ..simple_dependency_satisfying import compute_payloads, ComputedPayload
from .. import state as st
from .. import policies
+from .. import simple_dependency_satisfying as sds
from . import base
from . import mappings
from . import repos
-from .load_packages import load_packages
+from . import _operations
here = Path(__file__).resolve().parent
@@ -73,169 +68,14 @@ class ConcreteResourceRef(st.ResourceRef):
class ConcreteResourceVersionRef(st.ResourceVersionRef):
pass
-
-@dc.dataclass(frozen=True, unsafe_hash=True)
-class ConcretePayloadRef(st.PayloadRef):
- state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False)
-
- def get_data(self) -> st.PayloadData:
- try:
- return self.state.payloads_data[self]
- except KeyError:
- raise st.MissingItemError()
-
- def get_mapping(self) -> st.MappingVersionRef:
- raise NotImplementedError()
-
- def get_script_paths(self) \
- -> t.Iterable[t.Sequence[str]]:
- with self.state.cursor() as cursor:
- cursor.execute(
- '''
- SELECT
- i.identifier, fu.name
- FROM
- payloads AS p
- LEFT JOIN resolved_depended_resources AS rdd
- USING (payload_id)
- LEFT JOIN item_versions AS iv
- ON rdd.resource_item_id = iv.item_version_id
- LEFT JOIN items AS i
- USING (item_id)
- LEFT JOIN file_uses AS fu
- USING (item_version_id)
- WHERE
- fu.type = 'W' AND
- p.payload_id = ? AND
- (fu.idx IS NOT NULL OR rdd.idx IS NULL)
- ORDER BY
- rdd.idx, fu.idx;
- ''',
- (self.id,)
- )
-
- paths: list[t.Sequence[str]] = []
- for resource_identifier, file_name in cursor.fetchall():
- if resource_identifier is None:
- # payload found but it had no script files
- return ()
-
- paths.append((resource_identifier, *file_name.split('/')))
-
- if paths == []:
- # payload not found
- raise st.MissingItemError()
-
- return paths
-
- def get_file_data(self, path: t.Sequence[str]) \
- -> t.Optional[st.FileData]:
- if len(path) == 0:
- raise st.MissingItemError()
-
- resource_identifier, *file_name_segments = path
-
- file_name = '/'.join(file_name_segments)
-
- with self.state.cursor() as cursor:
- cursor.execute(
- '''
- SELECT
- f.data, fu.mime_type
- FROM
- payloads AS p
- JOIN resolved_depended_resources AS rdd
- USING (payload_id)
- JOIN item_versions AS iv
- ON rdd.resource_item_id = iv.item_version_id
- JOIN items AS i
- USING (item_id)
- JOIN file_uses AS fu
- USING (item_version_id)
- JOIN files AS f
- USING (file_id)
- WHERE
- p.payload_id = ? AND
- i.identifier = ? AND
- fu.name = ? AND
- fu.type = 'W';
- ''',
- (self.id, resource_identifier, file_name)
- )
-
- result = cursor.fetchall()
-
- if result == []:
- return None
-
- (data, mime_type), = result
-
- return st.FileData(type=mime_type, name=file_name, contents=data)
-
-def register_payload(
- policy_tree: base.PolicyTree,
- pattern: url_patterns.ParsedPattern,
- payload_key: st.PayloadKey,
- token: str
-) -> base.PolicyTree:
- """...."""
- payload_policy_factory = policies.PayloadPolicyFactory(
- builtin = False,
- payload_key = payload_key
- )
-
- policy_tree = policy_tree.register(pattern, payload_policy_factory)
-
- resource_policy_factory = policies.PayloadResourcePolicyFactory(
- builtin = False,
- payload_key = payload_key
- )
-
- policy_tree = policy_tree.register(
- pattern.path_append(token, '***'),
- resource_policy_factory
- )
-
- return policy_tree
-
-AnyInfoVar = t.TypeVar(
- 'AnyInfoVar',
- item_infos.ResourceInfo,
- item_infos.MappingInfo
-)
-
-def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \
- -> t.Mapping[AnyInfoVar, int]:
- cursor.execute(
- '''
- SELECT
- i.item_id, iv.definition, r.name, ri.iteration
- FROM
- item_versions AS iv
- JOIN items AS i USING (item_id)
- JOIN repo_iterations AS ri USING (repo_iteration_id)
- JOIN repos AS r USING (repo_id)
- WHERE
- i.type = ?;
- ''',
- (info_type.type_name[0].upper(),)
- )
-
- result: dict[AnyInfoVar, int] = {}
-
- for item_id, definition, repo_name, repo_iteration in cursor.fetchall():
- definition_io = io.StringIO(definition)
- info = info_type.load(definition_io, repo_name, repo_iteration)
- result[info] = item_id
-
- return result
-
@dc.dataclass
class ConcreteHaketiloState(base.HaketiloStateWithFields):
def __post_init__(self) -> None:
+ sqlite3.enable_callback_tracebacks(True)
+
self._prepare_database()
- self._rebuild_structures()
+ self.rebuild_structures()
def _prepare_database(self) -> None:
"""...."""
@@ -282,147 +122,25 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
def import_packages(self, malcontent_path: Path) -> None:
with self.cursor(transaction=True) as cursor:
- load_packages(self, cursor, malcontent_path)
- self.recompute_payloads(cursor)
-
- def recompute_payloads(self, cursor: sqlite3.Cursor) -> None:
- assert self.connection.in_transaction
-
- resources = get_infos_of_type(cursor, item_infos.ResourceInfo)
- mappings = get_infos_of_type(cursor, item_infos.MappingInfo)
-
- payloads = compute_payloads(resources.keys(), mappings.keys())
-
- payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
-
- cursor.execute('DELETE FROM payloads;')
-
- for mapping_info, by_pattern in payloads.items():
- for num, (pattern, payload) in enumerate(by_pattern.items()):
- cursor.execute(
- '''
- INSERT INTO payloads(
- mapping_item_id,
- pattern,
- eval_allowed,
- cors_bypass_allowed
- )
- VALUES (?, ?, ?, ?);
- ''',
- (
- mappings[mapping_info],
- pattern.orig_url,
- payload.allows_eval,
- payload.allows_cors_bypass
- )
- )
+ _operations.load_packages(cursor, malcontent_path, 1)
+ raise NotImplementedError()
+ _operations.prune_packages(cursor)
- cursor.execute(
- '''
- SELECT
- payload_id
- FROM
- payloads
- WHERE
- mapping_item_id = ? AND pattern = ?;
- ''',
- (mappings[mapping_info], pattern.orig_url)
- )
+ self.recompute_dependencies()
- (payload_id_int,), = cursor.fetchall()
-
- for res_num, resource_info in enumerate(payload.resources):
- cursor.execute(
- '''
- INSERT INTO resolved_depended_resources(
- payload_id,
- resource_item_id,
- idx
- )
- VALUES(?, ?, ?);
- ''',
- (payload_id_int, resources[resource_info], res_num)
- )
-
- self._rebuild_structures(cursor)
-
- def _rebuild_structures(self, cursor: t.Optional[sqlite3.Cursor] = None) \
- -> None:
- """
- Recreation of data structures as done after every recomputation of
- dependencies as well as at startup.
- """
- if cursor is None:
- with self.cursor() as new_cursor:
- return self._rebuild_structures(new_cursor)
-
- cursor.execute(
- '''
- SELECT
- p.payload_id, p.pattern, p.eval_allowed,
- p.cors_bypass_allowed,
- ms.enabled,
- i.identifier
- FROM
- payloads AS p
- JOIN item_versions AS iv
- ON p.mapping_item_id = iv.item_version_id
- JOIN items AS i USING (item_id)
- JOIN mapping_statuses AS ms USING (item_id);
- '''
- )
-
- new_policy_tree = base.PolicyTree()
-
- ui_factory = policies.WebUIPolicyFactory(builtin=True)
- web_ui_pattern = 'http*://hkt.mitm.it/***'
- for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern):
- new_policy_tree = new_policy_tree.register(
- parsed_pattern,
- ui_factory
- )
-
- new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
-
- for row in cursor.fetchall():
- (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
- enabled_status,
- identifier) = row
-
- payload_ref = ConcretePayloadRef(str(payload_id_int), self)
-
- previous_data = self.payloads_data.get(payload_ref)
- if previous_data is not None:
- token = previous_data.unique_token
- else:
- token = secrets.token_urlsafe(8)
-
- payload_key = st.PayloadKey(payload_ref, identifier)
-
- for parsed_pattern in url_patterns.parse_pattern(pattern):
- new_policy_tree = register_payload(
- new_policy_tree,
- parsed_pattern,
- payload_key,
- token
- )
-
- pattern_path_segments = parsed_pattern.path_segments
+ def recompute_dependencies(
+ self,
+ extra_requirements: t.Iterable[sds.MappingRequirement] = []
+ ) -> None:
+ with self.cursor() as cursor:
+ assert self.connection.in_transaction
- payload_data = st.PayloadData(
- payload_ref = payload_ref,
- #explicitly_enabled = enabled_status == 'E',
- explicitly_enabled = True,
- unique_token = token,
- pattern_path_segments = pattern_path_segments,
- eval_allowed = eval_allowed,
- cors_bypass_allowed = cors_bypass_allowed
+ _operations._recompute_dependencies_no_state_update(
+ cursor,
+ extra_requirements
)
- new_payloads_data[payload_ref] = payload_data
-
- self.policy_tree = new_policy_tree
- self.payloads_data = new_payloads_data
+ self.rebuild_structures()
def repo_store(self) -> st.RepoStore:
return repos.ConcreteRepoStore(self)
diff --git a/src/hydrilla/proxy/state_impl/mappings.py b/src/hydrilla/proxy/state_impl/mappings.py
index 3668784..cce2a36 100644
--- a/src/hydrilla/proxy/state_impl/mappings.py
+++ b/src/hydrilla/proxy/state_impl/mappings.py
@@ -76,7 +76,7 @@ class ConcreteMappingVersionRef(st.MappingVersionRef):
(status_letter, definition, repo, repo_iteration, is_orphan), = rows
item_info = item_infos.MappingInfo.load(
- io.StringIO(definition),
+ definition,
repo,
repo_iteration
)
@@ -120,7 +120,7 @@ class ConcreteMappingVersionStore(st.MappingVersionStore):
ref = ConcreteMappingVersionRef(str(item_version_id), self.state)
item_info = item_infos.MappingInfo.load(
- io.StringIO(definition),
+ definition,
repo,
repo_iteration
)
diff --git a/src/hydrilla/proxy/state_impl/repos.py b/src/hydrilla/proxy/state_impl/repos.py
index 5553ec2..f4c7c71 100644
--- a/src/hydrilla/proxy/state_impl/repos.py
+++ b/src/hydrilla/proxy/state_impl/repos.py
@@ -33,20 +33,27 @@ inside Haketilo.
from __future__ import annotations
import re
+import json
+import tempfile
+import requests
+import sqlite3
import typing as t
import dataclasses as dc
-from urllib.parse import urlparse
+from urllib.parse import urlparse, urljoin
from datetime import datetime
+from pathlib import Path
-import sqlite3
-
+from ... import json_instances
+from ... import item_infos
+from ... import versions
from .. import state as st
+from .. import simple_dependency_satisfying as sds
from . import base
-from . import prune_packages
+from . import _operations
-def validate_repo_url(url: str) -> None:
+def sanitize_repo_url(url: str) -> str:
try:
parsed = urlparse(url)
except:
@@ -55,6 +62,11 @@ def validate_repo_url(url: str) -> None:
if parsed.scheme not in ('http', 'https'):
raise st.RepoUrlInvalid()
+ if url[-1] != '/':
+ url = url + '/'
+
+ return url
+
def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None:
cursor.execute(
@@ -73,6 +85,53 @@ def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None:
raise st.MissingItemError()
+def sync_remote_repo_definitions(repo_url: str, dest: Path) -> None:
+ try:
+ list_all_response = requests.get(urljoin(repo_url, 'list_all'))
+ assert list_all_response.ok
+
+ list_instance = list_all_response.json()
+ except:
+ raise st.RepoCommunicationError()
+
+ try:
+ json_instances.validate_instance(
+ list_instance,
+ 'api_package_list-{}.schema.json'
+ )
+ except json_instances.UnknownSchemaError:
+ raise st.RepoApiVersionUnsupported()
+ except:
+ raise st.RepoCommunicationError()
+
+ ref: dict[str, t.Any]
+
+ for item_type_name in ('resource', 'mapping'):
+ for ref in list_instance[item_type_name + 's']:
+ ver = versions.version_string(versions.normalize(ref['version']))
+ item_rel_path = f'{item_type_name}/{ref["identifier"]}/{ver}'
+
+ try:
+ item_response = requests.get(urljoin(repo_url, item_rel_path))
+ assert item_response.ok
+ except:
+ raise st.RepoCommunicationError()
+
+ item_path = dest / item_rel_path
+ item_path.parent.mkdir(parents=True, exist_ok=True)
+ item_path.write_bytes(item_response.content)
+
+
+@dc.dataclass(frozen=True)
+class RemoteFileResolver(_operations.FileResolver):
+ repo_url: str
+
+ def by_sha256(self, sha256: str) -> bytes:
+ response = requests.get(urljoin(self.repo_url, f'file/sha256/{sha256}'))
+ assert response.ok
+ return response.content
+
+
def make_repo_display_info(
ref: st.RepoRef,
name: str,
@@ -122,6 +181,37 @@ class ConcreteRepoRef(st.RepoRef):
(self.id,)
)
+ _operations.prune_packages(cursor)
+
+ # For mappings explicitly enabled by the user (+ all mappings they
+ # recursively depend on) let's make sure that their exact same
+ # versions will be enabled after the change.
+ cursor.execute(
+ '''
+ SELECT
+ iv.definition, r.name, ri.iteration
+ FROM
+ mapping_statuses AS ms
+ JOIN item_versions AS iv
+ ON ms.active_version_id = iv.item_version_id
+ JOIN repo_iterations AS ri
+ USING (repo_iteration_id)
+ JOIN repos AS r
+ USING (repo_id)
+ WHERE
+ ms.required
+ '''
+ )
+
+ requirements = []
+
+ for definition, repo, iteration in cursor.fetchall():
+ info = item_infos.MappingInfo.load(definition, repo, iteration)
+ req = sds.MappingVersionRequirement(info.identifier, info)
+ requirements.append(req)
+
+ self.state.recompute_dependencies(requirements)
+
def update(
self,
*,
@@ -134,7 +224,7 @@ class ConcreteRepoRef(st.RepoRef):
if url is None:
return
- validate_repo_url(url)
+ url = sanitize_repo_url(url)
with self.state.cursor(transaction=True) as cursor:
ensure_repo_not_deleted(cursor, self.id)
@@ -144,10 +234,30 @@ class ConcreteRepoRef(st.RepoRef):
(url, self.id)
)
- prune_packages.prune(cursor)
-
def refresh(self) -> st.RepoIterationRef:
- raise NotImplementedError()
+ with self.state.cursor(transaction=True) as cursor:
+ ensure_repo_not_deleted(cursor, self.id)
+
+ cursor.execute(
+ 'SELECT url from repos where repo_id = ?;',
+ (self.id,)
+ )
+
+ (repo_url,), = cursor.fetchall()
+
+ with tempfile.TemporaryDirectory() as tmpdir_str:
+ tmpdir = Path(tmpdir_str)
+ sync_remote_repo_definitions(repo_url, tmpdir)
+ new_iteration_id = _operations.load_packages(
+ cursor,
+ tmpdir,
+ int(self.id),
+ RemoteFileResolver(repo_url)
+ )
+
+ self.state.recompute_dependencies()
+
+ return ConcreteRepoIterationRef(str(new_iteration_id), self.state)
def get_display_info(self) -> st.RepoDisplayInfo:
with self.state.cursor() as cursor:
@@ -199,7 +309,7 @@ class ConcreteRepoStore(st.RepoStore):
if repo_name_regex.match(name) is None:
raise st.RepoNameInvalid()
- validate_repo_url(url)
+ url = sanitize_repo_url(url)
with self.state.cursor(transaction=True) as cursor:
cursor.execute(
@@ -272,3 +382,8 @@ class ConcreteRepoStore(st.RepoStore):
result.append(make_repo_display_info(ref, *rest))
return result
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteRepoIterationRef(st.RepoIterationRef):
+ state: base.HaketiloStateWithFields
diff --git a/src/hydrilla/proxy/tables.sql b/src/hydrilla/proxy/tables.sql
index fc7c65c..0c604fc 100644
--- a/src/hydrilla/proxy/tables.sql
+++ b/src/hydrilla/proxy/tables.sql
@@ -132,7 +132,9 @@ CREATE TABLE items(
);
CREATE TABLE mapping_statuses(
- -- The item with this id shall be a mapping ("type" = 'M').
+ -- The item with this id shall be a mapping ("type" = 'M'). For each
+ -- mapping row in "items" there must be an accompanying row in this
+ -- table.
item_id INTEGER PRIMARY KEY,
-- "enabled" determines whether mapping's status is ENABLED,
@@ -142,19 +144,32 @@ CREATE TABLE mapping_statuses(
-- EXACT_VERSION, is to be updated only with versions from the same
-- REPOSITORY or is NOT_FROZEN at all.
frozen CHAR(1) NULL,
+ -- The last 2 fields defined below shall be updated when dependency tree
+ -- is recomputed.
+ -- When "required" is TRUE, the mapping is assumed to either be enabled
+ -- or be (directly or indirectly) required by another mapping which is
+ -- enabled (i.e. has "enabled" set to 'E').
+ required BOOLEAN NOT NULL,
+ -- Only one version of a mapping is allowed to be active at any time.
+ -- "active_version_id" indicates which version it is. Only a mapping
+ -- version referenced by "active_version_id" is allowed to have rows
+ -- in the "payloads" table reference it.
active_version_id INTEGER NULL,
CHECK (enabled IN ('E', 'D', 'N')),
CHECK ((frozen IS NULL) = (enabled != 'E')),
CHECK (frozen IS NULL OR frozen in ('E', 'R', 'N')),
- CHECK (enabled != 'E' OR active_version_id IS NOT NULL)
- CHECK (enabled != 'D' OR active_version_id IS NULL)
+ CHECK (enabled != 'E' OR required),
+ CHECK (enabled != 'D' OR NOT required),
+ CHECK (not required OR active_version_id IS NOT NULL),
+ CHECK (enabled != 'D' OR active_version_id IS NULL),
FOREIGN KEY (item_id)
REFERENCES items (item_id)
ON DELETE CASCADE,
FOREIGN KEY (active_version_id, item_id)
REFERENCES item_versions (item_version_id, item_id)
+ ON DELETE SET NULL
);
CREATE TABLE item_versions(
@@ -163,7 +178,7 @@ CREATE TABLE item_versions(
item_id INTEGER NOT NULL,
version VARCHAR NOT NULL,
repo_iteration_id INTEGER NOT NULL,
- definition TEXT NOT NULL,
+ definition BLOB NOT NULL,
UNIQUE (item_id, version, repo_iteration_id),
-- Constraint below needed to allow foreign key from "mapping_statuses".
@@ -189,8 +204,10 @@ FROM
GROUP BY
r.repo_id, r.name, r.url, r.deleted, r.last_refreshed;
--- Every time a repository gets refreshed, or a mapping gets enabled/disabled,
--- all dependencies the "payloads" table and those that reference it are
+-- Every time a repository gets refreshed or a mapping gets enabled/disabled,
+-- the dependency tree is recomputed. In the process the "payloads" table gets
+-- cleare and repopulated together with the "resolved_depended_resources" that
+-- depends on it.
CREATE TABLE payloads(
payload_id INTEGER PRIMARY KEY,
diff --git a/src/hydrilla/server/malcontent.py b/src/hydrilla/server/malcontent.py
index 49c0fb4..ce24330 100644
--- a/src/hydrilla/server/malcontent.py
+++ b/src/hydrilla/server/malcontent.py
@@ -79,7 +79,7 @@ class VersionedItemInfo(
Find and return info of the specified version of the item (or None if
absent).
"""
- return self.items.get(versions.normalize_version(ver))
+ return self.items.get(versions.normalize(ver))
def get_all(self) -> t.Iterable[VersionedType]:
"""Generate item info for all its versions, from oldest to newest."""
diff --git a/src/hydrilla/versions.py b/src/hydrilla/versions.py
index 93f395d..12d1c18 100644
--- a/src/hydrilla/versions.py
+++ b/src/hydrilla/versions.py
@@ -36,7 +36,7 @@ import typing as t
VerTuple = t.NewType('VerTuple', 'tuple[int, ...]')
-def normalize_version(ver: t.Sequence[int]) -> VerTuple:
+def normalize(ver: t.Sequence[int]) -> VerTuple:
"""Strip rightmost zeroes from 'ver'."""
new_len = 0
for i, num in enumerate(ver):
@@ -57,7 +57,7 @@ def parse_normalize(ver_str: str) -> VerTuple:
Convert 'ver_str' into a VerTuple representation, e.g. for
ver_str="4.6.13.0" return (4, 6, 13).
"""
- return normalize_version(parse(ver_str))
+ return normalize(parse(ver_str))
def version_string(ver: VerTuple, rev: t.Optional[int] = None) -> str:
"""