From 67c58a14f4f356117f42fea368a32359496d46c4 Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Thu, 11 Aug 2022 13:33:06 +0200 Subject: populate data structures based on payloads data loaded from sqlite db --- src/hydrilla/json_instances.py | 4 +- src/hydrilla/proxy/policies/payload.py | 2 +- src/hydrilla/proxy/policies/payload_resource.py | 2 +- src/hydrilla/proxy/state.py | 31 +-- src/hydrilla/proxy/state_impl/base.py | 17 +- src/hydrilla/proxy/state_impl/concrete_state.py | 269 +++++++++++++++++------- src/hydrilla/proxy/tables.sql | 7 +- src/hydrilla/versions.py | 6 +- 8 files changed, 208 insertions(+), 130 deletions(-) (limited to 'src') diff --git a/src/hydrilla/json_instances.py b/src/hydrilla/json_instances.py index 33a3785..fc8f975 100644 --- a/src/hydrilla/json_instances.py +++ b/src/hydrilla/json_instances.py @@ -43,7 +43,7 @@ from jsonschema import RefResolver, Draft7Validator # type: ignore from .translations import smart_gettext as _ from .exceptions import HaketiloException -from .versions import parse_version +from . import versions here = Path(__file__).resolve().parent @@ -193,7 +193,7 @@ def get_schema_version(instance: object) -> tuple[int, ...]: ver_str = match.group('ver') if match else None if ver_str is not None: - return parse_version(ver_str) + return versions.parse(ver_str) else: raise HaketiloException(_('no_schema_number_in_instance')) diff --git a/src/hydrilla/proxy/policies/payload.py b/src/hydrilla/proxy/policies/payload.py index d616f1b..1a88ea1 100644 --- a/src/hydrilla/proxy/policies/payload.py +++ b/src/hydrilla/proxy/policies/payload.py @@ -52,7 +52,7 @@ class PayloadAwarePolicy(base.Policy): """....""" token = self.payload_data.unique_token - base_path_segments = (*self.payload_data.pattern.path_segments, token) + base_path_segments = (*self.payload_data.pattern_path_segments, token) return f'{request_url.url_without_path}/{"/".join(base_path_segments)}/' diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py index 84d0919..b255d4e 100644 --- a/src/hydrilla/proxy/policies/payload_resource.py +++ b/src/hydrilla/proxy/policies/payload_resource.py @@ -106,7 +106,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy): # "/some/arbitrary/segments//actual/resource/path" # # Here we need to extract the "/actual/resource/path" part. - segments_to_drop = len(self.payload_data.pattern.path_segments) + 1 + segments_to_drop = len(self.payload_data.pattern_path_segments) + 1 resource_path = request_info.url.path_segments[segments_to_drop:] if resource_path == (): diff --git a/src/hydrilla/proxy/state.py b/src/hydrilla/proxy/state.py index e22c9fe..f511056 100644 --- a/src/hydrilla/proxy/state.py +++ b/src/hydrilla/proxy/state.py @@ -140,39 +140,22 @@ class PayloadKey: """....""" payload_ref: 'PayloadRef' - mapping_identifier: str - # mapping_version: VerTuple - # mapping_repo: str - # mapping_repo_iteration: int - pattern: ParsedPattern + mapping_identifier: str def __lt__(self, other: 'PayloadKey') -> bool: """....""" - return ( - self.mapping_identifier, - # other.mapping_version, - # self.mapping_repo, - # other.mapping_repo_iteration, - self.pattern - ) < ( - other.mapping_identifier, - # self.mapping_version, - # other.mapping_repo, - # self.mapping_repo_iteration, - other.pattern - ) + return self.mapping_identifier < other.mapping_identifier @dc.dataclass(frozen=True) class PayloadData: """....""" payload_ref: 'PayloadRef' - mapping_installed: bool - explicitly_enabled: bool - unique_token: str - pattern: ParsedPattern - eval_allowed: bool - cors_bypass_allowed: bool + explicitly_enabled: bool + unique_token: str + pattern_path_segments: tuple[str, ...] + eval_allowed: bool + cors_bypass_allowed: bool @dc.dataclass(frozen=True) class FileData: diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py index 788a93d..1ae08cb 100644 --- a/src/hydrilla/proxy/state_impl/base.py +++ b/src/hydrilla/proxy/state_impl/base.py @@ -51,9 +51,7 @@ from .. import policies PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] - -#PayloadsDataMap = Map[state.PayloadRef, state.PayloadData] -DataById = t.Mapping[str, state.PayloadData] +PayloadsData = t.Mapping[state.PayloadRef, state.PayloadData] # mypy needs to be corrected: # https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 @@ -65,9 +63,8 @@ class HaketiloStateWithFields(state.HaketiloState): current_cursor: t.Optional[sqlite3.Cursor] = None #settings: state.HaketiloGlobalSettings - policy_tree: PolicyTree = PolicyTree() - #payloads_data: PayloadsDataMap = Map() - payloads_data: DataById = dc.field(default_factory=dict) + policy_tree: PolicyTree = PolicyTree() + payloads_data: PayloadsData = dc.field(default_factory=dict) lock: threading.RLock = dc.field(default_factory=threading.RLock) @@ -77,9 +74,7 @@ class HaketiloStateWithFields(state.HaketiloState): """....""" start_transaction = transaction and not self.connection.in_transaction - try: - self.lock.acquire() - + with self.lock: if self.current_cursor is not None: yield self.current_cursor return @@ -102,9 +97,7 @@ class HaketiloStateWithFields(state.HaketiloState): raise finally: self.current_cursor = None - finally: - self.lock.release() - def rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: """....""" ... diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index ccb1269..53f30ae 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -106,19 +106,52 @@ class ConcreteResourceVersionRef(st.ResourceVersionRef): @dc.dataclass(frozen=True, unsafe_hash=True) class ConcretePayloadRef(st.PayloadRef): - computed_payload: ComputedPayload = dc.field(hash=False, compare=False) - def get_data(self, state: st.HaketiloState) -> st.PayloadData: - return t.cast(ConcreteHaketiloState, state).payloads_data[self.id] + return t.cast(ConcreteHaketiloState, state).payloads_data[self] def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: return 'to implement' def get_script_paths(self, state: st.HaketiloState) \ - -> t.Iterator[t.Sequence[str]]: - for resource_info in self.computed_payload.resources: - for file_spec in resource_info.scripts: - yield (resource_info.identifier, *file_spec.name.split('/')) + -> t.Iterable[t.Sequence[str]]: + with t.cast(ConcreteHaketiloState, state).cursor() as cursor: + cursor.execute( + ''' + SELECT + i.identifier, fu.name + FROM + payloads AS p + LEFT JOIN resolved_depended_resources AS rdd + USING (payload_id) + LEFT JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + LEFT JOIN items AS i + USING (item_id) + LEFT JOIN file_uses AS fu + USING (item_version_id) + WHERE + fu.type = 'W' AND + p.payload_id = ? AND + (fu.idx IS NOT NULL OR rdd.idx IS NULL) + ORDER BY + rdd.idx, fu.idx; + ''', + (self.id,) + ) + + paths: list[t.Sequence[str]] = [] + for resource_identifier, file_name in cursor.fetchall(): + if resource_identifier is None: + # payload found but it had no script files + return () + + paths.append((resource_identifier, *file_name.split('/'))) + + if paths == []: + # payload not found + raise st.MissingItemError() + + return paths def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ -> t.Optional[st.FileData]: @@ -129,53 +162,109 @@ class ConcretePayloadRef(st.PayloadRef): file_name = '/'.join(file_name_segments) - script_sha256 = '' + with t.cast(ConcreteHaketiloState, state).cursor() as cursor: + cursor.execute( + ''' + SELECT + f.data, fu.mime_type + FROM + payloads AS p + JOIN resolved_depended_resources AS rdd + USING (payload_id) + JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + JOIN items AS i + USING (item_id) + JOIN file_uses AS fu + USING (item_version_id) + JOIN files AS f + USING (file_id) + WHERE + p.payload_id = ? AND + i.identifier = ? AND + fu.name = ? AND + fu.type = 'W'; + ''', + (self.id, resource_identifier, file_name) + ) + + result = cursor.fetchall() - matched_resource_info = False + if result == []: + return None - for resource_info in self.computed_payload.resources: - if resource_info.identifier == resource_identifier: - matched_resource_info = True + (data, mime_type), = result - for script_spec in resource_info.scripts: - if script_spec.name == file_name: - script_sha256 = script_spec.sha256 + return st.FileData(type=mime_type, name=file_name, contents=data) - break +# @dc.dataclass(frozen=True, unsafe_hash=True) +# class ConcretePayloadRef(st.PayloadRef): +# computed_payload: ComputedPayload = dc.field(hash=False, compare=False) - if not matched_resource_info: - raise st.MissingItemError(resource_identifier) +# def get_data(self, state: st.HaketiloState) -> st.PayloadData: +# return t.cast(ConcreteHaketiloState, state).payloads_data[self.id] - if script_sha256 == '': - return None +# def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: +# return 'to implement' - store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir - files_dir_path = store_dir_path / 'temporary_malcontent' / 'file' - file_path = files_dir_path / 'sha256' / script_sha256 +# def get_script_paths(self, state: st.HaketiloState) \ +# -> t.Iterator[t.Sequence[str]]: +# for resource_info in self.computed_payload.resources: +# for file_spec in resource_info.scripts: +# yield (resource_info.identifier, *file_spec.name.split('/')) - return st.FileData( - type = 'application/javascript', - name = file_name, - contents = file_path.read_bytes() - ) +# def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ +# -> t.Optional[st.FileData]: +# if len(path) == 0: +# raise st.MissingItemError() + +# resource_identifier, *file_name_segments = path + +# file_name = '/'.join(file_name_segments) + +# script_sha256 = '' -PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] +# matched_resource_info = False + +# for resource_info in self.computed_payload.resources: +# if resource_info.identifier == resource_identifier: +# matched_resource_info = True + +# for script_spec in resource_info.scripts: +# if script_spec.name == file_name: +# script_sha256 = script_spec.sha256 + +# break + +# if not matched_resource_info: +# raise st.MissingItemError(resource_identifier) + +# if script_sha256 == '': +# return None + +# store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir +# files_dir_path = store_dir_path / 'temporary_malcontent' / 'file' +# file_path = files_dir_path / 'sha256' / script_sha256 + +# return st.FileData( +# type = 'application/javascript', +# name = file_name, +# contents = file_path.read_bytes() +# ) def register_payload( - policy_tree: PolicyTree, + policy_tree: base.PolicyTree, + pattern: url_patterns.ParsedPattern, payload_key: st.PayloadKey, token: str -) -> PolicyTree: +) -> base.PolicyTree: """....""" payload_policy_factory = policies.PayloadPolicyFactory( builtin = False, payload_key = payload_key ) - policy_tree = policy_tree.register( - payload_key.pattern, - payload_policy_factory - ) + policy_tree = policy_tree.register(pattern, payload_policy_factory) resource_policy_factory = policies.PayloadResourcePolicyFactory( builtin = False, @@ -183,7 +272,7 @@ def register_payload( ) policy_tree = policy_tree.register( - payload_key.pattern.path_append(token, '***'), + pattern.path_append(token, '***'), resource_policy_factory ) @@ -197,24 +286,6 @@ AnyInfoVar = t.TypeVar( item_infos.MappingInfo ) -# def newest_item_path(item_dir: Path) -> t.Optional[Path]: -# available_versions = tuple( -# versions.parse_normalize_version(ver_path.name) -# for ver_path in item_dir.iterdir() -# if ver_path.is_file() -# ) - -# if available_versions == (): -# return None - -# newest_version = max(available_versions) - -# version_path = item_dir / versions.version_string(newest_version) - -# assert version_path.is_file() - -# return version_path - def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ -> t.Iterator[tuple[AnyInfoVar, str]]: item_type_path = malcontent_path / item_class.type_name @@ -440,7 +511,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): self._populate_database_with_stuff_from_temporary_malcontent_dir() with self.cursor(transaction=True) as cursor: - self.rebuild_structures(cursor) + self.recompute_payloads(cursor) def _prepare_database(self) -> None: """....""" @@ -568,7 +639,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): idx = idx ) - def rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: assert self.connection.in_transaction resources = get_infos_of_type(cursor, item_infos.ResourceInfo) @@ -576,13 +647,12 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): payloads = compute_payloads(resources.keys(), mappings.keys()) - payloads_data = {} + payloads_data: dict[st.PayloadRef, st.PayloadData] = {} cursor.execute('DELETE FROM payloads;') for mapping_info, by_pattern in payloads.items(): for num, (pattern, payload) in enumerate(by_pattern.items()): - print('adding payload') cursor.execute( ''' INSERT INTO payloads( @@ -628,35 +698,67 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): (payload_id_int, resources[resource_info], res_num) ) - payload_id = str(payload_id_int) + self._rebuild_structures(cursor) + + def _rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + cursor.execute( + ''' + SELECT + p.payload_id, p.pattern, p.eval_allowed, + p.cors_bypass_allowed, + ms.enabled, + i.identifier + FROM + payloads AS p + JOIN item_versions AS iv + ON p.mapping_item_id = iv.item_version_id + JOIN items AS i USING (item_id) + JOIN mapping_statuses AS ms USING (item_id); + ''' + ) - ref = ConcretePayloadRef(payload_id, payload) + new_policy_tree = base.PolicyTree() + new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - data = st.PayloadData( - payload_ref = ref, - mapping_installed = True, - explicitly_enabled = True, - unique_token = secrets.token_urlsafe(16), - pattern = pattern, - eval_allowed = payload.allows_eval, - cors_bypass_allowed = payload.allows_cors_bypass - ) + for row in cursor.fetchall(): + (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, + enabled_status, + identifier) = row - payloads_data[payload_id] = data + payload_ref = ConcretePayloadRef(str(payload_id_int)) - key = st.PayloadKey( - payload_ref = ref, - mapping_identifier = mapping_info.identifier, - pattern = pattern - ) + previous_data = self.payloads_data.get(payload_ref) + if previous_data is not None: + token = previous_data.unique_token + else: + token = secrets.token_urlsafe(8) - self.policy_tree = register_payload( - self.policy_tree, - key, - data.unique_token + payload_key = st.PayloadKey(payload_ref, identifier) + + for parsed_pattern in url_patterns.parse_pattern(pattern): + new_policy_tree = register_payload( + new_policy_tree, + parsed_pattern, + payload_key, + token ) - self.payloads_data = payloads_data + pattern_path_segments = parsed_pattern.path_segments + + payload_data = st.PayloadData( + payload_ref = payload_ref, + #explicitly_enabled = enabled_status == 'E', + explicitly_enabled = True, + unique_token = token, + pattern_path_segments = pattern_path_segments, + eval_allowed = eval_allowed, + cors_bypass_allowed = cors_bypass_allowed + ) + + new_payloads_data[payload_ref] = payload_data + + self.policy_tree = new_policy_tree + self.payloads_data = new_payloads_data def get_repo(self, repo_id: str) -> st.RepoRef: return ConcreteRepoRef(repo_id) @@ -735,7 +837,12 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): @staticmethod def make(store_dir: Path) -> 'ConcreteHaketiloState': + connection = sqlite3.connect( + str(store_dir / 'sqlite3.db'), + isolation_level = None, + check_same_thread = False + ) return ConcreteHaketiloState( store_dir = store_dir, - connection = sqlite3.connect(str(store_dir / 'sqlite3.db')) + connection = connection ) diff --git a/src/hydrilla/proxy/tables.sql b/src/hydrilla/proxy/tables.sql index 25493d3..2a6cac6 100644 --- a/src/hydrilla/proxy/tables.sql +++ b/src/hydrilla/proxy/tables.sql @@ -121,15 +121,11 @@ CREATE TABLE mapping_statuses( -- "enabled" determines whether mapping's status is ENABLED, -- DISABLED or NO_MARK. enabled CHAR(1) NOT NULL, - enabled_version_id INTEGER NULL, -- "frozen" determines whether an enabled mapping is to be kept in its -- EXACT_VERSION, is to be updated only with versions from the same -- REPOSITORY or is NOT_FROZEN at all. frozen CHAR(1) NULL, - CHECK (NOT (enabled = 'D' AND enabled_version_id IS NOT NULL)), - CHECK (NOT (enabled = 'E' AND enabled_version_id IS NULL)), - CHECK ((frozen IS NULL) = (enabled != 'E')), CHECK (frozen IS NULL OR frozen in ('E', 'R', 'N')) ); @@ -143,8 +139,6 @@ CREATE TABLE item_versions( definition TEXT NOT NULL, UNIQUE (item_id, version, repo_iteration_id), - -- Allow foreign key from "mapping_statuses". - UNIQUE (item_version_id, item_id), FOREIGN KEY (item_id) REFERENCES items (item_id), @@ -226,6 +220,7 @@ CREATE TABLE file_uses( CHECK (type IN ('L', 'W')), UNIQUE(item_version_id, type, idx), + UNIQUE(item_version_id, type, name), FOREIGN KEY (item_version_id) REFERENCES item_versions(item_version_id) diff --git a/src/hydrilla/versions.py b/src/hydrilla/versions.py index 7474d98..93f395d 100644 --- a/src/hydrilla/versions.py +++ b/src/hydrilla/versions.py @@ -45,19 +45,19 @@ def normalize_version(ver: t.Sequence[int]) -> VerTuple: return VerTuple(tuple(ver[:new_len])) -def parse_version(ver_str: str) -> tuple[int, ...]: +def parse(ver_str: str) -> tuple[int, ...]: """ Convert 'ver_str' into an array representation, e.g. for ver_str="4.6.13.0" return [4, 6, 13, 0]. """ return tuple(int(num) for num in ver_str.split('.')) -def parse_normalize_version(ver_str: str) -> VerTuple: +def parse_normalize(ver_str: str) -> VerTuple: """ Convert 'ver_str' into a VerTuple representation, e.g. for ver_str="4.6.13.0" return (4, 6, 13). """ - return normalize_version(parse_version(ver_str)) + return normalize_version(parse(ver_str)) def version_string(ver: VerTuple, rev: t.Optional[int] = None) -> str: """ -- cgit v1.2.3