aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/state_impl/concrete_state.py
diff options
context:
space:
mode:
authorWojtek Kosior <koszko@koszko.org>2022-08-11 13:33:06 +0200
committerWojtek Kosior <koszko@koszko.org>2022-08-11 13:33:06 +0200
commit67c58a14f4f356117f42fea368a32359496d46c4 (patch)
treebb8d1019dc4547a215404a40b60a031ca4f2e21d /src/hydrilla/proxy/state_impl/concrete_state.py
parent3f3ba519ae3c3346945928b21ab36f7238e5387e (diff)
downloadhaketilo-hydrilla-67c58a14f4f356117f42fea368a32359496d46c4.tar.gz
haketilo-hydrilla-67c58a14f4f356117f42fea368a32359496d46c4.zip
populate data structures based on payloads data loaded from sqlite db
Diffstat (limited to 'src/hydrilla/proxy/state_impl/concrete_state.py')
-rw-r--r--src/hydrilla/proxy/state_impl/concrete_state.py269
1 files changed, 188 insertions, 81 deletions
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py
index ccb1269..53f30ae 100644
--- a/src/hydrilla/proxy/state_impl/concrete_state.py
+++ b/src/hydrilla/proxy/state_impl/concrete_state.py
@@ -106,19 +106,52 @@ class ConcreteResourceVersionRef(st.ResourceVersionRef):
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcretePayloadRef(st.PayloadRef):
- computed_payload: ComputedPayload = dc.field(hash=False, compare=False)
-
def get_data(self, state: st.HaketiloState) -> st.PayloadData:
- return t.cast(ConcreteHaketiloState, state).payloads_data[self.id]
+ return t.cast(ConcreteHaketiloState, state).payloads_data[self]
def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
return 'to implement'
def get_script_paths(self, state: st.HaketiloState) \
- -> t.Iterator[t.Sequence[str]]:
- for resource_info in self.computed_payload.resources:
- for file_spec in resource_info.scripts:
- yield (resource_info.identifier, *file_spec.name.split('/'))
+ -> t.Iterable[t.Sequence[str]]:
+ with t.cast(ConcreteHaketiloState, state).cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ i.identifier, fu.name
+ FROM
+ payloads AS p
+ LEFT JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ LEFT JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ LEFT JOIN items AS i
+ USING (item_id)
+ LEFT JOIN file_uses AS fu
+ USING (item_version_id)
+ WHERE
+ fu.type = 'W' AND
+ p.payload_id = ? AND
+ (fu.idx IS NOT NULL OR rdd.idx IS NULL)
+ ORDER BY
+ rdd.idx, fu.idx;
+ ''',
+ (self.id,)
+ )
+
+ paths: list[t.Sequence[str]] = []
+ for resource_identifier, file_name in cursor.fetchall():
+ if resource_identifier is None:
+ # payload found but it had no script files
+ return ()
+
+ paths.append((resource_identifier, *file_name.split('/')))
+
+ if paths == []:
+ # payload not found
+ raise st.MissingItemError()
+
+ return paths
def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
-> t.Optional[st.FileData]:
@@ -129,53 +162,109 @@ class ConcretePayloadRef(st.PayloadRef):
file_name = '/'.join(file_name_segments)
- script_sha256 = ''
+ with t.cast(ConcreteHaketiloState, state).cursor() as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ f.data, fu.mime_type
+ FROM
+ payloads AS p
+ JOIN resolved_depended_resources AS rdd
+ USING (payload_id)
+ JOIN item_versions AS iv
+ ON rdd.resource_item_id = iv.item_version_id
+ JOIN items AS i
+ USING (item_id)
+ JOIN file_uses AS fu
+ USING (item_version_id)
+ JOIN files AS f
+ USING (file_id)
+ WHERE
+ p.payload_id = ? AND
+ i.identifier = ? AND
+ fu.name = ? AND
+ fu.type = 'W';
+ ''',
+ (self.id, resource_identifier, file_name)
+ )
+
+ result = cursor.fetchall()
- matched_resource_info = False
+ if result == []:
+ return None
- for resource_info in self.computed_payload.resources:
- if resource_info.identifier == resource_identifier:
- matched_resource_info = True
+ (data, mime_type), = result
- for script_spec in resource_info.scripts:
- if script_spec.name == file_name:
- script_sha256 = script_spec.sha256
+ return st.FileData(type=mime_type, name=file_name, contents=data)
- break
+# @dc.dataclass(frozen=True, unsafe_hash=True)
+# class ConcretePayloadRef(st.PayloadRef):
+# computed_payload: ComputedPayload = dc.field(hash=False, compare=False)
- if not matched_resource_info:
- raise st.MissingItemError(resource_identifier)
+# def get_data(self, state: st.HaketiloState) -> st.PayloadData:
+# return t.cast(ConcreteHaketiloState, state).payloads_data[self.id]
- if script_sha256 == '':
- return None
+# def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
+# return 'to implement'
- store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir
- files_dir_path = store_dir_path / 'temporary_malcontent' / 'file'
- file_path = files_dir_path / 'sha256' / script_sha256
+# def get_script_paths(self, state: st.HaketiloState) \
+# -> t.Iterator[t.Sequence[str]]:
+# for resource_info in self.computed_payload.resources:
+# for file_spec in resource_info.scripts:
+# yield (resource_info.identifier, *file_spec.name.split('/'))
- return st.FileData(
- type = 'application/javascript',
- name = file_name,
- contents = file_path.read_bytes()
- )
+# def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
+# -> t.Optional[st.FileData]:
+# if len(path) == 0:
+# raise st.MissingItemError()
+
+# resource_identifier, *file_name_segments = path
+
+# file_name = '/'.join(file_name_segments)
+
+# script_sha256 = ''
-PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
+# matched_resource_info = False
+
+# for resource_info in self.computed_payload.resources:
+# if resource_info.identifier == resource_identifier:
+# matched_resource_info = True
+
+# for script_spec in resource_info.scripts:
+# if script_spec.name == file_name:
+# script_sha256 = script_spec.sha256
+
+# break
+
+# if not matched_resource_info:
+# raise st.MissingItemError(resource_identifier)
+
+# if script_sha256 == '':
+# return None
+
+# store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir
+# files_dir_path = store_dir_path / 'temporary_malcontent' / 'file'
+# file_path = files_dir_path / 'sha256' / script_sha256
+
+# return st.FileData(
+# type = 'application/javascript',
+# name = file_name,
+# contents = file_path.read_bytes()
+# )
def register_payload(
- policy_tree: PolicyTree,
+ policy_tree: base.PolicyTree,
+ pattern: url_patterns.ParsedPattern,
payload_key: st.PayloadKey,
token: str
-) -> PolicyTree:
+) -> base.PolicyTree:
"""...."""
payload_policy_factory = policies.PayloadPolicyFactory(
builtin = False,
payload_key = payload_key
)
- policy_tree = policy_tree.register(
- payload_key.pattern,
- payload_policy_factory
- )
+ policy_tree = policy_tree.register(pattern, payload_policy_factory)
resource_policy_factory = policies.PayloadResourcePolicyFactory(
builtin = False,
@@ -183,7 +272,7 @@ def register_payload(
)
policy_tree = policy_tree.register(
- payload_key.pattern.path_append(token, '***'),
+ pattern.path_append(token, '***'),
resource_policy_factory
)
@@ -197,24 +286,6 @@ AnyInfoVar = t.TypeVar(
item_infos.MappingInfo
)
-# def newest_item_path(item_dir: Path) -> t.Optional[Path]:
-# available_versions = tuple(
-# versions.parse_normalize_version(ver_path.name)
-# for ver_path in item_dir.iterdir()
-# if ver_path.is_file()
-# )
-
-# if available_versions == ():
-# return None
-
-# newest_version = max(available_versions)
-
-# version_path = item_dir / versions.version_string(newest_version)
-
-# assert version_path.is_file()
-
-# return version_path
-
def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
-> t.Iterator[tuple[AnyInfoVar, str]]:
item_type_path = malcontent_path / item_class.type_name
@@ -440,7 +511,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
self._populate_database_with_stuff_from_temporary_malcontent_dir()
with self.cursor(transaction=True) as cursor:
- self.rebuild_structures(cursor)
+ self.recompute_payloads(cursor)
def _prepare_database(self) -> None:
"""...."""
@@ -568,7 +639,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
idx = idx
)
- def rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
+ def recompute_payloads(self, cursor: sqlite3.Cursor) -> None:
assert self.connection.in_transaction
resources = get_infos_of_type(cursor, item_infos.ResourceInfo)
@@ -576,13 +647,12 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
payloads = compute_payloads(resources.keys(), mappings.keys())
- payloads_data = {}
+ payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
cursor.execute('DELETE FROM payloads;')
for mapping_info, by_pattern in payloads.items():
for num, (pattern, payload) in enumerate(by_pattern.items()):
- print('adding payload')
cursor.execute(
'''
INSERT INTO payloads(
@@ -628,35 +698,67 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
(payload_id_int, resources[resource_info], res_num)
)
- payload_id = str(payload_id_int)
+ self._rebuild_structures(cursor)
+
+ def _rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
+ cursor.execute(
+ '''
+ SELECT
+ p.payload_id, p.pattern, p.eval_allowed,
+ p.cors_bypass_allowed,
+ ms.enabled,
+ i.identifier
+ FROM
+ payloads AS p
+ JOIN item_versions AS iv
+ ON p.mapping_item_id = iv.item_version_id
+ JOIN items AS i USING (item_id)
+ JOIN mapping_statuses AS ms USING (item_id);
+ '''
+ )
- ref = ConcretePayloadRef(payload_id, payload)
+ new_policy_tree = base.PolicyTree()
+ new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
- data = st.PayloadData(
- payload_ref = ref,
- mapping_installed = True,
- explicitly_enabled = True,
- unique_token = secrets.token_urlsafe(16),
- pattern = pattern,
- eval_allowed = payload.allows_eval,
- cors_bypass_allowed = payload.allows_cors_bypass
- )
+ for row in cursor.fetchall():
+ (payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
+ enabled_status,
+ identifier) = row
- payloads_data[payload_id] = data
+ payload_ref = ConcretePayloadRef(str(payload_id_int))
- key = st.PayloadKey(
- payload_ref = ref,
- mapping_identifier = mapping_info.identifier,
- pattern = pattern
- )
+ previous_data = self.payloads_data.get(payload_ref)
+ if previous_data is not None:
+ token = previous_data.unique_token
+ else:
+ token = secrets.token_urlsafe(8)
- self.policy_tree = register_payload(
- self.policy_tree,
- key,
- data.unique_token
+ payload_key = st.PayloadKey(payload_ref, identifier)
+
+ for parsed_pattern in url_patterns.parse_pattern(pattern):
+ new_policy_tree = register_payload(
+ new_policy_tree,
+ parsed_pattern,
+ payload_key,
+ token
)
- self.payloads_data = payloads_data
+ pattern_path_segments = parsed_pattern.path_segments
+
+ payload_data = st.PayloadData(
+ payload_ref = payload_ref,
+ #explicitly_enabled = enabled_status == 'E',
+ explicitly_enabled = True,
+ unique_token = token,
+ pattern_path_segments = pattern_path_segments,
+ eval_allowed = eval_allowed,
+ cors_bypass_allowed = cors_bypass_allowed
+ )
+
+ new_payloads_data[payload_ref] = payload_data
+
+ self.policy_tree = new_policy_tree
+ self.payloads_data = new_payloads_data
def get_repo(self, repo_id: str) -> st.RepoRef:
return ConcreteRepoRef(repo_id)
@@ -735,7 +837,12 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields):
@staticmethod
def make(store_dir: Path) -> 'ConcreteHaketiloState':
+ connection = sqlite3.connect(
+ str(store_dir / 'sqlite3.db'),
+ isolation_level = None,
+ check_same_thread = False
+ )
return ConcreteHaketiloState(
store_dir = store_dir,
- connection = sqlite3.connect(str(store_dir / 'sqlite3.db'))
+ connection = connection
)