From f0044a21ea7bbabb633057804e83df884196012b Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Wed, 31 Aug 2022 09:25:40 +0200 Subject: [proxy] make sure that dependency tree recomputation by default activates the same resources that were marked as required before --- .../_operations/recompute_dependencies.py | 74 +++++++++++++++++----- 1 file changed, 57 insertions(+), 17 deletions(-) (limited to 'src/hydrilla/proxy/state_impl/_operations') diff --git a/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py b/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py index 2ec3600..5403ec3 100644 --- a/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py +++ b/src/hydrilla/proxy/state_impl/_operations/recompute_dependencies.py @@ -78,7 +78,7 @@ def _get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \ def _get_current_required_state( cursor: sqlite3.Cursor, unlocked_required_mappings: t.Sequence[int] -) -> list[sds.MappingRequirement]: +) -> tuple[list[sds.MappingRequirement], list[sds.ResourceVersionRequirement]]: # For mappings explicitly enabled by the user (+ all mappings they # recursively depend on) let's make sure that their exact same versions will # be enabled after the change. Make exception for mappings specified by the @@ -88,6 +88,7 @@ def _get_current_required_state( ids = unlocked_required_mappings, table_name = '__unlocked_ids' ): + # Describe all required mappings using requirement objects. cursor.execute( ''' SELECT @@ -101,14 +102,52 @@ def _get_current_required_state( rows = cursor.fetchall() - requirements: list[sds.MappingRequirement] = [] + mapping_requirements: list[sds.MappingRequirement] = [] - for definition, repo, iteration in rows: - info = item_infos.MappingInfo.load(definition, repo, iteration) - req = sds.MappingVersionRequirement(info.identifier, info) - requirements.append(req) + for definition, repo, iteration in rows: + mapping_info = \ + item_infos.MappingInfo.load(definition, repo, iteration) + mapping_req = sds.MappingVersionRequirement( + identifier = mapping_info.identifier, + version_info = mapping_info + ) + mapping_requirements.append(mapping_req) + + # Describe all required resources using requirement objects. + cursor.execute( + ''' + SELECT + i_m.identifier, + ive_r.definition, ive_r.repo, ive_r.repo_iteration + FROM + resolved_depended_resources AS rdd + JOIN item_versions_extra AS ive_r + ON rdd.resource_item_id = ive_r.item_version_id + JOIN payloads AS p + USING (payload_id) + JOIN item_versions AS iv_m + ON p.mapping_item_id = iv_m.item_version_id + JOIN items AS i_m + ON iv_m.item_id = i_m.item_id + WHERE + i_m.item_id NOT IN __unlocked_ids AND iv_m.active = 'R'; + ''', + ) + + rows = cursor.fetchall() + + resource_requirements: list[sds.ResourceVersionRequirement] = [] - return requirements + for mapping_identifier, definition, repo, iteration in rows: + resource_info = \ + item_infos.ResourceInfo.load(definition, repo, iteration) + resource_req = sds.ResourceVersionRequirement( + mapping_identifier = mapping_identifier, + version_info = resource_info + ) + resource_requirements.append(resource_req) + + return (mapping_requirements, resource_requirements) def _mark_version_installed(cursor: sqlite3.Cursor, version_id: int) -> None: cursor.execute( @@ -135,13 +174,13 @@ def _recompute_dependencies_no_state_update_no_pull_files( resources_to_ids = dict((info, id) for id, info in ids_to_resources.items()) mappings_to_ids = dict((info, id) for id, info in ids_to_mappings.items()) - if unlocked_required_mappings == 'all_mappings_unlocked': - requirements = [] - else: - requirements = _get_current_required_state( + if unlocked_required_mappings != 'all_mappings_unlocked': + mapping_reqs, resource_reqs = _get_current_required_state( cursor = cursor, unlocked_required_mappings = unlocked_required_mappings ) + else: + mapping_reqs, resource_reqs = [], [] cursor.execute( ''' @@ -156,7 +195,7 @@ def _recompute_dependencies_no_state_update_no_pull_files( ) for mapping_identifier, in cursor.fetchall(): - requirements.append(sds.MappingRequirement(mapping_identifier)) + mapping_reqs.append(sds.MappingRequirement(mapping_identifier)) cursor.execute( ''' @@ -179,12 +218,13 @@ def _recompute_dependencies_no_state_update_no_pull_files( else: requirement = sds.MappingVersionRequirement(info.identifier, info) - requirements.append(requirement) + mapping_reqs.append(requirement) mapping_choices = sds.compute_payloads( - ids_to_resources.values(), - ids_to_mappings.values(), - requirements + resources = ids_to_resources.values(), + mappings = ids_to_mappings.values(), + mapping_requirements = mapping_reqs, + resource_requirements = resource_reqs ) cursor.execute( @@ -243,7 +283,7 @@ def _recompute_dependencies_no_state_update_no_pull_files( WHERE item_version_id = ?; ''', - (mapping_ver_id, 'R' if choice.required else 'A') + ('R' if choice.required else 'A', mapping_ver_id) ) for num, (pattern, payload) in enumerate(choice.payloads.items()): -- cgit v1.2.3