aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/state_impl/concrete_state.py
diff options
context:
space:
mode:
authorWojtek Kosior <koszko@koszko.org>2022-08-22 12:52:59 +0200
committerWojtek Kosior <koszko@koszko.org>2022-09-28 12:54:51 +0200
commit8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5 (patch)
tree4c4956e45701460bedaa0d8b0be808052152777f /src/hydrilla/proxy/state_impl/concrete_state.py
parente1344ae7017b28a54d7714895bd54c8431a20bc6 (diff)
downloadhaketilo-hydrilla-8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5.tar.gz
haketilo-hydrilla-8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5.zip
allow pulling packages from remote repository
Diffstat (limited to 'src/hydrilla/proxy/state_impl/concrete_state.py')
-rw-r--r--src/hydrilla/proxy/state_impl/concrete_state.py322
1 files changed, 20 insertions, 302 deletions
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py
index 525a702..6a59a75 100644
--- a/src/hydrilla/proxy/state_impl/concrete_state.py
+++ b/src/hydrilla/proxy/state_impl/concrete_state.py
@@ -32,28 +32,23 @@ and resources.
# Enable using with Python 3.7.
from __future__ import annotations
-import secrets
-import io
-import hashlib
+import sqlite3
import typing as t
import dataclasses as dc
from pathlib import Path
-import sqlite3
-
from ...exceptions import HaketiloException
from ...translations import smart_gettext as _
-from ... import pattern_tree
from ... import url_patterns
from ... import item_infos
-from ..simple_dependency_satisfying import compute_payloads, ComputedPayload
from .. import state as st
from .. import policies
+from .. import simple_dependency_satisfying as sds
from . import base
from . import mappings
from . import repos
-from .load_packages import load_packages
+from . import _operations
here = Path(__file__).resolve().parent
@@ -73,169 +68,14 @@ class ConcreteResourceRef(st.ResourceRef):
class ConcreteResourceVersionRef(st.ResourceVersionRef):
pass
-
-@dc.dataclass(frozen=True, unsafe_hash=True)
-class ConcretePayloadRef(st.PayloadRef):
- state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False)
-
- def get_data(self) -> st.PayloadData:
- try:
- return self.state.payloads_data[self]
- except KeyError:
- raise st.MissingItemError()
-
- def get_mapping(self) -> st.MappingVersionRef:
- raise NotImplementedError()
-
- def get_script_paths(self) \
- -> t.Iterable[t.Sequence[str]]:
- with self.state.cursor() as cursor:
- cursor.execute(
- '''
- SELECT
- i.identifier, fu.name
- FROM
- payloads AS p
- LEFT JOIN resolved_depended_resources AS rdd
- USING (payload_id)
- LEFT JOIN item_versions AS iv
- ON rdd.resource_item_id = iv.item_version_id
- LEFT JOIN items AS i
- USING (item_id)
- LEFT JOIN file_uses AS fu
- USING (item_version_id)
- WHERE
- fu.type = 'W' AND
- p.payload_id = ? AND
- (fu.idx IS NOT NULL OR rdd.idx IS NULL)
- ORDER BY
- rdd.idx, fu.idx;
- ''',
- (self.id,)
- )
-
- paths: list[t.Sequence[str]] = []
- for resource_identifier, file_name in cursor.fetchall():
- if resource_identifier is None:
- # payload found but it had no script files
- return ()
-
- paths.append((resource_identifier, *file_name.split('/')))
-
- if paths == []:
- # payload not found
- raise st.MissingItemError()
-
- return paths
-
- def get_file_data(self, path: t.Sequence[str]) \
- -> t.Optional[st.FileData]:
- if len(path) == 0:
- raise st.MissingItemError()
-
- resource_identifier, *file_name_segments = path
-
- file_name = '/'.join(file_name_segments)
-
- with self.state.cursor() as cursor:
- cursor.execute(
- '''
- SELECT
- f.data, fu.mime_type
- FROM
- payloads AS p
- JOIN resolved_depended_resources AS rdd
- USING (payload_id)
- JOIN item_versions AS iv
- ON rdd.resource_item_id = iv.item_version_id
- JOIN items AS i
- USING (item_id)
- JOIN file_uses AS fu
- USING (item_version_id)
- JOIN files AS f
- USING (file_id)
- WHERE
- p.payload_id = ? AND
- i.identifier = ? AND
- fu.name = ? AND
- fu.type = 'W';
- ''',
- (self.id, resource_identifier, file_name)
- )
-
- result = cursor.fetchall()
-
- if result == []:
- return None
-
- (data, mime_type), = result
-
- return st.FileData(type=mime_type, name=file_name, contents=data)
-
-def register_payload(
- policy_tree: base.PolicyTree,
- pattern: url_patterns.ParsedPattern,
- payload_key: st.PayloadKey,
- token: str
-) -> base.PolicyTree:
- """...."""
- payload_policy_factory = policies.PayloadPolicyFactory(
- builtin = False,
- payload_key = payload_key
- )
-
- policy_tree = policy_tree.register(pattern, payload_policy_factory)
-
- resource_policy_factory = policies.PayloadResourcePolicyFactory(
- builtin = False,
- payload_key = payload_key
- )
-
- policy_tree = policy_tree.register(
- pattern.path_append(token, '***'),
- resource_policy_factory
- )
-
- return policy_tree
-
-AnyInfoVar = t.TypeVar(
- 'AnyInfoVar',
- item_infos.ResourceInfo,
- item_infos.MappingInfo
-)
-
-def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \
- -> t.Mapping[AnyInfoVar, int]:
- cursor.execute(
- '''
- SELECT
- i.item_id, iv.definition, r.name, ri.iteration
- FROM
- item_versions AS iv
- JOIN items AS i USING (item_id)
- JOIN repo_iterations AS ri USING (repo_iteration_id)
- JOIN repos AS r USING (repo_id)
- WHERE
- i.type = ?;
- ''',
- (info_type.type_name[0].upper(),)
- )
-
- result: dict[AnyInfoVar, int] = {}
-
- for item_id, definition, repo_name, repo_iteration in cursor.fetchall():
- definition_io = io.StringIO(definition)
- info = info_type.load(definition_io, repo_name, repo_iteration)
- result[info] = item_id
-
- return result
-
@dc.dataclass
class ConcreteHaketiloState(base.HaketiloStateWithFields):
def __post_init__(self) -> None:
+ sqlite3.enable_callback_tracebacks(True)
+
self._prepare_database()
- self._rebuild_structures()
+ self.rebuild_structures()
def _prepare_database(self) -> None:
"""...."""
@@ -282,147 +122,25 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
def import_packages(self, malcontent_path: Path) -> None:
with self.cursor(transaction=True) as cursor:
- load_packages(self, cursor, malcontent_path)
- self.recompute_payloads(cursor)
-
- def recompute_payloads(self, cursor: sqlite3.Cursor) -> None:
- assert self.connection.in_transaction
-
- resources = get_infos_of_type(cursor, item_infos.ResourceInfo)
- mappings = get_infos_of_type(cursor, item_infos.MappingInfo)
-
- payloads = compute_payloads(resources.keys(), mappings.keys())
-
- payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
-
- cursor.execute('DELETE FROM payloads;')
-
- for mapping_info, by_pattern in payloads.items():
- for num, (pattern, payload) in enumerate(by_pattern.items()):
- cursor.execute(
- '''
- INSERT INTO payloads(
- mapping_item_id,
- pattern,
- eval_allowed,
- cors_bypass_allowed
- )
- VALUES (?, ?, ?, ?);
- ''',
- (
- mappings[mapping_info],
- pattern.orig_url,
- payload.allows_eval,
- payload.allows_cors_bypass
- )
- )
+ _operations.load_packages(cursor, malcontent_path, 1)
+ raise NotImplementedError()
+ _operations.prune_packages(cursor)
- cursor.execute(
- '''
- SELECT
- payload_id
- FROM
- payloads
- WHERE
- mapping_item_id = ? AND pattern = ?;
- ''',
- (mappings[mapping_info], pattern.orig_url)
- )
+ self.recompute_dependencies()
- (payload_id_int,), = cursor.fetchall()
-
- for res_num, resource_info in enumerate(payload.resources):
- cursor.execute(
- '''
- INSERT INTO resolved_depended_resources(
- payload_id,
- resource_item_id,
- idx
- )
- VALUES(?, ?, ?);
- ''',
- (payload_id_int, resources[resource_info], res_num)
- )
-
- self._rebuild_structures(cursor)
-
- def _rebuild_structures(self, cursor: t.Optional[sqlite3.Cursor] = None) \
- -> None:
- """
- Recreation of data structures as done after every recomputation of
- dependencies as well as at startup.
- """
- if cursor is None:
- with self.cursor() as new_cursor:
- return self._rebuild_structures(new_cursor)
-
- cursor.execute(
- '''
- SELECT
- p.payload_id, p.pattern, p.eval_allowed,
- p.cors_bypass_allowed,
- ms.enabled,
- i.identifier
- FROM
- payloads AS p
- JOIN item_versions AS iv
- ON p.mapping_item_id = iv.item_version_id
- JOIN items AS i USING (item_id)
- JOIN mapping_statuses AS ms USING (item_id);
- '''
- )
-
- new_policy_tree = base.PolicyTree()
-
- ui_factory = policies.WebUIPolicyFactory(builtin=True)
- web_ui_pattern = 'http*://hkt.mitm.it/***'
- for parsed_pattern in url_patterns.parse_pattern(web_ui_pattern):
- new_policy_tree = new_policy_tree.register(
- parsed_pattern,
- ui_factory
- )
-
- new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
-
- for row in cursor.fetchall():
- (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
- enabled_status,
- identifier) = row
-
- payload_ref = ConcretePayloadRef(str(payload_id_int), self)
-
- previous_data = self.payloads_data.get(payload_ref)
- if previous_data is not None:
- token = previous_data.unique_token
- else:
- token = secrets.token_urlsafe(8)
-
- payload_key = st.PayloadKey(payload_ref, identifier)
-
- for parsed_pattern in url_patterns.parse_pattern(pattern):
- new_policy_tree = register_payload(
- new_policy_tree,
- parsed_pattern,
- payload_key,
- token
- )
-
- pattern_path_segments = parsed_pattern.path_segments
+ def recompute_dependencies(
+ self,
+ extra_requirements: t.Iterable[sds.MappingRequirement] = []
+ ) -> None:
+ with self.cursor() as cursor:
+ assert self.connection.in_transaction
- payload_data = st.PayloadData(
- payload_ref = payload_ref,
- #explicitly_enabled = enabled_status == 'E',
- explicitly_enabled = True,
- unique_token = token,
- pattern_path_segments = pattern_path_segments,
- eval_allowed = eval_allowed,
- cors_bypass_allowed = cors_bypass_allowed
+ _operations._recompute_dependencies_no_state_update(
+ cursor,
+ extra_requirements
)
- new_payloads_data[payload_ref] = payload_data
-
- self.policy_tree = new_policy_tree
- self.payloads_data = new_payloads_data
+ self.rebuild_structures()
def repo_store(self) -> st.RepoStore:
return repos.ConcreteRepoStore(self)