aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/state_impl/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/state_impl/base.py')
-rw-r--r--src/hydrilla/proxy/state_impl/base.py219
1 files changed, 212 insertions, 7 deletions
diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py
index 1ae08cb..92833dd 100644
--- a/src/hydrilla/proxy/state_impl/base.py
+++ b/src/hydrilla/proxy/state_impl/base.py
@@ -34,34 +34,160 @@ subtype.
from __future__ import annotations
import sqlite3
+import secrets
import threading
import dataclasses as dc
import typing as t
from pathlib import Path
from contextlib import contextmanager
-
-import sqlite3
+from abc import abstractmethod
from immutables import Map
+from ... import url_patterns
from ... import pattern_tree
-from .. import state
+from .. import simple_dependency_satisfying as sds
+from .. import state as st
from .. import policies
PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
-PayloadsData = t.Mapping[state.PayloadRef, state.PayloadData]
+PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData]
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcretePayloadRef(st.PayloadRef):
+ state: 'HaketiloStateWithFields' = dc.field(hash=False, compare=False)
+
+ def get_data(self) -> st.PayloadData:
+ try:
+ return self.state.payloads_data[self]
+ except KeyError:
+ raise st.MissingItemError()
+
+ def get_mapping(self) -> st.MappingVersionRef:
+ raise NotImplementedError()
+
+ def get_script_paths(self) \
+ -> t.Iterable[t.Sequence[str]]:
+ with self.state.cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ i.identifier, fu.name
+ FROM
+ payloads AS p
+ LEFT JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ LEFT JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ LEFT JOIN items AS i
+ USING (item_id)
+ LEFT JOIN file_uses AS fu
+ USING (item_version_id)
+ WHERE
+ fu.type = 'W' AND
+ p.payload_id = ? AND
+ (fu.idx IS NOT NULL OR rdd.idx IS NULL)
+ ORDER BY
+ rdd.idx, fu.idx;
+ ''',
+ (self.id,)
+ )
+
+ paths: list[t.Sequence[str]] = []
+ for resource_identifier, file_name in cursor.fetchall():
+ if resource_identifier is None:
+ # payload found but it had no script files
+ return ()
+
+ paths.append((resource_identifier, *file_name.split('/')))
+
+ if paths == []:
+ # payload not found
+ raise st.MissingItemError()
+
+ return paths
+
+ def get_file_data(self, path: t.Sequence[str]) \
+ -> t.Optional[st.FileData]:
+ if len(path) == 0:
+ raise st.MissingItemError()
+
+ resource_identifier, *file_name_segments = path
+
+ file_name = '/'.join(file_name_segments)
+
+ with self.state.cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ f.data, fu.mime_type
+ FROM
+ payloads AS p
+ JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ JOIN items AS i
+ USING (item_id)
+ JOIN file_uses AS fu
+ USING (item_version_id)
+ JOIN files AS f
+ USING (file_id)
+ WHERE
+ p.payload_id = ? AND
+ i.identifier = ? AND
+ fu.name = ? AND
+ fu.type = 'W';
+ ''',
+ (self.id, resource_identifier, file_name)
+ )
+
+ result = cursor.fetchall()
+
+ if result == []:
+ return None
+
+ (data, mime_type), = result
+
+ return st.FileData(type=mime_type, name=file_name, contents=data)
+
+def register_payload(
+ policy_tree: PolicyTree,
+ pattern: url_patterns.ParsedPattern,
+ payload_key: st.PayloadKey,
+ token: str
+) -> PolicyTree:
+ """...."""
+ payload_policy_factory = policies.PayloadPolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(pattern, payload_policy_factory)
+
+ resource_policy_factory = policies.PayloadResourcePolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ pattern.path_append(token, '***'),
+ resource_policy_factory
+ )
+
+ return policy_tree
# mypy needs to be corrected:
# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
@dc.dataclass # type: ignore[misc]
-class HaketiloStateWithFields(state.HaketiloState):
+class HaketiloStateWithFields(st.HaketiloState):
"""...."""
store_dir: Path
connection: sqlite3.Connection
current_cursor: t.Optional[sqlite3.Cursor] = None
- #settings: state.HaketiloGlobalSettings
+ #settings: st.HaketiloGlobalSettings
policy_tree: PolicyTree = PolicyTree()
payloads_data: PayloadsData = dc.field(default_factory=dict)
@@ -98,6 +224,85 @@ class HaketiloStateWithFields(state.HaketiloState):
finally:
self.current_cursor = None
- def recompute_payloads(self, cursor: sqlite3.Cursor) -> None:
+ def rebuild_structures(self) -> None:
+ """
+ Recreation of data structures as done after every recomputation of
+ dependencies as well as at startup.
+ """
+ with self.cursor(transaction=True) as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ p.payload_id, p.pattern, p.eval_allowed,
+ p.cors_bypass_allowed,
+ ms.enabled,
+ i.identifier
+ FROM
+ payloads AS p
+ JOIN item_versions AS iv
+ ON p.mapping_item_id = iv.item_version_id
+ JOIN items AS i USING (item_id)
+ JOIN mapping_statuses AS ms USING (item_id);
+ '''
+ )
+
+ rows = cursor.fetchall()
+
+ new_policy_tree = PolicyTree()
+
+ ui_factory = policies.WebUIPolicyFactory(builtin=True)
+ web_ui_pattern = 'http*://hkt.mitm.it/***'
+ for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern):
+ new_policy_tree = new_policy_tree.register(
+ parsed_pattern,
+ ui_factory
+ )
+
+ new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
+
+ for row in rows:
+ (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
+ enabled_status,
+ identifier) = row
+
+ payload_ref = ConcretePayloadRef(str(payload_id_int), self)
+
+ previous_data = self.payloads_data.get(payload_ref)
+ if previous_data is not None:
+ token = previous_data.unique_token
+ else:
+ token = secrets.token_urlsafe(8)
+
+ payload_key = st.PayloadKey(payload_ref, identifier)
+
+ for parsed_pattern in url_patterns.parse_pattern(pattern):
+ new_policy_tree = register_payload(
+ new_policy_tree,
+ parsed_pattern,
+ payload_key,
+ token
+ )
+
+ pattern_path_segments = parsed_pattern.path_segments
+
+ payload_data = st.PayloadData(
+ payload_ref = payload_ref,
+ explicitly_enabled = enabled_status == 'E',
+ unique_token = token,
+ pattern_path_segments = pattern_path_segments,
+ eval_allowed = eval_allowed,
+ cors_bypass_allowed = cors_bypass_allowed
+ )
+
+ new_payloads_data[payload_ref] = payload_data
+
+ self.policy_tree = new_policy_tree
+ self.payloads_data = new_payloads_data
+
+ @abstractmethod
+ def recompute_dependencies(
+ self,
+ requirements: t.Iterable[sds.MappingRequirement] = []
+ ) -> None:
"""...."""
...