aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/state_impl
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/state_impl')
-rw-r--r--src/hydrilla/proxy/state_impl/_operations/__init__.py9
-rw-r--r--src/hydrilla/proxy/state_impl/_operations/load_packages.py (renamed from src/hydrilla/proxy/state_impl/load_packages.py)124
-rw-r--r--src/hydrilla/proxy/state_impl/_operations/prune_packages.py (renamed from src/hydrilla/proxy/state_impl/prune_packages.py)65
-rw-r--r--src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py223
-rw-r--r--src/hydrilla/proxy/state_impl/base.py219
-rw-r--r--src/hydrilla/proxy/state_impl/concrete_state.py322
-rw-r--r--src/hydrilla/proxy/state_impl/mappings.py4
-rw-r--r--src/hydrilla/proxy/state_impl/repos.py135
8 files changed, 711 insertions, 390 deletions
diff --git a/src/hydrilla/proxy/state_impl/_operations/__init__.py b/src/hydrilla/proxy/state_impl/_operations/__init__.py
new file mode 100644
index 0000000..c147be4
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl/_operations/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-License-Identifier: CC0-1.0
+
+# Copyright (C) 2022 Wojtek Kosior <koszko@koszko.org>
+#
+# Available under the terms of Creative Commons Zero v1.0 Universal.
+
+from .load_packages import load_packages, FileResolver
+from .prune_packages import prune_packages
+from .recompute_dependencies import _recompute_dependencies_no_state_update
diff --git a/src/hydrilla/proxy/state_impl/load_packages.py b/src/hydrilla/proxy/state_impl/_operations/load_packages.py
index 6983c3e..c294ef0 100644
--- a/src/hydrilla/proxy/state_impl/load_packages.py
+++ b/src/hydrilla/proxy/state_impl/_operations/load_packages.py
@@ -37,37 +37,49 @@ import dataclasses as dc
import typing as t
from pathlib import Path
+from abc import ABC, abstractmethod
import sqlite3
-from ...exceptions import HaketiloException
-from ...translations import smart_gettext as _
-from ... import versions
-from ... import item_infos
-from . import base
+from ....exceptions import HaketiloException
+from ....translations import smart_gettext as _
+from .... import versions
+from .... import item_infos
-def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int:
+def make_repo_iteration(cursor: sqlite3.Cursor, repo_id: int) -> int:
cursor.execute(
'''
SELECT
- repo_id, next_iteration - 1
+ next_iteration
FROM
repos
WHERE
- name = ?;
+ repo_id = ?;
''',
- (repo_name,)
+ (repo_id,)
)
- (repo_id, last_iteration), = cursor.fetchall()
+ (next_iteration,), = cursor.fetchall()
cursor.execute(
'''
- INSERT OR IGNORE INTO repo_iterations(repo_id, iteration)
+ UPDATE
+ repos
+ SET
+ next_iteration = ?
+ WHERE
+ repo_id = ?;
+ ''',
+ (next_iteration + 1, repo_id)
+ )
+
+ cursor.execute(
+ '''
+ INSERT INTO repo_iterations(repo_id, iteration)
VALUES(?, ?);
''',
- (repo_id, last_iteration)
+ (repo_id, next_iteration)
)
cursor.execute(
@@ -79,7 +91,7 @@ def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int:
WHERE
repo_id = ? AND iteration = ?;
''',
- (repo_id, last_iteration)
+ (repo_id, next_iteration)
)
(repo_iteration_id,), = cursor.fetchall()
@@ -118,7 +130,7 @@ def get_or_make_item_version(
item_id: int,
repo_iteration_id: int,
version: versions.VerTuple,
- definition: str
+ definition: bytes
) -> int:
ver_str = versions.version_string(version)
@@ -154,8 +166,8 @@ def get_or_make_item_version(
def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None:
cursor.execute(
'''
- INSERT OR IGNORE INTO mapping_statuses(item_id, enabled, frozen)
- VALUES(?, 'E', 'R');
+ INSERT OR IGNORE INTO mapping_statuses(item_id, enabled, required)
+ VALUES(?, 'N', FALSE);
''',
(item_id,)
)
@@ -215,14 +227,18 @@ class _FileInfo:
id: int
is_ascii: bool
+class FileResolver(ABC):
+ @abstractmethod
+ def by_sha256(self, sha256: str) -> bytes:
+ ...
+
def _add_item(
- cursor: sqlite3.Cursor,
- files_by_sha256_path: Path,
- info: item_infos.AnyInfo,
- definition: str
+ cursor: sqlite3.Cursor,
+ package_file_resolver: FileResolver,
+ info: item_infos.AnyInfo,
+ definition: bytes,
+ repo_iteration_id: int
) -> None:
- repo_iteration_id = get_or_make_repo_iteration(cursor, '<local>')
-
item_id = get_or_make_item(cursor, info.type_name, info.identifier)
item_version_id = get_or_make_item_version(
@@ -243,17 +259,7 @@ def _add_item(
file_specifiers.extend(info.scripts)
for file_spec in file_specifiers:
- file_path = files_by_sha256_path / file_spec.sha256
- if not file_path.is_file():
- fmt = _('err.proxy.file_missing_{item_identifier}_{file_name}_{sha256}')
- msg = fmt.format(
- item_identifier = info.identifier,
- file_name = file_spec.name,
- sha256 = file_spec.sha256
- )
- raise HaketiloException(msg)
-
- file_bytes = file_path.read_bytes()
+ file_bytes = package_file_resolver.by_sha256(file_spec.sha256)
sha256 = hashlib.sha256(file_bytes).digest().hex()
if sha256 != file_spec.sha256:
@@ -309,7 +315,7 @@ AnyInfoVar = t.TypeVar(
)
def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
- -> t.Iterator[tuple[AnyInfoVar, str]]:
+ -> t.Iterator[tuple[AnyInfoVar, bytes]]:
item_type_path = malcontent_path / item_class.type_name
if not item_type_path.is_dir():
return
@@ -319,8 +325,8 @@ def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
continue
for item_version_path in item_path.iterdir():
- definition = item_version_path.read_text()
- item_info = item_class.load(io.StringIO(definition))
+ definition = item_version_path.read_bytes()
+ item_info = item_class.load(definition)
assert item_info.identifier == item_path.name
assert versions.version_string(item_info.version) == \
@@ -328,17 +334,45 @@ def _read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
yield item_info, definition
+@dc.dataclass(frozen=True)
+class MalcontentFileResolver(FileResolver):
+ malcontent_dir_path: Path
+
+ def by_sha256(self, sha256: str) -> bytes:
+ file_path = self.malcontent_dir_path / 'file' / 'sha256' / sha256
+ if not file_path.is_file():
+ fmt = _('err.proxy.file_missing_{sha256}')
+ raise HaketiloException(fmt.format(sha256=sha256))
+
+ return file_path.read_bytes()
+
def load_packages(
- state: base.HaketiloStateWithFields,
- cursor: sqlite3.Cursor,
- malcontent_path: Path
-) -> None:
- files_by_sha256_path = malcontent_path / 'file' / 'sha256'
+ cursor: sqlite3.Cursor,
+ malcontent_path: Path,
+ repo_id: int,
+ package_file_resolver: t.Optional[FileResolver] = None
+) -> int:
+ if package_file_resolver is None:
+ package_file_resolver = MalcontentFileResolver(malcontent_path)
- for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]:
+ repo_iteration_id = make_repo_iteration(cursor, repo_id)
+
+ types: t.Iterable[t.Type[item_infos.AnyInfo]] = \
+ [item_infos.ResourceInfo, item_infos.MappingInfo]
+
+ for info_type in types:
info: item_infos.AnyInfo
- for info, definition in _read_items(
+
+ for info, definition in _read_items( # type: ignore
malcontent_path,
- info_type # type: ignore
+ info_type
):
- _add_item(cursor, files_by_sha256_path, info, definition)
+ _add_item(
+ cursor,
+ package_file_resolver,
+ info,
+ definition,
+ repo_iteration_id
+ )
+
+ return repo_iteration_id
diff --git a/src/hydrilla/proxy/state_impl/prune_packages.py b/src/hydrilla/proxy/state_impl/_operations/prune_packages.py
index 1857188..9c2b1d7 100644
--- a/src/hydrilla/proxy/state_impl/prune_packages.py
+++ b/src/hydrilla/proxy/state_impl/_operations/prune_packages.py
@@ -33,27 +33,42 @@ from __future__ import annotations
import sqlite3
+from pathlib import Path
-_remove_mapping_versions_sql = '''
-WITH removed_mappings AS (
- SELECT
- iv.item_version_id
- FROM
- item_versions AS iv
- JOIN items AS i
- USING (item_id)
- JOIN orphan_iterations AS oi
- USING (repo_iteration_id)
- LEFT JOIN payloads AS p
- ON p.mapping_item_id = iv.item_version_id
- WHERE
- i.type = 'M' AND p.payload_id IS NULL
-)
-DELETE FROM
- item_versions
-WHERE
- item_version_id IN removed_mappings;
-'''
+
+_remove_mapping_versions_sqls = [
+ '''
+ CREATE TEMPORARY TABLE removed_mappings(
+ item_version_id INTEGER PRIMARY KEY
+ );
+ ''', '''
+ INSERT INTO
+ removed_mappings
+ SELECT
+ iv.item_version_id
+ FROM
+ item_versions AS iv
+ JOIN items AS i USING (item_id)
+ JOIN mapping_statuses AS ms USING (item_id)
+ JOIN orphan_iterations AS oi USING (repo_iteration_id)
+ WHERE
+ NOT ms.required;
+ ''', '''
+ UPDATE
+ mapping_statuses
+ SET
+ active_version_id = NULL
+ WHERE
+ active_version_id IN removed_mappings;
+ ''', '''
+ DELETE FROM
+ item_versions
+ WHERE
+ item_version_id IN removed_mappings;
+ ''', '''
+ DROP TABLE removed_mappings;
+ '''
+]
_remove_resource_versions_sql = '''
WITH removed_resources AS (
@@ -134,7 +149,7 @@ WITH removed_repos AS (
repos AS r
LEFT JOIN repo_iterations AS ri USING (repo_id)
WHERE
- r.deleted AND ri.repo_iteration_id IS NULL
+ r.deleted AND ri.repo_iteration_id IS NULL AND r.repo_id != 1
)
DELETE FROM
repos
@@ -142,9 +157,11 @@ WHERE
repo_id IN removed_repos;
'''
-def prune(cursor: sqlite3.Cursor) -> None:
- """...."""
- cursor.execute(_remove_mapping_versions_sql)
+def prune_packages(cursor: sqlite3.Cursor) -> None:
+ assert cursor.connection.in_transaction
+
+ for sql in _remove_mapping_versions_sqls:
+ cursor.execute(sql)
cursor.execute(_remove_resource_versions_sql)
cursor.execute(_remove_items_sql)
cursor.execute(_remove_files_sql)
diff --git a/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py b/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py
new file mode 100644
index 0000000..4093f12
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py
@@ -0,0 +1,223 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy data and configuration (update of dependency tree in the db).
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+....
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import typing as t
+
+import sqlite3
+
+from .... import item_infos
+from ... import simple_dependency_satisfying as sds
+
+
+AnyInfoVar = t.TypeVar(
+ 'AnyInfoVar',
+ item_infos.ResourceInfo,
+ item_infos.MappingInfo
+)
+
+def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \
+ -> t.Mapping[int, AnyInfoVar]:
+ cursor.execute(
+ '''
+ SELECT
+ i.item_id, iv.definition, r.name, ri.iteration
+ FROM
+ item_versions AS iv
+ JOIN items AS i USING (item_id)
+ JOIN repo_iterations AS ri USING (repo_iteration_id)
+ JOIN repos AS r USING (repo_id)
+ WHERE
+ i.type = ?;
+ ''',
+ (info_type.type_name[0].upper(),)
+ )
+
+ result: dict[int, AnyInfoVar] = {}
+
+ for item_id, definition, repo_name, repo_iteration in cursor.fetchall():
+ info = info_type.load(definition, repo_name, repo_iteration)
+ result[item_id] = info
+
+ return result
+
+def _recompute_dependencies_no_state_update(
+ cursor: sqlite3.Cursor,
+ extra_requirements: t.Iterable[sds.MappingRequirement]
+) -> None:
+ cursor.execute('DELETE FROM payloads;')
+
+ ids_to_resources = get_infos_of_type(cursor, item_infos.ResourceInfo)
+ ids_to_mappings = get_infos_of_type(cursor, item_infos.MappingInfo)
+
+ resources = ids_to_resources.items()
+ resources_to_ids = dict((info.identifier, id) for id, info in resources)
+
+ mappings = ids_to_mappings.items()
+ mappings_to_ids = dict((info.identifier, id) for id, info in mappings)
+
+ requirements = [*extra_requirements]
+
+ cursor.execute(
+ '''
+ SELECT
+ i.identifier
+ FROM
+ mapping_statuses AS ms
+ JOIN items AS i USING(item_id)
+ WHERE
+ ms.enabled = 'E' AND ms.frozen = 'N';
+ '''
+ )
+
+ for mapping_identifier, in cursor.fetchall():
+ requirements.append(sds.MappingRequirement(mapping_identifier))
+
+ cursor.execute(
+ '''
+ SELECT
+ active_version_id, frozen
+ FROM
+ mapping_statuses
+ WHERE
+ enabled = 'E' AND frozen IN ('R', 'E');
+ '''
+ )
+
+ for active_version_id, frozen in cursor.fetchall():
+ info = ids_to_mappings[active_version_id]
+
+ requirement: sds.MappingRequirement
+
+ if frozen == 'R':
+ requirement = sds.MappingRepoRequirement(info.identifier, info.repo)
+ else:
+ requirement = sds.MappingVersionRequirement(info.identifier, info)
+
+ requirements.append(requirement)
+
+ mapping_choices = sds.compute_payloads(
+ ids_to_resources.values(),
+ ids_to_mappings.values(),
+ requirements
+ )
+
+ cursor.execute(
+ '''
+ UPDATE
+ mapping_statuses
+ SET
+ required = FALSE,
+ active_version_id = NULL
+ WHERE
+ enabled != 'E';
+ '''
+ )
+
+ cursor.execute('DELETE FROM payloads;')
+
+ for choice in mapping_choices.values():
+ mapping_ver_id = mappings_to_ids[choice.info.identifier]
+
+ cursor.execute(
+ '''
+ SELECT
+ item_id
+ FROM
+ item_versions
+ WHERE
+ item_version_id = ?;
+ ''',
+ (mapping_ver_id,)
+ )
+
+ (mapping_item_id,), = cursor.fetchall()
+
+ cursor.execute(
+ '''
+ UPDATE
+ mapping_statuses
+ SET
+ required = ?,
+ active_version_id = ?
+ WHERE
+ item_id = ?;
+ ''',
+ (choice.required, mapping_ver_id, mapping_item_id)
+ )
+
+ for num, (pattern, payload) in enumerate(choice.payloads.items()):
+ cursor.execute(
+ '''
+ INSERT INTO payloads(
+ mapping_item_id,
+ pattern,
+ eval_allowed,
+ cors_bypass_allowed
+ )
+ VALUES (?, ?, ?, ?);
+ ''',
+ (
+ mapping_ver_id,
+ pattern.orig_url,
+ payload.allows_eval,
+ payload.allows_cors_bypass
+ )
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ payload_id
+ FROM
+ payloads
+ WHERE
+ mapping_item_id = ? AND pattern = ?;
+ ''',
+ (mapping_ver_id, pattern.orig_url)
+ )
+
+ (payload_id,), = cursor.fetchall()
+
+ for res_num, resource_info in enumerate(payload.resources):
+ resource_ver_id = resources_to_ids[resource_info.identifier]
+ cursor.execute(
+ '''
+ INSERT INTO resolved_depended_resources(
+ payload_id,
+ resource_item_id,
+ idx
+ )
+ VALUES(?, ?, ?);
+ ''',
+ (payload_id, resource_ver_id, res_num)
+ )
diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py
index 1ae08cb..92833dd 100644
--- a/src/hydrilla/proxy/state_impl/base.py
+++ b/src/hydrilla/proxy/state_impl/base.py
@@ -34,34 +34,160 @@ subtype.
from __future__ import annotations
import sqlite3
+import secrets
import threading
import dataclasses as dc
import typing as t
from pathlib import Path
from contextlib import contextmanager
-
-import sqlite3
+from abc import abstractmethod
from immutables import Map
+from ... import url_patterns
from ... import pattern_tree
-from .. import state
+from .. import simple_dependency_satisfying as sds
+from .. import state as st
from .. import policies
PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
-PayloadsData = t.Mapping[state.PayloadRef, state.PayloadData]
+PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData]
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcretePayloadRef(st.PayloadRef):
+ state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False)
+
+ def get_data(self) -> st.PayloadData:
+ try:
+ return self.state.payloads_data[self]
+ except KeyError:
+ raise st.MissingItemError()
+
+ def get_mapping(self) -> st.MappingVersionRef:
+ raise NotImplementedError()
+
+ def get_script_paths(self) \
+ -> t.Iterable[t.Sequence[str]]:
+ with self.state.cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ i.identifier, fu.name
+ FROM
+ payloads AS p
+ LEFT JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ LEFT JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ LEFT JOIN items AS i
+ USING (item_id)
+ LEFT JOIN file_uses AS fu
+ USING (item_version_id)
+ WHERE
+ fu.type = 'W' AND
+ p.payload_id = ? AND
+ (fu.idx IS NOT NULL OR rdd.idx IS NULL)
+ ORDER BY
+ rdd.idx, fu.idx;
+ ''',
+ (self.id,)
+ )
+
+ paths: list[t.Sequence[str]] = []
+ for resource_identifier, file_name in cursor.fetchall():
+ if resource_identifier is None:
+ # payload found but it had no script files
+ return ()
+
+ paths.append((resource_identifier, *file_name.split('/')))
+
+ if paths == []:
+ # payload not found
+ raise st.MissingItemError()
+
+ return paths
+
+ def get_file_data(self, path: t.Sequence[str]) \
+ -> t.Optional[st.FileData]:
+ if len(path) == 0:
+ raise st.MissingItemError()
+
+ resource_identifier, *file_name_segments = path
+
+ file_name = '/'.join(file_name_segments)
+
+ with self.state.cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ f.data, fu.mime_type
+ FROM
+ payloads AS p
+ JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ JOIN items AS i
+ USING (item_id)
+ JOIN file_uses AS fu
+ USING (item_version_id)
+ JOIN files AS f
+ USING (file_id)
+ WHERE
+ p.payload_id = ? AND
+ i.identifier = ? AND
+ fu.name = ? AND
+ fu.type = 'W';
+ ''',
+ (self.id, resource_identifier, file_name)
+ )
+
+ result = cursor.fetchall()
+
+ if result == []:
+ return None
+
+ (data, mime_type), = result
+
+ return st.FileData(type=mime_type, name=file_name, contents=data)
+
+def register_payload(
+ policy_tree: PolicyTree,
+ pattern: url_patterns.ParsedPattern,
+ payload_key: st.PayloadKey,
+ token: str
+) -> PolicyTree:
+ """...."""
+ payload_policy_factory = policies.PayloadPolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(pattern, payload_policy_factory)
+
+ resource_policy_factory = policies.PayloadResourcePolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ pattern.path_append(token, '***'),
+ resource_policy_factory
+ )
+
+ return policy_tree
# mypy needs to be corrected:
# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
@dc.dataclass # type: ignore[misc]
-class HaketiloStateWithFields(state.HaketiloState):
+class HaketiloStateWithFields(st.HaketiloState):
"""...."""
store_dir: Path
connection: sqlite3.Connection
current_cursor: t.Optional[sqlite3.Cursor] = None
- #settings: state.HaketiloGlobalSettings
+ #settings: st.HaketiloGlobalSettings
policy_tree: PolicyTree = PolicyTree()
payloads_data: PayloadsData = dc.field(default_factory=dict)
@@ -98,6 +224,85 @@ class HaketiloStateWithFields(state.HaketiloState):
finally:
self.current_cursor = None
- def recompute_payloads(self, cursor: sqlite3.Cursor) -> None:
+ def rebuild_structures(self) -> None:
+ """
+ Recreation of data structures as done after every recomputation of
+ dependencies as well as at startup.
+ """
+ with self.cursor(transaction=True) as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ p.payload_id, p.pattern, p.eval_allowed,
+ p.cors_bypass_allowed,
+ ms.enabled,
+ i.identifier
+ FROM
+ payloads AS p
+ JOIN item_versions AS iv
+ ON p.mapping_item_id = iv.item_version_id
+ JOIN items AS i USING (item_id)
+ JOIN mapping_statuses AS ms USING (item_id);
+ '''
+ )
+
+ rows = cursor.fetchall()
+
+ new_policy_tree = PolicyTree()
+
+ ui_factory = policies.WebUIPolicyFactory(builtin=True)
+ web_ui_pattern = 'http*://hkt.mitm.it/***'
+ for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern):
+ new_policy_tree = new_policy_tree.register(
+ parsed_pattern,
+ ui_factory
+ )
+
+ new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
+
+ for row in rows:
+ (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
+ enabled_status,
+ identifier) = row
+
+ payload_ref = ConcretePayloadRef(str(payload_id_int), self)
+
+ previous_data = self.payloads_data.get(payload_ref)
+ if previous_data is not None:
+ token = previous_data.unique_token
+ else:
+ token = secrets.token_urlsafe(8)
+
+ payload_key = st.PayloadKey(payload_ref, identifier)
+
+ for parsed_pattern in url_patterns.parse_pattern(pattern):
+ new_policy_tree = register_payload(
+ new_policy_tree,
+ parsed_pattern,
+ payload_key,
+ token
+ )
+
+ pattern_path_segments = parsed_pattern.path_segments
+
+ payload_data = st.PayloadData(
+ payload_ref = payload_ref,
+ explicitly_enabled = enabled_status == 'E',
+ unique_token = token,
+ pattern_path_segments = pattern_path_segments,
+ eval_allowed = eval_allowed,
+ cors_bypass_allowed = cors_bypass_allowed
+ )
+
+ new_payloads_data[payload_ref] = payload_data
+
+ self.policy_tree = new_policy_tree
+ self.payloads_data = new_payloads_data
+
+ @abstractmethod
+ def recompute_dependencies(
+ self,
+ requirements: t.Iterable[sds.MappingRequirement] = []
+ ) -> None:
"""...."""
...
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py
index 525a702..6a59a75 100644
--- a/src/hydrilla/proxy/state_impl/concrete_state.py
+++ b/src/hydrilla/proxy/state_impl/concrete_state.py
@@ -32,28 +32,23 @@ and resources.
# Enable using with Python 3.7.
from __future__ import annotations
-import secrets
-import io
-import hashlib
+import sqlite3
import typing as t
import dataclasses as dc
from pathlib import Path
-import sqlite3
-
from ...exceptions import HaketiloException
from ...translations import smart_gettext as _
-from ... import pattern_tree
from ... import url_patterns
from ... import item_infos
-from ..simple_dependency_satisfying import compute_payloads, ComputedPayload
from .. import state as st
from .. import policies
+from .. import simple_dependency_satisfying as sds
from . import base
from . import mappings
from . import repos
-from .load_packages import load_packages
+from . import _operations
here = Path(__file__).resolve().parent
@@ -73,169 +68,14 @@ class ConcreteResourceRef(st.ResourceRef):
class ConcreteResourceVersionRef(st.ResourceVersionRef):
pass
-
-@dc.dataclass(frozen=True, unsafe_hash=True)
-class ConcretePayloadRef(st.PayloadRef):
- state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False)
-
- def get_data(self) -> st.PayloadData:
- try:
- return self.state.payloads_data[self]
- except KeyError:
- raise st.MissingItemError()
-
- def get_mapping(self) -> st.MappingVersionRef:
- raise NotImplementedError()
-
- def get_script_paths(self) \
- -> t.Iterable[t.Sequence[str]]:
- with self.state.cursor() as cursor:
- cursor.execute(
- '''
- SELECT
- i.identifier, fu.name
- FROM
- payloads AS p
- LEFT JOIN resolved_depended_resources AS rdd
- USING (payload_id)
- LEFT JOIN item_versions AS iv
- ON rdd.resource_item_id = iv.item_version_id
- LEFT JOIN items AS i
- USING (item_id)
- LEFT JOIN file_uses AS fu
- USING (item_version_id)
- WHERE
- fu.type = 'W' AND
- p.payload_id = ? AND
- (fu.idx IS NOT NULL OR rdd.idx IS NULL)
- ORDER BY
- rdd.idx, fu.idx;
- ''',
- (self.id,)
- )
-
- paths: list[t.Sequence[str]] = []
- for resource_identifier, file_name in cursor.fetchall():
- if resource_identifier is None:
- # payload found but it had no script files
- return ()
-
- paths.append((resource_identifier, *file_name.split('/')))
-
- if paths == []:
- # payload not found
- raise st.MissingItemError()
-
- return paths
-
- def get_file_data(self, path: t.Sequence[str]) \
- -> t.Optional[st.FileData]:
- if len(path) == 0:
- raise st.MissingItemError()
-
- resource_identifier, *file_name_segments = path
-
- file_name = '/'.join(file_name_segments)
-
- with self.state.cursor() as cursor:
- cursor.execute(
- '''
- SELECT
- f.data, fu.mime_type
- FROM
- payloads AS p
- JOIN resolved_depended_resources AS rdd
- USING (payload_id)
- JOIN item_versions AS iv
- ON rdd.resource_item_id = iv.item_version_id
- JOIN items AS i
- USING (item_id)
- JOIN file_uses AS fu
- USING (item_version_id)
- JOIN files AS f
- USING (file_id)
- WHERE
- p.payload_id = ? AND
- i.identifier = ? AND
- fu.name = ? AND
- fu.type = 'W';
- ''',
- (self.id, resource_identifier, file_name)
- )
-
- result = cursor.fetchall()
-
- if result == []:
- return None
-
- (data, mime_type), = result
-
- return st.FileData(type=mime_type, name=file_name, contents=data)
-
-def register_payload(
- policy_tree: base.PolicyTree,
- pattern: url_patterns.ParsedPattern,
- payload_key: st.PayloadKey,
- token: str
-) -> base.PolicyTree:
- """...."""
- payload_policy_factory = policies.PayloadPolicyFactory(
- builtin = False,
- payload_key = payload_key
- )
-
- policy_tree = policy_tree.register(pattern, payload_policy_factory)
-
- resource_policy_factory = policies.PayloadResourcePolicyFactory(
- builtin = False,
- payload_key = payload_key
- )
-
- policy_tree = policy_tree.register(
- pattern.path_append(token, '***'),
- resource_policy_factory
- )
-
- return policy_tree
-
-AnyInfoVar = t.TypeVar(
- 'AnyInfoVar',
- item_infos.ResourceInfo,
- item_infos.MappingInfo
-)
-
-def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \
- -> t.Mapping[AnyInfoVar, int]:
- cursor.execute(
- '''
- SELECT
- i.item_id, iv.definition, r.name, ri.iteration
- FROM
- item_versions AS iv
- JOIN items AS i USING (item_id)
- JOIN repo_iterations AS ri USING (repo_iteration_id)
- JOIN repos AS r USING (repo_id)
- WHERE
- i.type = ?;
- ''',
- (info_type.type_name[0].upper(),)
- )
-
- result: dict[AnyInfoVar, int] = {}
-
- for item_id, definition, repo_name, repo_iteration in cursor.fetchall():
- definition_io = io.StringIO(definition)
- info = info_type.load(definition_io, repo_name, repo_iteration)
- result[info] = item_id
-
- return result
-
@dc.dataclass
class ConcreteHaketiloState(base.HaketiloStateWithFields):
def __post_init__(self) -> None:
+ sqlite3.enable_callback_tracebacks(True)
+
self._prepare_database()
- self._rebuild_structures()
+ self.rebuild_structures()
def _prepare_database(self) -> None:
"""...."""
@@ -282,147 +122,25 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
def import_packages(self, malcontent_path: Path) -> None:
with self.cursor(transaction=True) as cursor:
- load_packages(self, cursor, malcontent_path)
- self.recompute_payloads(cursor)
-
- def recompute_payloads(self, cursor: sqlite3.Cursor) -> None:
- assert self.connection.in_transaction
-
- resources = get_infos_of_type(cursor, item_infos.ResourceInfo)
- mappings = get_infos_of_type(cursor, item_infos.MappingInfo)
-
- payloads = compute_payloads(resources.keys(), mappings.keys())
-
- payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
-
- cursor.execute('DELETE FROM payloads;')
-
- for mapping_info, by_pattern in payloads.items():
- for num, (pattern, payload) in enumerate(by_pattern.items()):
- cursor.execute(
- '''
- INSERT INTO payloads(
- mapping_item_id,
- pattern,
- eval_allowed,
- cors_bypass_allowed
- )
- VALUES (?, ?, ?, ?);
- ''',
- (
- mappings[mapping_info],
- pattern.orig_url,
- payload.allows_eval,
- payload.allows_cors_bypass
- )
- )
+ _operations.load_packages(cursor, malcontent_path, 1)
+ raise NotImplementedError()
+ _operations.prune_packages(cursor)
- cursor.execute(
- '''
- SELECT
- payload_id
- FROM
- payloads
- WHERE
- mapping_item_id = ? AND pattern = ?;
- ''',
- (mappings[mapping_info], pattern.orig_url)
- )
+ self.recompute_dependencies()
- (payload_id_int,), = cursor.fetchall()
-
- for res_num, resource_info in enumerate(payload.resources):
- cursor.execute(
- '''
- INSERT INTO resolved_depended_resources(
- payload_id,
- resource_item_id,
- idx
- )
- VALUES(?, ?, ?);
- ''',
- (payload_id_int, resources[resource_info], res_num)
- )
-
- self._rebuild_structures(cursor)
-
- def _rebuild_structures(self, cursor: t.Optional[sqlite3.Cursor] = None) \
- -> None:
- """
- Recreation of data structures as done after every recomputation of
- dependencies as well as at startup.
- """
- if cursor is None:
- with self.cursor() as new_cursor:
- return self._rebuild_structures(new_cursor)
-
- cursor.execute(
- '''
- SELECT
- p.payload_id, p.pattern, p.eval_allowed,
- p.cors_bypass_allowed,
- ms.enabled,
- i.identifier
- FROM
- payloads AS p
- JOIN item_versions AS iv
- ON p.mapping_item_id = iv.item_version_id
- JOIN items AS i USING (item_id)
- JOIN mapping_statuses AS ms USING (item_id);
- '''
- )
-
- new_policy_tree = base.PolicyTree()
-
- ui_factory = policies.WebUIPolicyFactory(builtin=True)
- web_ui_pattern = 'http*://hkt.mitm.it/***'
- for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern):
- new_policy_tree = new_policy_tree.register(
- parsed_pattern,
- ui_factory
- )
-
- new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
-
- for row in cursor.fetchall():
- (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
- enabled_status,
- identifier) = row
-
- payload_ref = ConcretePayloadRef(str(payload_id_int), self)
-
- previous_data = self.payloads_data.get(payload_ref)
- if previous_data is not None:
- token = previous_data.unique_token
- else:
- token = secrets.token_urlsafe(8)
-
- payload_key = st.PayloadKey(payload_ref, identifier)
-
- for parsed_pattern in url_patterns.parse_pattern(pattern):
- new_policy_tree = register_payload(
- new_policy_tree,
- parsed_pattern,
- payload_key,
- token
- )
-
- pattern_path_segments = parsed_pattern.path_segments
+ def recompute_dependencies(
+ self,
+ extra_requirements: t.Iterable[sds.MappingRequirement] = []
+ ) -> None:
+ with self.cursor() as cursor:
+ assert self.connection.in_transaction
- payload_data = st.PayloadData(
- payload_ref = payload_ref,
- #explicitly_enabled = enabled_status == 'E',
- explicitly_enabled = True,
- unique_token = token,
- pattern_path_segments = pattern_path_segments,
- eval_allowed = eval_allowed,
- cors_bypass_allowed = cors_bypass_allowed
+ _operations._recompute_dependencies_no_state_update(
+ cursor,
+ extra_requirements
)
- new_payloads_data[payload_ref] = payload_data
-
- self.policy_tree = new_policy_tree
- self.payloads_data = new_payloads_data
+ self.rebuild_structures()
def repo_store(self) -> st.RepoStore:
return repos.ConcreteRepoStore(self)
diff --git a/src/hydrilla/proxy/state_impl/mappings.py b/src/hydrilla/proxy/state_impl/mappings.py
index 3668784..cce2a36 100644
--- a/src/hydrilla/proxy/state_impl/mappings.py
+++ b/src/hydrilla/proxy/state_impl/mappings.py
@@ -76,7 +76,7 @@ class ConcreteMappingVersionRef(st.MappingVersionRef):
(status_letter, definition, repo, repo_iteration, is_orphan), = rows
item_info = item_infos.MappingInfo.load(
- io.StringIO(definition),
+ definition,
repo,
repo_iteration
)
@@ -120,7 +120,7 @@ class ConcreteMappingVersionStore(st.MappingVersionStore):
ref = ConcreteMappingVersionRef(str(item_version_id), self.state)
item_info = item_infos.MappingInfo.load(
- io.StringIO(definition),
+ definition,
repo,
repo_iteration
)
diff --git a/src/hydrilla/proxy/state_impl/repos.py b/src/hydrilla/proxy/state_impl/repos.py
index 5553ec2..f4c7c71 100644
--- a/src/hydrilla/proxy/state_impl/repos.py
+++ b/src/hydrilla/proxy/state_impl/repos.py
@@ -33,20 +33,27 @@ inside Haketilo.
from __future__ import annotations
import re
+import json
+import tempfile
+import requests
+import sqlite3
import typing as t
import dataclasses as dc
-from urllib.parse import urlparse
+from urllib.parse import urlparse, urljoin
from datetime import datetime
+from pathlib import Path
-import sqlite3
-
+from ... import json_instances
+from ... import item_infos
+from ... import versions
from .. import state as st
+from .. import simple_dependency_satisfying as sds
from . import base
-from . import prune_packages
+from . import _operations
-def validate_repo_url(url: str) -> None:
+def sanitize_repo_url(url: str) -> str:
try:
parsed = urlparse(url)
except:
@@ -55,6 +62,11 @@ def validate_repo_url(url: str) -> None:
if parsed.scheme not in ('http', 'https'):
raise st.RepoUrlInvalid()
+ if url[-1] != '/':
+ url = url + '/'
+
+ return url
+
def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None:
cursor.execute(
@@ -73,6 +85,53 @@ def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None:
raise st.MissingItemError()
+def sync_remote_repo_definitions(repo_url: str, dest: Path) -> None:
+ try:
+ list_all_response = requests.get(urljoin(repo_url, 'list_all'))
+ assert list_all_response.ok
+
+ list_instance = list_all_response.json()
+ except:
+ raise st.RepoCommunicationError()
+
+ try:
+ json_instances.validate_instance(
+ list_instance,
+ 'api_package_list-{}.schema.json'
+ )
+ except json_instances.UnknownSchemaError:
+ raise st.RepoApiVersionUnsupported()
+ except:
+ raise st.RepoCommunicationError()
+
+ ref: dict[str, t.Any]
+
+ for item_type_name in ('resource', 'mapping'):
+ for ref in list_instance[item_type_name + 's']:
+ ver = versions.version_string(versions.normalize(ref['version']))
+ item_rel_path = f'{item_type_name}/{ref["identifier"]}/{ver}'
+
+ try:
+ item_response = requests.get(urljoin(repo_url, item_rel_path))
+ assert item_response.ok
+ except:
+ raise st.RepoCommunicationError()
+
+ item_path = dest / item_rel_path
+ item_path.parent.mkdir(parents=True, exist_ok=True)
+ item_path.write_bytes(item_response.content)
+
+
+@dc.dataclass(frozen=True)
+class RemoteFileResolver(_operations.FileResolver):
+ repo_url: str
+
+ def by_sha256(self, sha256: str) -> bytes:
+ response = requests.get(urljoin(self.repo_url, f'file/sha256/{sha256}'))
+ assert response.ok
+ return response.content
+
+
def make_repo_display_info(
ref: st.RepoRef,
name: str,
@@ -122,6 +181,37 @@ class ConcreteRepoRef(st.RepoRef):
(self.id,)
)
+ _operations.prune_packages(cursor)
+
+ # For mappings explicitly enabled by the user (+ all mappings they
+ # recursively depend on) let's make sure that their exact same
+ # versions will be enabled after the change.
+ cursor.execute(
+ '''
+ SELECT
+ iv.definition, r.name, ri.iteration
+ FROM
+ mapping_statuses AS ms
+ JOIN item_versions AS iv
+ ON ms.active_version_id = iv.item_version_id
+ JOIN repo_iterations AS ri
+ USING (repo_iteration_id)
+ JOIN repos AS r
+ USING (repo_id)
+ WHERE
+ ms.required
+ '''
+ )
+
+ requirements = []
+
+ for definition, repo, iteration in cursor.fetchall():
+ info = item_infos.MappingInfo.load(definition, repo, iteration)
+ req = sds.MappingVersionRequirement(info.identifier, info)
+ requirements.append(req)
+
+ self.state.recompute_dependencies(requirements)
+
def update(
self,
*,
@@ -134,7 +224,7 @@ class ConcreteRepoRef(st.RepoRef):
if url is None:
return
- validate_repo_url(url)
+ url = sanitize_repo_url(url)
with self.state.cursor(transaction=True) as cursor:
ensure_repo_not_deleted(cursor, self.id)
@@ -144,10 +234,30 @@ class ConcreteRepoRef(st.RepoRef):
(url, self.id)
)
- prune_packages.prune(cursor)
-
def refresh(self) -> st.RepoIterationRef:
- raise NotImplementedError()
+ with self.state.cursor(transaction=True) as cursor:
+ ensure_repo_not_deleted(cursor, self.id)
+
+ cursor.execute(
+ 'SELECT url from repos where repo_id = ?;',
+ (self.id,)
+ )
+
+ (repo_url,), = cursor.fetchall()
+
+ with tempfile.TemporaryDirectory() as tmpdir_str:
+ tmpdir = Path(tmpdir_str)
+ sync_remote_repo_definitions(repo_url, tmpdir)
+ new_iteration_id = _operations.load_packages(
+ cursor,
+ tmpdir,
+ int(self.id),
+ RemoteFileResolver(repo_url)
+ )
+
+ self.state.recompute_dependencies()
+
+ return ConcreteRepoIterationRef(str(new_iteration_id), self.state)
def get_display_info(self) -> st.RepoDisplayInfo:
with self.state.cursor() as cursor:
@@ -199,7 +309,7 @@ class ConcreteRepoStore(st.RepoStore):
if repo_name_regex.match(name) is None:
raise st.RepoNameInvalid()
- validate_repo_url(url)
+ url = sanitize_repo_url(url)
with self.state.cursor(transaction=True) as cursor:
cursor.execute(
@@ -272,3 +382,8 @@ class ConcreteRepoStore(st.RepoStore):
result.append(make_repo_display_info(ref, *rest))
return result
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteRepoIterationRef(st.RepoIterationRef):
+ state: base.HaketiloStateWithFields