From 72fcc76cc75ccb7e180886170db01dae637e250e Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Wed, 10 Aug 2022 21:07:54 +0200 Subject: small clean up for item definitions handling before dependency resolution happens --- src/hydrilla/proxy/state_impl/concrete_state.py | 72 +++++++++++-------------- 1 file changed, 32 insertions(+), 40 deletions(-) (limited to 'src/hydrilla/proxy/state_impl') diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index cd6698c..1b46ae9 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -48,7 +48,7 @@ from ... import pattern_tree from ... import url_patterns from ... import versions from ... import item_infos -from ..simple_dependency_satisfying import ItemsCollection, ComputedPayload +from ..simple_dependency_satisfying import compute_payloads, ComputedPayload from .. import state as st from .. import policies from . import base @@ -56,8 +56,6 @@ from . import base here = Path(__file__).resolve().parent -AnyInfo = t.Union[item_infos.ResourceInfo, item_infos.MappingInfo] - @dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] class ConcreteRepoRef(st.RepoRef): def remove(self, state: st.HaketiloState) -> None: @@ -316,7 +314,7 @@ def get_or_make_item_version( item_id: int, repo_iteration_id: int, definition: str, - info: AnyInfo + info: item_infos.AnyInfo ) -> int: ver_str = versions.version_string(info.version) @@ -419,6 +417,31 @@ def make_file_use( (item_version_id, file_id, name, type, mime_type, idx) ) +def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar]) \ + -> t.Iterable[AnyInfoVar]: + cursor.execute( + ''' + SELECT + iv.definition, r.name, ri.iteration + FROM + item_versions AS iv + JOIN items AS i USING (item_id) + JOIN repo_iterations AS ri USING (repo_iteration_id) + JOIN repos AS r USING (repo_id) + WHERE + i.type = ?; + ''', + (info_type.type_name[0].upper(),) + ) + + result: list[AnyInfoVar] = [] + + for definition, repo_name, repo_iteration in cursor.fetchall(): + definition_io = io.StringIO(definition) + result.append(info_type.load(definition_io, repo_name, repo_iteration)) + + return result + @dc.dataclass class ConcreteHaketiloState(base.HaketiloStateWithFields): def __post_init__(self) -> None: @@ -474,7 +497,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): with self.cursor(lock=True, transaction=True) as cursor: for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]: - info: AnyInfo + info: item_infos.AnyInfo for info, definition in read_items( malcontent_dir_path, info_type # type: ignore @@ -551,45 +574,14 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): ) def rebuild_structures(self, cursor: sqlite3.Cursor) -> None: - cursor.execute( - ''' - SELECT - item_id, type, version, definition - FROM - item_versions JOIN items USING (item_id); - ''' - ) - - best_versions: dict[int, versions.VerTuple] = {} - definitions = {} - types = {} - - for item_id, item_type, ver_str, definition in cursor.fetchall(): - # TODO: what we're doing in this loop does not yet take different - # repos and different repo iterations into account. - ver = versions.parse_normalize_version(ver_str) - if best_versions.get(item_id, (0,)) < ver: - best_versions[item_id] = ver - definitions[item_id] = definition - types[item_id] = item_type - - resources = {} - mappings = {} - - for item_id, definition in definitions.items(): - if types[item_id] == 'R': - r_info = item_infos.ResourceInfo.load(io.StringIO(definition)) - resources[r_info.identifier] = r_info - else: - m_info = item_infos.MappingInfo.load(io.StringIO(definition)) - mappings[m_info.identifier] = m_info + resources = get_infos_of_type(cursor, item_infos.ResourceInfo) + mappings = get_infos_of_type(cursor, item_infos.MappingInfo) - items_collection = ItemsCollection(resources, mappings) - computed_payloads = items_collection.compute_payloads() + payloads = compute_payloads(resources, mappings) payloads_data = {} - for mapping_info, by_pattern in computed_payloads.items(): + for mapping_info, by_pattern in payloads.items(): for num, (pattern, payload) in enumerate(by_pattern.items()): payload_id = f'{num}@{mapping_info.identifier}' -- cgit v1.2.3