summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWojtek Kosior <koszko@koszko.org>2022-08-11 13:33:06 +0200
committerWojtek Kosior <koszko@koszko.org>2022-08-11 13:33:06 +0200
commit67c58a14f4f356117f42fea368a32359496d46c4 (patch)
treebb8d1019dc4547a215404a40b60a031ca4f2e21d
parent3f3ba519ae3c3346945928b21ab36f7238e5387e (diff)
downloadhaketilo-hydrilla-67c58a14f4f356117f42fea368a32359496d46c4.tar.gz
haketilo-hydrilla-67c58a14f4f356117f42fea368a32359496d46c4.zip
populate data structures based on payloads data loaded from sqlite db
-rw-r--r--src/hydrilla/json_instances.py4
-rw-r--r--src/hydrilla/proxy/policies/payload.py2
-rw-r--r--src/hydrilla/proxy/policies/payload_resource.py2
-rw-r--r--src/hydrilla/proxy/state.py31
-rw-r--r--src/hydrilla/proxy/state_impl/base.py17
-rw-r--r--src/hydrilla/proxy/state_impl/concrete_state.py269
-rw-r--r--src/hydrilla/proxy/tables.sql7
-rw-r--r--src/hydrilla/versions.py6
8 files changed, 208 insertions, 130 deletions
diff --git a/src/hydrilla/json_instances.py b/src/hydrilla/json_instances.py
index 33a3785..fc8f975 100644
--- a/src/hydrilla/json_instances.py
+++ b/src/hydrilla/json_instances.py
@@ -43,7 +43,7 @@ from jsonschema import RefResolver, Draft7Validator # type: ignore
from .translations import smart_gettext as _
from .exceptions import HaketiloException
-from .versions import parse_version
+from . import versions
here = Path(__file__).resolve().parent
@@ -193,7 +193,7 @@ def get_schema_version(instance: object) -> tuple[int, ...]:
ver_str = match.group('ver') if match else None
if ver_str is not None:
- return parse_version(ver_str)
+ return versions.parse(ver_str)
else:
raise HaketiloException(_('no_schema_number_in_instance'))
diff --git a/src/hydrilla/proxy/policies/payload.py b/src/hydrilla/proxy/policies/payload.py
index d616f1b..1a88ea1 100644
--- a/src/hydrilla/proxy/policies/payload.py
+++ b/src/hydrilla/proxy/policies/payload.py
@@ -52,7 +52,7 @@ class PayloadAwarePolicy(base.Policy):
"""...."""
token = self.payload_data.unique_token
- base_path_segments = (*self.payload_data.pattern.path_segments, token)
+ base_path_segments = (*self.payload_data.pattern_path_segments, token)
return f'{request_url.url_without_path}/{"/".join(base_path_segments)}/'
diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py
index 84d0919..b255d4e 100644
--- a/src/hydrilla/proxy/policies/payload_resource.py
+++ b/src/hydrilla/proxy/policies/payload_resource.py
@@ -106,7 +106,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy):
# "/some/arbitrary/segments/<per-session_token>/actual/resource/path"
#
# Here we need to extract the "/actual/resource/path" part.
- segments_to_drop = len(self.payload_data.pattern.path_segments) + 1
+ segments_to_drop = len(self.payload_data.pattern_path_segments) + 1
resource_path = request_info.url.path_segments[segments_to_drop:]
if resource_path == ():
diff --git a/src/hydrilla/proxy/state.py b/src/hydrilla/proxy/state.py
index e22c9fe..f511056 100644
--- a/src/hydrilla/proxy/state.py
+++ b/src/hydrilla/proxy/state.py
@@ -140,39 +140,22 @@ class PayloadKey:
"""...."""
payload_ref: 'PayloadRef'
- mapping_identifier: str
- # mapping_version: VerTuple
- # mapping_repo: str
- # mapping_repo_iteration: int
- pattern: ParsedPattern
+ mapping_identifier: str
def __lt__(self, other: 'PayloadKey') -> bool:
"""...."""
- return (
- self.mapping_identifier,
- # other.mapping_version,
- # self.mapping_repo,
- # other.mapping_repo_iteration,
- self.pattern
- ) < (
- other.mapping_identifier,
- # self.mapping_version,
- # other.mapping_repo,
- # self.mapping_repo_iteration,
- other.pattern
- )
+ return self.mapping_identifier < other.mapping_identifier
@dc.dataclass(frozen=True)
class PayloadData:
"""...."""
payload_ref: 'PayloadRef'
- mapping_installed: bool
- explicitly_enabled: bool
- unique_token: str
- pattern: ParsedPattern
- eval_allowed: bool
- cors_bypass_allowed: bool
+ explicitly_enabled: bool
+ unique_token: str
+ pattern_path_segments: tuple[str, ...]
+ eval_allowed: bool
+ cors_bypass_allowed: bool
@dc.dataclass(frozen=True)
class FileData:
diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py
index 788a93d..1ae08cb 100644
--- a/src/hydrilla/proxy/state_impl/base.py
+++ b/src/hydrilla/proxy/state_impl/base.py
@@ -51,9 +51,7 @@ from .. import policies
PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
-
-#PayloadsDataMap = Map[state.PayloadRef, state.PayloadData]
-DataById = t.Mapping[str, state.PayloadData]
+PayloadsData = t.Mapping[state.PayloadRef, state.PayloadData]
# mypy needs to be corrected:
# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
@@ -65,9 +63,8 @@ class HaketiloStateWithFields(state.HaketiloState):
current_cursor: t.Optional[sqlite3.Cursor] = None
#settings: state.HaketiloGlobalSettings
- policy_tree: PolicyTree = PolicyTree()
- #payloads_data: PayloadsDataMap = Map()
- payloads_data: DataById = dc.field(default_factory=dict)
+ policy_tree: PolicyTree = PolicyTree()
+ payloads_data: PayloadsData = dc.field(default_factory=dict)
lock: threading.RLock = dc.field(default_factory=threading.RLock)
@@ -77,9 +74,7 @@ class HaketiloStateWithFields(state.HaketiloState):
"""...."""
start_transaction = transaction and not self.connection.in_transaction
- try:
- self.lock.acquire()
-
+ with self.lock:
if self.current_cursor is not None:
yield self.current_cursor
return
@@ -102,9 +97,7 @@ class HaketiloStateWithFields(state.HaketiloState):
raise
finally:
self.current_cursor = None
- finally:
- self.lock.release()
- def rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
+ def recompute_payloads(self, cursor: sqlite3.Cursor) -> None:
"""...."""
...
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py
index ccb1269..53f30ae 100644
--- a/src/hydrilla/proxy/state_impl/concrete_state.py
+++ b/src/hydrilla/proxy/state_impl/concrete_state.py
@@ -106,19 +106,52 @@ class ConcreteResourceVersionRef(st.ResourceVersionRef):
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcretePayloadRef(st.PayloadRef):
- computed_payload: ComputedPayload = dc.field(hash=False, compare=False)
-
def get_data(self, state: st.HaketiloState) -> st.PayloadData:
- return t.cast(ConcreteHaketiloState, state).payloads_data[self.id]
+ return t.cast(ConcreteHaketiloState, state).payloads_data[self]
def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
return 'to implement'
def get_script_paths(self, state: st.HaketiloState) \
- -> t.Iterator[t.Sequence[str]]:
- for resource_info in self.computed_payload.resources:
- for file_spec in resource_info.scripts:
- yield (resource_info.identifier, *file_spec.name.split('/'))
+ -> t.Iterable[t.Sequence[str]]:
+ with t.cast(ConcreteHaketiloState, state).cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ i.identifier, fu.name
+ FROM
+ payloads AS p
+ LEFT JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ LEFT JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ LEFT JOIN items AS i
+ USING (item_id)
+ LEFT JOIN file_uses AS fu
+ USING (item_version_id)
+ WHERE
+ fu.type = 'W' AND
+ p.payload_id = ? AND
+ (fu.idx IS NOT NULL OR rdd.idx IS NULL)
+ ORDER BY
+ rdd.idx, fu.idx;
+ ''',
+ (self.id,)
+ )
+
+ paths: list[t.Sequence[str]] = []
+ for resource_identifier, file_name in cursor.fetchall():
+ if resource_identifier is None:
+ # payload found but it had no script files
+ return ()
+
+ paths.append((resource_identifier, *file_name.split('/')))
+
+ if paths == []:
+ # payload not found
+ raise st.MissingItemError()
+
+ return paths
def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
-> t.Optional[st.FileData]:
@@ -129,53 +162,109 @@ class ConcretePayloadRef(st.PayloadRef):
file_name = '/'.join(file_name_segments)
- script_sha256 = ''
+ with t.cast(ConcreteHaketiloState, state).cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ f.data, fu.mime_type
+ FROM
+ payloads AS p
+ JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ JOIN items AS i
+ USING (item_id)
+ JOIN file_uses AS fu
+ USING (item_version_id)
+ JOIN files AS f
+ USING (file_id)
+ WHERE
+ p.payload_id = ? AND
+ i.identifier = ? AND
+ fu.name = ? AND
+ fu.type = 'W';
+ ''',
+ (self.id, resource_identifier, file_name)
+ )
+
+ result = cursor.fetchall()
- matched_resource_info = False
+ if result == []:
+ return None
- for resource_info in self.computed_payload.resources:
- if resource_info.identifier == resource_identifier:
- matched_resource_info = True
+ (data, mime_type), = result
- for script_spec in resource_info.scripts:
- if script_spec.name == file_name:
- script_sha256 = script_spec.sha256
+ return st.FileData(type=mime_type, name=file_name, contents=data)
- break
+# @dc.dataclass(frozen=True, unsafe_hash=True)
+# class ConcretePayloadRef(st.PayloadRef):
+# computed_payload: ComputedPayload = dc.field(hash=False, compare=False)
- if not matched_resource_info:
- raise st.MissingItemError(resource_identifier)
+# def get_data(self, state: st.HaketiloState) -> st.PayloadData:
+# return t.cast(ConcreteHaketiloState, state).payloads_data[self.id]
- if script_sha256 == '':
- return None
+# def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
+# return 'to implement'
- store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir
- files_dir_path = store_dir_path / 'temporary_malcontent' / 'file'
- file_path = files_dir_path / 'sha256' / script_sha256
+# def get_script_paths(self, state: st.HaketiloState) \
+# -> t.Iterator[t.Sequence[str]]:
+# for resource_info in self.computed_payload.resources:
+# for file_spec in resource_info.scripts:
+# yield (resource_info.identifier, *file_spec.name.split('/'))
- return st.FileData(
- type = 'application/javascript',
- name = file_name,
- contents = file_path.read_bytes()
- )
+# def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
+# -> t.Optional[st.FileData]:
+# if len(path) == 0:
+# raise st.MissingItemError()
+
+# resource_identifier, *file_name_segments = path
+
+# file_name = '/'.join(file_name_segments)
+
+# script_sha256 = ''
-PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
+# matched_resource_info = False
+
+# for resource_info in self.computed_payload.resources:
+# if resource_info.identifier == resource_identifier:
+# matched_resource_info = True
+
+# for script_spec in resource_info.scripts:
+# if script_spec.name == file_name:
+# script_sha256 = script_spec.sha256
+
+# break
+
+# if not matched_resource_info:
+# raise st.MissingItemError(resource_identifier)
+
+# if script_sha256 == '':
+# return None
+
+# store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir
+# files_dir_path = store_dir_path / 'temporary_malcontent' / 'file'
+# file_path = files_dir_path / 'sha256' / script_sha256
+
+# return st.FileData(
+# type = 'application/javascript',
+# name = file_name,
+# contents = file_path.read_bytes()
+# )
def register_payload(
- policy_tree: PolicyTree,
+ policy_tree: base.PolicyTree,
+ pattern: url_patterns.ParsedPattern,
payload_key: st.PayloadKey,
token: str
-) -> PolicyTree:
+) -> base.PolicyTree:
"""...."""
payload_policy_factory = policies.PayloadPolicyFactory(
builtin = False,
payload_key = payload_key
)
- policy_tree = policy_tree.register(
- payload_key.pattern,
- payload_policy_factory
- )
+ policy_tree = policy_tree.register(pattern, payload_policy_factory)
resource_policy_factory = policies.PayloadResourcePolicyFactory(
builtin = False,
@@ -183,7 +272,7 @@ def register_payload(
)
policy_tree = policy_tree.register(
- payload_key.pattern.path_append(token, '***'),
+ pattern.path_append(token, '***'),
resource_policy_factory
)
@@ -197,24 +286,6 @@ AnyInfoVar = t.TypeVar(
item_infos.MappingInfo
)
-# def newest_item_path(item_dir: Path) -> t.Optional[Path]:
-# available_versions = tuple(
-# versions.parse_normalize_version(ver_path.name)
-# for ver_path in item_dir.iterdir()
-# if ver_path.is_file()
-# )
-
-# if available_versions == ():
-# return None
-
-# newest_version = max(available_versions)
-
-# version_path = item_dir / versions.version_string(newest_version)
-
-# assert version_path.is_file()
-
-# return version_path
-
def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
-> t.Iterator[tuple[AnyInfoVar, str]]:
item_type_path = malcontent_path / item_class.type_name
@@ -440,7 +511,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
self._populate_database_with_stuff_from_temporary_malcontent_dir()
with self.cursor(transaction=True) as cursor:
- self.rebuild_structures(cursor)
+ self.recompute_payloads(cursor)
def _prepare_database(self) -> None:
"""...."""
@@ -568,7 +639,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
idx = idx
)
- def rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
+ def recompute_payloads(self, cursor: sqlite3.Cursor) -> None:
assert self.connection.in_transaction
resources = get_infos_of_type(cursor, item_infos.ResourceInfo)
@@ -576,13 +647,12 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
payloads = compute_payloads(resources.keys(), mappings.keys())
- payloads_data = {}
+ payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
cursor.execute('DELETE FROM payloads;')
for mapping_info, by_pattern in payloads.items():
for num, (pattern, payload) in enumerate(by_pattern.items()):
- print('adding payload')
cursor.execute(
'''
INSERT INTO payloads(
@@ -628,35 +698,67 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
(payload_id_int, resources[resource_info], res_num)
)
- payload_id = str(payload_id_int)
+ self._rebuild_structures(cursor)
+
+ def _rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
+ cursor.execute(
+ '''
+ SELECT
+ p.payload_id, p.pattern, p.eval_allowed,
+ p.cors_bypass_allowed,
+ ms.enabled,
+ i.identifier
+ FROM
+ payloads AS p
+ JOIN item_versions AS iv
+ ON p.mapping_item_id = iv.item_version_id
+ JOIN items AS i USING (item_id)
+ JOIN mapping_statuses AS ms USING (item_id);
+ '''
+ )
- ref = ConcretePayloadRef(payload_id, payload)
+ new_policy_tree = base.PolicyTree()
+ new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
- data = st.PayloadData(
- payload_ref = ref,
- mapping_installed = True,
- explicitly_enabled = True,
- unique_token = secrets.token_urlsafe(16),
- pattern = pattern,
- eval_allowed = payload.allows_eval,
- cors_bypass_allowed = payload.allows_cors_bypass
- )
+ for row in cursor.fetchall():
+ (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
+ enabled_status,
+ identifier) = row
- payloads_data[payload_id] = data
+ payload_ref = ConcretePayloadRef(str(payload_id_int))
- key = st.PayloadKey(
- payload_ref = ref,
- mapping_identifier = mapping_info.identifier,
- pattern = pattern
- )
+ previous_data = self.payloads_data.get(payload_ref)
+ if previous_data is not None:
+ token = previous_data.unique_token
+ else:
+ token = secrets.token_urlsafe(8)
- self.policy_tree = register_payload(
- self.policy_tree,
- key,
- data.unique_token
+ payload_key = st.PayloadKey(payload_ref, identifier)
+
+ for parsed_pattern in url_patterns.parse_pattern(pattern):
+ new_policy_tree = register_payload(
+ new_policy_tree,
+ parsed_pattern,
+ payload_key,
+ token
)
- self.payloads_data = payloads_data
+ pattern_path_segments = parsed_pattern.path_segments
+
+ payload_data = st.PayloadData(
+ payload_ref = payload_ref,
+ #explicitly_enabled = enabled_status == 'E',
+ explicitly_enabled = True,
+ unique_token = token,
+ pattern_path_segments = pattern_path_segments,
+ eval_allowed = eval_allowed,
+ cors_bypass_allowed = cors_bypass_allowed
+ )
+
+ new_payloads_data[payload_ref] = payload_data
+
+ self.policy_tree = new_policy_tree
+ self.payloads_data = new_payloads_data
def get_repo(self, repo_id: str) -> st.RepoRef:
return ConcreteRepoRef(repo_id)
@@ -735,7 +837,12 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
@staticmethod
def make(store_dir: Path) -> 'ConcreteHaketiloState':
+ connection = sqlite3.connect(
+ str(store_dir / 'sqlite3.db'),
+ isolation_level = None,
+ check_same_thread = False
+ )
return ConcreteHaketiloState(
store_dir = store_dir,
- connection = sqlite3.connect(str(store_dir / 'sqlite3.db'))
+ connection = connection
)
diff --git a/src/hydrilla/proxy/tables.sql b/src/hydrilla/proxy/tables.sql
index 25493d3..2a6cac6 100644
--- a/src/hydrilla/proxy/tables.sql
+++ b/src/hydrilla/proxy/tables.sql
@@ -121,15 +121,11 @@ CREATE TABLE mapping_statuses(
-- "enabled" determines whether mapping's status is ENABLED,
-- DISABLED or NO_MARK.
enabled CHAR(1) NOT NULL,
- enabled_version_id INTEGER NULL,
-- "frozen" determines whether an enabled mapping is to be kept in its
-- EXACT_VERSION, is to be updated only with versions from the same
-- REPOSITORY or is NOT_FROZEN at all.
frozen CHAR(1) NULL,
- CHECK (NOT (enabled = 'D' AND enabled_version_id IS NOT NULL)),
- CHECK (NOT (enabled = 'E' AND enabled_version_id IS NULL)),
-
CHECK ((frozen IS NULL) = (enabled != 'E')),
CHECK (frozen IS NULL OR frozen in ('E', 'R', 'N'))
);
@@ -143,8 +139,6 @@ CREATE TABLE item_versions(
definition TEXT NOT NULL,
UNIQUE (item_id, version, repo_iteration_id),
- -- Allow foreign key from "mapping_statuses".
- UNIQUE (item_version_id, item_id),
FOREIGN KEY (item_id)
REFERENCES items (item_id),
@@ -226,6 +220,7 @@ CREATE TABLE file_uses(
CHECK (type IN ('L', 'W')),
UNIQUE(item_version_id, type, idx),
+ UNIQUE(item_version_id, type, name),
FOREIGN KEY (item_version_id)
REFERENCES item_versions(item_version_id)
diff --git a/src/hydrilla/versions.py b/src/hydrilla/versions.py
index 7474d98..93f395d 100644
--- a/src/hydrilla/versions.py
+++ b/src/hydrilla/versions.py
@@ -45,19 +45,19 @@ def normalize_version(ver: t.Sequence[int]) -> VerTuple:
return VerTuple(tuple(ver[:new_len]))
-def parse_version(ver_str: str) -> tuple[int, ...]:
+def parse(ver_str: str) -> tuple[int, ...]:
"""
Convert 'ver_str' into an array representation, e.g. for ver_str="4.6.13.0"
return [4, 6, 13, 0].
"""
return tuple(int(num) for num in ver_str.split('.'))
-def parse_normalize_version(ver_str: str) -> VerTuple:
+def parse_normalize(ver_str: str) -> VerTuple:
"""
Convert 'ver_str' into a VerTuple representation, e.g. for
ver_str="4.6.13.0" return (4, 6, 13).
"""
- return normalize_version(parse_version(ver_str))
+ return normalize_version(parse(ver_str))
def version_string(ver: VerTuple, rev: t.Optional[int] = None) -> str:
"""