aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/state_impl/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/state_impl/base.py')
-rw-r--r--src/hydrilla/proxy/state_impl/base.py267
1 files changed, 74 insertions, 193 deletions
diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py
index 92833dd..25fd4c5 100644
--- a/src/hydrilla/proxy/state_impl/base.py
+++ b/src/hydrilla/proxy/state_impl/base.py
@@ -34,7 +34,6 @@ subtype.
from __future__ import annotations
import sqlite3
-import secrets
import threading
import dataclasses as dc
import typing as t
@@ -43,8 +42,6 @@ from pathlib import Path
from contextlib import contextmanager
from abc import abstractmethod
-from immutables import Map
-
from ... import url_patterns
from ... import pattern_tree
from .. import simple_dependency_satisfying as sds
@@ -52,132 +49,36 @@ from .. import state as st
from .. import policies
-PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
-PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData]
-
-@dc.dataclass(frozen=True, unsafe_hash=True)
-class ConcretePayloadRef(st.PayloadRef):
- state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False)
-
- def get_data(self) -> st.PayloadData:
- try:
- return self.state.payloads_data[self]
- except KeyError:
- raise st.MissingItemError()
-
- def get_mapping(self) -> st.MappingVersionRef:
- raise NotImplementedError()
-
- def get_script_paths(self) \
- -> t.Iterable[t.Sequence[str]]:
- with self.state.cursor() as cursor:
- cursor.execute(
- '''
- SELECT
- i.identifier, fu.name
- FROM
- payloads AS p
- LEFT JOIN resolved_depended_resources AS rdd
- USING (payload_id)
- LEFT JOIN item_versions AS iv
- ON rdd.resource_item_id = iv.item_version_id
- LEFT JOIN items AS i
- USING (item_id)
- LEFT JOIN file_uses AS fu
- USING (item_version_id)
- WHERE
- fu.type = 'W' AND
- p.payload_id = ? AND
- (fu.idx IS NOT NULL OR rdd.idx IS NULL)
- ORDER BY
- rdd.idx, fu.idx;
- ''',
- (self.id,)
- )
-
- paths: list[t.Sequence[str]] = []
- for resource_identifier, file_name in cursor.fetchall():
- if resource_identifier is None:
- # payload found but it had no script files
- return ()
-
- paths.append((resource_identifier, *file_name.split('/')))
-
- if paths == []:
- # payload not found
- raise st.MissingItemError()
-
- return paths
-
- def get_file_data(self, path: t.Sequence[str]) \
- -> t.Optional[st.FileData]:
- if len(path) == 0:
- raise st.MissingItemError()
-
- resource_identifier, *file_name_segments = path
-
- file_name = '/'.join(file_name_segments)
-
- with self.state.cursor() as cursor:
- cursor.execute(
- '''
- SELECT
- f.data, fu.mime_type
- FROM
- payloads AS p
- JOIN resolved_depended_resources AS rdd
- USING (payload_id)
- JOIN item_versions AS iv
- ON rdd.resource_item_id = iv.item_version_id
- JOIN items AS i
- USING (item_id)
- JOIN file_uses AS fu
- USING (item_version_id)
- JOIN files AS f
- USING (file_id)
- WHERE
- p.payload_id = ? AND
- i.identifier = ? AND
- fu.name = ? AND
- fu.type = 'W';
- ''',
- (self.id, resource_identifier, file_name)
- )
-
- result = cursor.fetchall()
-
- if result == []:
- return None
-
- (data, mime_type), = result
+@dc.dataclass(frozen=True)
+class PolicyTree(pattern_tree.PatternTree[policies.PolicyFactory]):
+ SelfType = t.TypeVar('SelfType', bound='PolicyTree')
- return st.FileData(type=mime_type, name=file_name, contents=data)
+ def register_payload(
+ self: 'SelfType',
+ pattern: url_patterns.ParsedPattern,
+ payload_key: st.PayloadKey,
+ token: str
+ ) -> 'SelfType':
+ payload_policy_factory = policies.PayloadPolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
-def register_payload(
- policy_tree: PolicyTree,
- pattern: url_patterns.ParsedPattern,
- payload_key: st.PayloadKey,
- token: str
-) -> PolicyTree:
- """...."""
- payload_policy_factory = policies.PayloadPolicyFactory(
- builtin = False,
- payload_key = payload_key
- )
+ policy_tree = self.register(pattern, payload_policy_factory)
- policy_tree = policy_tree.register(pattern, payload_policy_factory)
+ resource_policy_factory = policies.PayloadResourcePolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
- resource_policy_factory = policies.PayloadResourcePolicyFactory(
- builtin = False,
- payload_key = payload_key
- )
+ policy_tree = policy_tree.register(
+ pattern.path_append(token, '***'),
+ resource_policy_factory
+ )
- policy_tree = policy_tree.register(
- pattern.path_append(token, '***'),
- resource_policy_factory
- )
+ return policy_tree
- return policy_tree
+PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData]
# mypy needs to be corrected:
# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
@@ -187,6 +88,7 @@ class HaketiloStateWithFields(st.HaketiloState):
store_dir: Path
connection: sqlite3.Connection
current_cursor: t.Optional[sqlite3.Cursor] = None
+
#settings: st.HaketiloGlobalSettings
policy_tree: PolicyTree = PolicyTree()
@@ -224,85 +126,64 @@ class HaketiloStateWithFields(st.HaketiloState):
finally:
self.current_cursor = None
- def rebuild_structures(self) -> None:
- """
- Recreation of data structures as done after every recomputation of
- dependencies as well as at startup.
- """
- with self.cursor(transaction=True) as cursor:
- cursor.execute(
- '''
- SELECT
- p.payload_id, p.pattern, p.eval_allowed,
- p.cors_bypass_allowed,
- ms.enabled,
- i.identifier
- FROM
- payloads AS p
- JOIN item_versions AS iv
- ON p.mapping_item_id = iv.item_version_id
- JOIN items AS i USING (item_id)
- JOIN mapping_statuses AS ms USING (item_id);
- '''
- )
-
- rows = cursor.fetchall()
-
- new_policy_tree = PolicyTree()
+ def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy:
+ """...."""
+ with self.lock:
+ policy_tree = self.policy_tree
- ui_factory = policies.WebUIPolicyFactory(builtin=True)
- web_ui_pattern = 'http*://hkt.mitm.it/***'
- for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern):
- new_policy_tree = new_policy_tree.register(
- parsed_pattern,
- ui_factory
+ try:
+ best_priority: int = 0
+ best_policy: t.Optional[policies.Policy] = None
+
+ for factories_set in policy_tree.search(url):
+ for stored_factory in sorted(factories_set):
+ factory = stored_factory.item
+
+ policy = factory.make_policy(self)
+
+ if policy.priority > best_priority:
+ best_priority = policy.priority
+ best_policy = policy
+ except Exception as e:
+ return policies.ErrorBlockPolicy(
+ builtin = True,
+ error = e
)
- new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
-
- for row in rows:
- (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
- enabled_status,
- identifier) = row
-
- payload_ref = ConcretePayloadRef(str(payload_id_int), self)
-
- previous_data = self.payloads_data.get(payload_ref)
- if previous_data is not None:
- token = previous_data.unique_token
- else:
- token = secrets.token_urlsafe(8)
-
- payload_key = st.PayloadKey(payload_ref, identifier)
-
- for parsed_pattern in url_patterns.parse_pattern(pattern):
- new_policy_tree = register_payload(
- new_policy_tree,
- parsed_pattern,
- payload_key,
- token
- )
+ if best_policy is not None:
+ return best_policy
- pattern_path_segments = parsed_pattern.path_segments
+ if self.get_settings().default_allow_scripts:
+ return policies.FallbackAllowPolicy()
+ else:
+ return policies.FallbackBlockPolicy()
- payload_data = st.PayloadData(
- payload_ref = payload_ref,
- explicitly_enabled = enabled_status == 'E',
- unique_token = token,
- pattern_path_segments = pattern_path_segments,
- eval_allowed = eval_allowed,
- cors_bypass_allowed = cors_bypass_allowed
- )
-
- new_payloads_data[payload_ref] = payload_data
-
- self.policy_tree = new_policy_tree
- self.payloads_data = new_payloads_data
+ @abstractmethod
+ def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None:
+ ...
@abstractmethod
def recompute_dependencies(
self,
- requirements: t.Iterable[sds.MappingRequirement] = []
+ requirements: t.Iterable[sds.MappingRequirement] = [],
+ prune_orphans: bool = False
) -> None:
"""...."""
...
+
+ @abstractmethod
+ def pull_missing_files(self) -> None:
+ """
+ This function checks which packages marked as installed are missing
+ files in the database. It attempts to restore integrity by downloading
+ the files from their respective repositories.
+ """
+ ...
+
+ @abstractmethod
+ def rebuild_structures(self) -> None:
+ """
+ Recreation of data structures as done after every recomputation of
+ dependencies as well as at startup.
+ """
+ ...