From 67c58a14f4f356117f42fea368a32359496d46c4 Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Thu, 11 Aug 2022 13:33:06 +0200 Subject: populate data structures based on payloads data loaded from sqlite db --- src/hydrilla/proxy/state_impl/base.py | 17 +- src/hydrilla/proxy/state_impl/concrete_state.py | 269 +++++++++++++++++------- 2 files changed, 193 insertions(+), 93 deletions(-) (limited to 'src/hydrilla/proxy/state_impl') diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py index 788a93d..1ae08cb 100644 --- a/src/hydrilla/proxy/state_impl/base.py +++ b/src/hydrilla/proxy/state_impl/base.py @@ -51,9 +51,7 @@ from .. import policies PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] - -#PayloadsDataMap = Map[state.PayloadRef, state.PayloadData] -DataById = t.Mapping[str, state.PayloadData] +PayloadsData = t.Mapping[state.PayloadRef, state.PayloadData] # mypy needs to be corrected: # https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 @@ -65,9 +63,8 @@ class HaketiloStateWithFields(state.HaketiloState): current_cursor: t.Optional[sqlite3.Cursor] = None #settings: state.HaketiloGlobalSettings - policy_tree: PolicyTree = PolicyTree() - #payloads_data: PayloadsDataMap = Map() - payloads_data: DataById = dc.field(default_factory=dict) + policy_tree: PolicyTree = PolicyTree() + payloads_data: PayloadsData = dc.field(default_factory=dict) lock: threading.RLock = dc.field(default_factory=threading.RLock) @@ -77,9 +74,7 @@ class HaketiloStateWithFields(state.HaketiloState): """....""" start_transaction = transaction and not self.connection.in_transaction - try: - self.lock.acquire() - + with self.lock: if self.current_cursor is not None: yield self.current_cursor return @@ -102,9 +97,7 @@ class HaketiloStateWithFields(state.HaketiloState): raise finally: self.current_cursor = None - finally: - self.lock.release() - def rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: """....""" ... diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index ccb1269..53f30ae 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -106,19 +106,52 @@ class ConcreteResourceVersionRef(st.ResourceVersionRef): @dc.dataclass(frozen=True, unsafe_hash=True) class ConcretePayloadRef(st.PayloadRef): - computed_payload: ComputedPayload = dc.field(hash=False, compare=False) - def get_data(self, state: st.HaketiloState) -> st.PayloadData: - return t.cast(ConcreteHaketiloState, state).payloads_data[self.id] + return t.cast(ConcreteHaketiloState, state).payloads_data[self] def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: return 'to implement' def get_script_paths(self, state: st.HaketiloState) \ - -> t.Iterator[t.Sequence[str]]: - for resource_info in self.computed_payload.resources: - for file_spec in resource_info.scripts: - yield (resource_info.identifier, *file_spec.name.split('/')) + -> t.Iterable[t.Sequence[str]]: + with t.cast(ConcreteHaketiloState, state).cursor() as cursor: + cursor.execute( + ''' + SELECT + i.identifier, fu.name + FROM + payloads AS p + LEFT JOIN resolved_depended_resources AS rdd + USING (payload_id) + LEFT JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + LEFT JOIN items AS i + USING (item_id) + LEFT JOIN file_uses AS fu + USING (item_version_id) + WHERE + fu.type = 'W' AND + p.payload_id = ? AND + (fu.idx IS NOT NULL OR rdd.idx IS NULL) + ORDER BY + rdd.idx, fu.idx; + ''', + (self.id,) + ) + + paths: list[t.Sequence[str]] = [] + for resource_identifier, file_name in cursor.fetchall(): + if resource_identifier is None: + # payload found but it had no script files + return () + + paths.append((resource_identifier, *file_name.split('/'))) + + if paths == []: + # payload not found + raise st.MissingItemError() + + return paths def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ -> t.Optional[st.FileData]: @@ -129,53 +162,109 @@ class ConcretePayloadRef(st.PayloadRef): file_name = '/'.join(file_name_segments) - script_sha256 = '' + with t.cast(ConcreteHaketiloState, state).cursor() as cursor: + cursor.execute( + ''' + SELECT + f.data, fu.mime_type + FROM + payloads AS p + JOIN resolved_depended_resources AS rdd + USING (payload_id) + JOIN item_versions AS iv + ON rdd.resource_item_id = iv.item_version_id + JOIN items AS i + USING (item_id) + JOIN file_uses AS fu + USING (item_version_id) + JOIN files AS f + USING (file_id) + WHERE + p.payload_id = ? AND + i.identifier = ? AND + fu.name = ? AND + fu.type = 'W'; + ''', + (self.id, resource_identifier, file_name) + ) + + result = cursor.fetchall() - matched_resource_info = False + if result == []: + return None - for resource_info in self.computed_payload.resources: - if resource_info.identifier == resource_identifier: - matched_resource_info = True + (data, mime_type), = result - for script_spec in resource_info.scripts: - if script_spec.name == file_name: - script_sha256 = script_spec.sha256 + return st.FileData(type=mime_type, name=file_name, contents=data) - break +# @dc.dataclass(frozen=True, unsafe_hash=True) +# class ConcretePayloadRef(st.PayloadRef): +# computed_payload: ComputedPayload = dc.field(hash=False, compare=False) - if not matched_resource_info: - raise st.MissingItemError(resource_identifier) +# def get_data(self, state: st.HaketiloState) -> st.PayloadData: +# return t.cast(ConcreteHaketiloState, state).payloads_data[self.id] - if script_sha256 == '': - return None +# def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: +# return 'to implement' - store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir - files_dir_path = store_dir_path / 'temporary_malcontent' / 'file' - file_path = files_dir_path / 'sha256' / script_sha256 +# def get_script_paths(self, state: st.HaketiloState) \ +# -> t.Iterator[t.Sequence[str]]: +# for resource_info in self.computed_payload.resources: +# for file_spec in resource_info.scripts: +# yield (resource_info.identifier, *file_spec.name.split('/')) - return st.FileData( - type = 'application/javascript', - name = file_name, - contents = file_path.read_bytes() - ) +# def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ +# -> t.Optional[st.FileData]: +# if len(path) == 0: +# raise st.MissingItemError() + +# resource_identifier, *file_name_segments = path + +# file_name = '/'.join(file_name_segments) + +# script_sha256 = '' -PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] +# matched_resource_info = False + +# for resource_info in self.computed_payload.resources: +# if resource_info.identifier == resource_identifier: +# matched_resource_info = True + +# for script_spec in resource_info.scripts: +# if script_spec.name == file_name: +# script_sha256 = script_spec.sha256 + +# break + +# if not matched_resource_info: +# raise st.MissingItemError(resource_identifier) + +# if script_sha256 == '': +# return None + +# store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir +# files_dir_path = store_dir_path / 'temporary_malcontent' / 'file' +# file_path = files_dir_path / 'sha256' / script_sha256 + +# return st.FileData( +# type = 'application/javascript', +# name = file_name, +# contents = file_path.read_bytes() +# ) def register_payload( - policy_tree: PolicyTree, + policy_tree: base.PolicyTree, + pattern: url_patterns.ParsedPattern, payload_key: st.PayloadKey, token: str -) -> PolicyTree: +) -> base.PolicyTree: """....""" payload_policy_factory = policies.PayloadPolicyFactory( builtin = False, payload_key = payload_key ) - policy_tree = policy_tree.register( - payload_key.pattern, - payload_policy_factory - ) + policy_tree = policy_tree.register(pattern, payload_policy_factory) resource_policy_factory = policies.PayloadResourcePolicyFactory( builtin = False, @@ -183,7 +272,7 @@ def register_payload( ) policy_tree = policy_tree.register( - payload_key.pattern.path_append(token, '***'), + pattern.path_append(token, '***'), resource_policy_factory ) @@ -197,24 +286,6 @@ AnyInfoVar = t.TypeVar( item_infos.MappingInfo ) -# def newest_item_path(item_dir: Path) -> t.Optional[Path]: -# available_versions = tuple( -# versions.parse_normalize_version(ver_path.name) -# for ver_path in item_dir.iterdir() -# if ver_path.is_file() -# ) - -# if available_versions == (): -# return None - -# newest_version = max(available_versions) - -# version_path = item_dir / versions.version_string(newest_version) - -# assert version_path.is_file() - -# return version_path - def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ -> t.Iterator[tuple[AnyInfoVar, str]]: item_type_path = malcontent_path / item_class.type_name @@ -440,7 +511,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): self._populate_database_with_stuff_from_temporary_malcontent_dir() with self.cursor(transaction=True) as cursor: - self.rebuild_structures(cursor) + self.recompute_payloads(cursor) def _prepare_database(self) -> None: """....""" @@ -568,7 +639,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): idx = idx ) - def rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: assert self.connection.in_transaction resources = get_infos_of_type(cursor, item_infos.ResourceInfo) @@ -576,13 +647,12 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): payloads = compute_payloads(resources.keys(), mappings.keys()) - payloads_data = {} + payloads_data: dict[st.PayloadRef, st.PayloadData] = {} cursor.execute('DELETE FROM payloads;') for mapping_info, by_pattern in payloads.items(): for num, (pattern, payload) in enumerate(by_pattern.items()): - print('adding payload') cursor.execute( ''' INSERT INTO payloads( @@ -628,35 +698,67 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): (payload_id_int, resources[resource_info], res_num) ) - payload_id = str(payload_id_int) + self._rebuild_structures(cursor) + + def _rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + cursor.execute( + ''' + SELECT + p.payload_id, p.pattern, p.eval_allowed, + p.cors_bypass_allowed, + ms.enabled, + i.identifier + FROM + payloads AS p + JOIN item_versions AS iv + ON p.mapping_item_id = iv.item_version_id + JOIN items AS i USING (item_id) + JOIN mapping_statuses AS ms USING (item_id); + ''' + ) - ref = ConcretePayloadRef(payload_id, payload) + new_policy_tree = base.PolicyTree() + new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - data = st.PayloadData( - payload_ref = ref, - mapping_installed = True, - explicitly_enabled = True, - unique_token = secrets.token_urlsafe(16), - pattern = pattern, - eval_allowed = payload.allows_eval, - cors_bypass_allowed = payload.allows_cors_bypass - ) + for row in cursor.fetchall(): + (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, + enabled_status, + identifier) = row - payloads_data[payload_id] = data + payload_ref = ConcretePayloadRef(str(payload_id_int)) - key = st.PayloadKey( - payload_ref = ref, - mapping_identifier = mapping_info.identifier, - pattern = pattern - ) + previous_data = self.payloads_data.get(payload_ref) + if previous_data is not None: + token = previous_data.unique_token + else: + token = secrets.token_urlsafe(8) - self.policy_tree = register_payload( - self.policy_tree, - key, - data.unique_token + payload_key = st.PayloadKey(payload_ref, identifier) + + for parsed_pattern in url_patterns.parse_pattern(pattern): + new_policy_tree = register_payload( + new_policy_tree, + parsed_pattern, + payload_key, + token ) - self.payloads_data = payloads_data + pattern_path_segments = parsed_pattern.path_segments + + payload_data = st.PayloadData( + payload_ref = payload_ref, + #explicitly_enabled = enabled_status == 'E', + explicitly_enabled = True, + unique_token = token, + pattern_path_segments = pattern_path_segments, + eval_allowed = eval_allowed, + cors_bypass_allowed = cors_bypass_allowed + ) + + new_payloads_data[payload_ref] = payload_data + + self.policy_tree = new_policy_tree + self.payloads_data = new_payloads_data def get_repo(self, repo_id: str) -> st.RepoRef: return ConcreteRepoRef(repo_id) @@ -735,7 +837,12 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): @staticmethod def make(store_dir: Path) -> 'ConcreteHaketiloState': + connection = sqlite3.connect( + str(store_dir / 'sqlite3.db'), + isolation_level = None, + check_same_thread = False + ) return ConcreteHaketiloState( store_dir = store_dir, - connection = sqlite3.connect(str(store_dir / 'sqlite3.db')) + connection = connection ) -- cgit v1.2.3