aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/state_impl
diff options
context:
space:
mode:
authorWojtek Kosior <koszko@koszko.org>2022-08-25 11:53:14 +0200
committerWojtek Kosior <koszko@koszko.org>2022-09-28 12:54:53 +0200
commit4dbbb2aec204a5cccc713e2e2098d6e0a47f8cf6 (patch)
treea0e328e4200c78fbf2fbad621179ec3e436b97e3 /src/hydrilla/proxy/state_impl
parent3a1da529bdba0c353b10f6fe2cf1024feb81f809 (diff)
downloadhaketilo-hydrilla-4dbbb2aec204a5cccc713e2e2098d6e0a47f8cf6.tar.gz
haketilo-hydrilla-4dbbb2aec204a5cccc713e2e2098d6e0a47f8cf6.zip
[proxy] refactor state implementation
Diffstat (limited to 'src/hydrilla/proxy/state_impl')
-rw-r--r--src/hydrilla/proxy/state_impl/base.py267
-rw-r--r--src/hydrilla/proxy/state_impl/concrete_state.py130
-rw-r--r--src/hydrilla/proxy/state_impl/mappings.py3
-rw-r--r--src/hydrilla/proxy/state_impl/payloads.py137
-rw-r--r--src/hydrilla/proxy/state_impl/repos.py17
5 files changed, 309 insertions, 245 deletions
diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py
index 92833dd..25fd4c5 100644
--- a/src/hydrilla/proxy/state_impl/base.py
+++ b/src/hydrilla/proxy/state_impl/base.py
@@ -34,7 +34,6 @@ subtype.
from __future__ import annotations
import sqlite3
-import secrets
import threading
import dataclasses as dc
import typing as t
@@ -43,8 +42,6 @@ from pathlib import Path
from contextlib import contextmanager
from abc import abstractmethod
-from immutables import Map
-
from ... import url_patterns
from ... import pattern_tree
from .. import simple_dependency_satisfying as sds
@@ -52,132 +49,36 @@ from .. import state as st
from .. import policies
-PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
-PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData]
-
-@dc.dataclass(frozen=True, unsafe_hash=True)
-class ConcretePayloadRef(st.PayloadRef):
- state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False)
-
- def get_data(self) -> st.PayloadData:
- try:
- return self.state.payloads_data[self]
- except KeyError:
- raise st.MissingItemError()
-
- def get_mapping(self) -> st.MappingVersionRef:
- raise NotImplementedError()
-
- def get_script_paths(self) \
- -> t.Iterable[t.Sequence[str]]:
- with self.state.cursor() as cursor:
- cursor.execute(
- '''
- SELECT
- i.identifier, fu.name
- FROM
- payloads AS p
- LEFT JOIN resolved_depended_resources AS rdd
- USING (payload_id)
- LEFT JOIN item_versions AS iv
- ON rdd.resource_item_id = iv.item_version_id
- LEFT JOIN items AS i
- USING (item_id)
- LEFT JOIN file_uses AS fu
- USING (item_version_id)
- WHERE
- fu.type = 'W' AND
- p.payload_id = ? AND
- (fu.idx IS NOT NULL OR rdd.idx IS NULL)
- ORDER BY
- rdd.idx, fu.idx;
- ''',
- (self.id,)
- )
-
- paths: list[t.Sequence[str]] = []
- for resource_identifier, file_name in cursor.fetchall():
- if resource_identifier is None:
- # payload found but it had no script files
- return ()
-
- paths.append((resource_identifier, *file_name.split('/')))
-
- if paths == []:
- # payload not found
- raise st.MissingItemError()
-
- return paths
-
- def get_file_data(self, path: t.Sequence[str]) \
- -> t.Optional[st.FileData]:
- if len(path) == 0:
- raise st.MissingItemError()
-
- resource_identifier, *file_name_segments = path
-
- file_name = '/'.join(file_name_segments)
-
- with self.state.cursor() as cursor:
- cursor.execute(
- '''
- SELECT
- f.data, fu.mime_type
- FROM
- payloads AS p
- JOIN resolved_depended_resources AS rdd
- USING (payload_id)
- JOIN item_versions AS iv
- ON rdd.resource_item_id = iv.item_version_id
- JOIN items AS i
- USING (item_id)
- JOIN file_uses AS fu
- USING (item_version_id)
- JOIN files AS f
- USING (file_id)
- WHERE
- p.payload_id = ? AND
- i.identifier = ? AND
- fu.name = ? AND
- fu.type = 'W';
- ''',
- (self.id, resource_identifier, file_name)
- )
-
- result = cursor.fetchall()
-
- if result == []:
- return None
-
- (data, mime_type), = result
+@dc.dataclass(frozen=True)
+class PolicyTree(pattern_tree.PatternTree[policies.PolicyFactory]):
+ SelfType = t.TypeVar('SelfType', bound='PolicyTree')
- return st.FileData(type=mime_type, name=file_name, contents=data)
+ def register_payload(
+ self: 'SelfType',
+ pattern: url_patterns.ParsedPattern,
+ payload_key: st.PayloadKey,
+ token: str
+ ) -> 'SelfType':
+ payload_policy_factory = policies.PayloadPolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
-def register_payload(
- policy_tree: PolicyTree,
- pattern: url_patterns.ParsedPattern,
- payload_key: st.PayloadKey,
- token: str
-) -> PolicyTree:
- """...."""
- payload_policy_factory = policies.PayloadPolicyFactory(
- builtin = False,
- payload_key = payload_key
- )
+ policy_tree = self.register(pattern, payload_policy_factory)
- policy_tree = policy_tree.register(pattern, payload_policy_factory)
+ resource_policy_factory = policies.PayloadResourcePolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
- resource_policy_factory = policies.PayloadResourcePolicyFactory(
- builtin = False,
- payload_key = payload_key
- )
+ policy_tree = policy_tree.register(
+ pattern.path_append(token, '***'),
+ resource_policy_factory
+ )
- policy_tree = policy_tree.register(
- pattern.path_append(token, '***'),
- resource_policy_factory
- )
+ return policy_tree
- return policy_tree
+PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData]
# mypy needs to be corrected:
# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
@@ -187,6 +88,7 @@ class HaketiloStateWithFields(st.HaketiloState):
store_dir: Path
connection: sqlite3.Connection
current_cursor: t.Optional[sqlite3.Cursor] = None
+
#settings: st.HaketiloGlobalSettings
policy_tree: PolicyTree = PolicyTree()
@@ -224,85 +126,64 @@ class HaketiloStateWithFields(st.HaketiloState):
finally:
self.current_cursor = None
- def rebuild_structures(self) -> None:
- """
- Recreation of data structures as done after every recomputation of
- dependencies as well as at startup.
- """
- with self.cursor(transaction=True) as cursor:
- cursor.execute(
- '''
- SELECT
- p.payload_id, p.pattern, p.eval_allowed,
- p.cors_bypass_allowed,
- ms.enabled,
- i.identifier
- FROM
- payloads AS p
- JOIN item_versions AS iv
- ON p.mapping_item_id = iv.item_version_id
- JOIN items AS i USING (item_id)
- JOIN mapping_statuses AS ms USING (item_id);
- '''
- )
-
- rows = cursor.fetchall()
-
- new_policy_tree = PolicyTree()
+ def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy:
+ """...."""
+ with self.lock:
+ policy_tree = self.policy_tree
- ui_factory = policies.WebUIPolicyFactory(builtin=True)
- web_ui_pattern = 'http*://hkt.mitm.it/***'
- for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern):
- new_policy_tree = new_policy_tree.register(
- parsed_pattern,
- ui_factory
+ try:
+ best_priority: int = 0
+ best_policy: t.Optional[policies.Policy] = None
+
+ for factories_set in policy_tree.search(url):
+ for stored_factory in sorted(factories_set):
+ factory = stored_factory.item
+
+ policy = factory.make_policy(self)
+
+ if policy.priority > best_priority:
+ best_priority = policy.priority
+ best_policy = policy
+ except Exception as e:
+ return policies.ErrorBlockPolicy(
+ builtin = True,
+ error = e
)
- new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
-
- for row in rows:
- (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
- enabled_status,
- identifier) = row
-
- payload_ref = ConcretePayloadRef(str(payload_id_int), self)
-
- previous_data = self.payloads_data.get(payload_ref)
- if previous_data is not None:
- token = previous_data.unique_token
- else:
- token = secrets.token_urlsafe(8)
-
- payload_key = st.PayloadKey(payload_ref, identifier)
-
- for parsed_pattern in url_patterns.parse_pattern(pattern):
- new_policy_tree = register_payload(
- new_policy_tree,
- parsed_pattern,
- payload_key,
- token
- )
+ if best_policy is not None:
+ return best_policy
- pattern_path_segments = parsed_pattern.path_segments
+ if self.get_settings().default_allow_scripts:
+ return policies.FallbackAllowPolicy()
+ else:
+ return policies.FallbackBlockPolicy()
- payload_data = st.PayloadData(
- payload_ref = payload_ref,
- explicitly_enabled = enabled_status == 'E',
- unique_token = token,
- pattern_path_segments = pattern_path_segments,
- eval_allowed = eval_allowed,
- cors_bypass_allowed = cors_bypass_allowed
- )
-
- new_payloads_data[payload_ref] = payload_data
-
- self.policy_tree = new_policy_tree
- self.payloads_data = new_payloads_data
+ @abstractmethod
+ def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None:
+ ...
@abstractmethod
def recompute_dependencies(
self,
- requirements: t.Iterable[sds.MappingRequirement] = []
+ requirements: t.Iterable[sds.MappingRequirement] = [],
+ prune_orphans: bool = False
) -> None:
"""...."""
...
+
+ @abstractmethod
+ def pull_missing_files(self) -> None:
+ """
+ This function checks which packages marked as installed are missing
+ files in the database. It attempts to restore integrity by downloading
+ the files from their respective repositories.
+ """
+ ...
+
+ @abstractmethod
+ def rebuild_structures(self) -> None:
+ """
+ Recreation of data structures as done after every recomputation of
+ dependencies as well as at startup.
+ """
+ ...
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py
index 9e56bff..0de67e0 100644
--- a/src/hydrilla/proxy/state_impl/concrete_state.py
+++ b/src/hydrilla/proxy/state_impl/concrete_state.py
@@ -33,6 +33,7 @@ and resources.
from __future__ import annotations
import sqlite3
+import secrets
import typing as t
import dataclasses as dc
@@ -48,6 +49,7 @@ from .. import simple_dependency_satisfying as sds
from . import base
from . import mappings
from . import repos
+from . import payloads
from . import _operations
@@ -120,23 +122,35 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
finally:
cursor.close()
- def import_packages(self, malcontent_path: Path) -> None:
- with self.cursor(transaction=True) as cursor:
+ def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None:
+ with self.cursor(transaction=(repo_id == 1)) as cursor:
+ # This method without the repo_id argument exposed is part of the
+ # state API. As such, calls with repo_id = 1 (imports of local
+ # semirepo packages) create a new transaction. Calls with different
+ # values of repo_id are assumed to originate from within the state
+ # implementation code and expect an existing transaction. Here, we
+ # verify the transaction is indeed present.
+ assert self.connection.in_transaction
+
_operations._load_packages_no_state_update(
cursor = cursor,
malcontent_path = malcontent_path,
- repo_id = 1
+ repo_id = repo_id
)
self.rebuild_structures()
def recompute_dependencies(
self,
- extra_requirements: t.Iterable[sds.MappingRequirement] = []
+ extra_requirements: t.Iterable[sds.MappingRequirement] = [],
+ prune_orphans: bool = False,
) -> None:
with self.cursor() as cursor:
assert self.connection.in_transaction
+ if prune_orphans:
+ _operations.prune_packages(cursor)
+
_operations._recompute_dependencies_no_state_update(
cursor = cursor,
extra_requirements = extra_requirements
@@ -144,6 +158,82 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
self.rebuild_structures()
+ def pull_missing_files(self) -> None:
+ with self.cursor() as cursor:
+ assert self.connection.in_transaction
+
+ _operations.pull_missing_files(cursor)
+
+ def rebuild_structures(self) -> None:
+ with self.cursor(transaction=True) as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ p.payload_id, p.pattern, p.eval_allowed,
+ p.cors_bypass_allowed,
+ ms.enabled,
+ i.identifier
+ FROM
+ payloads AS p
+ JOIN item_versions AS iv
+ ON p.mapping_item_id = iv.item_version_id
+ JOIN items AS i USING (item_id)
+ JOIN mapping_statuses AS ms USING (item_id);
+ '''
+ )
+
+ rows = cursor.fetchall()
+
+ new_policy_tree = base.PolicyTree()
+
+ ui_factory = policies.WebUIPolicyFactory(builtin=True)
+ web_ui_pattern = 'http*://hkt.mitm.it/***'
+ for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern):
+ new_policy_tree = new_policy_tree.register(
+ parsed_pattern,
+ ui_factory
+ )
+
+ new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
+
+ for row in rows:
+ (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
+ enabled_status,
+ identifier) = row
+
+ payload_ref = payloads.ConcretePayloadRef(str(payload_id_int), self)
+
+ previous_data = self.payloads_data.get(payload_ref)
+ if previous_data is not None:
+ token = previous_data.unique_token
+ else:
+ token = secrets.token_urlsafe(8)
+
+ payload_key = st.PayloadKey(payload_ref, identifier)
+
+ for parsed_pattern in url_patterns.parse_pattern(pattern):
+ new_policy_tree = new_policy_tree.register_payload(
+ parsed_pattern,
+ payload_key,
+ token
+ )
+
+ pattern_path_segments = parsed_pattern.path_segments
+
+ payload_data = st.PayloadData(
+ payload_ref = payload_ref,
+ explicitly_enabled = enabled_status == 'E',
+ unique_token = token,
+ pattern_path_segments = pattern_path_segments,
+ eval_allowed = eval_allowed,
+ cors_bypass_allowed = cors_bypass_allowed
+ )
+
+ new_payloads_data[payload_ref] = payload_data
+
+ self.policy_tree = new_policy_tree
+ self.payloads_data = new_payloads_data
+
def repo_store(self) -> st.RepoStore:
return repos.ConcreteRepoStore(self)
@@ -182,38 +272,6 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
) -> None:
raise NotImplementedError()
- def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy:
- """...."""
- with self.lock:
- policy_tree = self.policy_tree
-
- try:
- best_priority: int = 0
- best_policy: t.Optional[policies.Policy] = None
-
- for factories_set in policy_tree.search(url):
- for stored_factory in sorted(factories_set):
- factory = stored_factory.item
-
- policy = factory.make_policy(self)
-
- if policy.priority > best_priority:
- best_priority = policy.priority
- best_policy = policy
- except Exception as e:
- return policies.ErrorBlockPolicy(
- builtin = True,
- error = e
- )
-
- if best_policy is not None:
- return best_policy
-
- if self.get_settings().default_allow_scripts:
- return policies.FallbackAllowPolicy()
- else:
- return policies.FallbackBlockPolicy()
-
@staticmethod
def make(store_dir: Path) -> 'ConcreteHaketiloState':
connection = sqlite3.connect(
diff --git a/src/hydrilla/proxy/state_impl/mappings.py b/src/hydrilla/proxy/state_impl/mappings.py
index 7d08e58..8a401b8 100644
--- a/src/hydrilla/proxy/state_impl/mappings.py
+++ b/src/hydrilla/proxy/state_impl/mappings.py
@@ -38,7 +38,6 @@ import dataclasses as dc
from ... import item_infos
from .. import state as st
from . import base
-from . import _operations
@dc.dataclass(frozen=True, unsafe_hash=True)
@@ -227,7 +226,7 @@ class ConcreteMappingVersionRef(st.MappingVersionRef):
self._set_installed_status(cursor, st.InstalledStatus.INSTALLED)
- _operations.pull_missing_files(cursor)
+ self.state.pull_missing_files()
def uninstall(self) -> None:
raise NotImplementedError()
diff --git a/src/hydrilla/proxy/state_impl/payloads.py b/src/hydrilla/proxy/state_impl/payloads.py
new file mode 100644
index 0000000..2bee11f
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl/payloads.py
@@ -0,0 +1,137 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy data and configuration (PayloadRef subtype).
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+This module provides an interface to interact with payloads inside Haketilo.
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import dataclasses as dc
+import typing as t
+
+from .. import state as st
+from . import base
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcretePayloadRef(st.PayloadRef):
+ state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False)
+
+ def get_data(self) -> st.PayloadData:
+ try:
+ return self.state.payloads_data[self]
+ except KeyError:
+ raise st.MissingItemError()
+
+ def get_mapping(self) -> st.MappingVersionRef:
+ raise NotImplementedError()
+
+ def get_script_paths(self) \
+ -> t.Iterable[t.Sequence[str]]:
+ with self.state.cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ i.identifier, fu.name
+ FROM
+ payloads AS p
+ LEFT JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ LEFT JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ LEFT JOIN items AS i
+ USING (item_id)
+ LEFT JOIN file_uses AS fu
+ USING (item_version_id)
+ WHERE
+ fu.type = 'W' AND
+ p.payload_id = ? AND
+ (fu.idx IS NOT NULL OR rdd.idx IS NULL)
+ ORDER BY
+ rdd.idx, fu.idx;
+ ''',
+ (self.id,)
+ )
+
+ paths: list[t.Sequence[str]] = []
+ for resource_identifier, file_name in cursor.fetchall():
+ if resource_identifier is None:
+ # payload found but it had no script files
+ return ()
+
+ paths.append((resource_identifier, *file_name.split('/')))
+
+ if paths == []:
+ # payload not found
+ raise st.MissingItemError()
+
+ return paths
+
+ def get_file_data(self, path: t.Sequence[str]) \
+ -> t.Optional[st.FileData]:
+ if len(path) == 0:
+ raise st.MissingItemError()
+
+ resource_identifier, *file_name_segments = path
+
+ file_name = '/'.join(file_name_segments)
+
+ with self.state.cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ f.data, fu.mime_type
+ FROM
+ payloads AS p
+ JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ JOIN items AS i
+ USING (item_id)
+ JOIN file_uses AS fu
+ USING (item_version_id)
+ JOIN files AS f
+ USING (file_id)
+ WHERE
+ p.payload_id = ? AND
+ i.identifier = ? AND
+ fu.name = ? AND
+ fu.type = 'W';
+ ''',
+ (self.id, resource_identifier, file_name)
+ )
+
+ result = cursor.fetchall()
+
+ if result == []:
+ return None
+
+ (data, mime_type), = result
+
+ return st.FileData(type=mime_type, name=file_name, contents=data)
diff --git a/src/hydrilla/proxy/state_impl/repos.py b/src/hydrilla/proxy/state_impl/repos.py
index 346e113..838698c 100644
--- a/src/hydrilla/proxy/state_impl/repos.py
+++ b/src/hydrilla/proxy/state_impl/repos.py
@@ -51,7 +51,6 @@ from ... import versions
from .. import state as st
from .. import simple_dependency_satisfying as sds
from . import base
-from . import _operations
repo_name_regex = re.compile(r'''
@@ -194,8 +193,6 @@ class ConcreteRepoRef(st.RepoRef):
(self.id,)
)
- _operations.prune_packages(cursor)
-
# For mappings explicitly enabled by the user (+ all mappings they
# recursively depend on) let's make sure that their exact same
# versions will be enabled after the change.
@@ -217,7 +214,7 @@ class ConcreteRepoRef(st.RepoRef):
req = sds.MappingVersionRequirement(info.identifier, info)
requirements.append(req)
- self.state.recompute_dependencies(requirements)
+ self.state.recompute_dependencies(requirements, prune_orphans=True)
def update(
self,
@@ -260,7 +257,7 @@ class ConcreteRepoRef(st.RepoRef):
self.state.recompute_dependencies()
- def refresh(self) -> st.RepoIterationRef:
+ def refresh(self) -> None:
with self.state.cursor(transaction=True) as cursor:
ensure_repo_not_deleted(cursor, self.id)
@@ -274,15 +271,7 @@ class ConcreteRepoRef(st.RepoRef):
with tempfile.TemporaryDirectory() as tmpdir_str:
tmpdir = Path(tmpdir_str)
sync_remote_repo_definitions(repo_url, tmpdir)
- new_iteration_id = _operations._load_packages_no_state_update(
- cursor = cursor,
- malcontent_path = tmpdir,
- repo_id = int(self.id)
- )
-
- self.state.rebuild_structures()
-
- return ConcreteRepoIterationRef(str(new_iteration_id), self.state)
+ self.state.import_items(tmpdir, int(self.id))
def get_display_info(self) -> st.RepoDisplayInfo:
with self.state.cursor() as cursor: