From 8238435825d01ad2ec1a11b6bcaf6d9a9aad5ab5 Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Mon, 22 Aug 2022 12:52:59 +0200 Subject: allow pulling packages from remote repository --- src/hydrilla/proxy/state_impl/repos.py | 135 ++++++++++++++++++++++++++++++--- 1 file changed, 125 insertions(+), 10 deletions(-) (limited to 'src/hydrilla/proxy/state_impl/repos.py') diff --git a/src/hydrilla/proxy/state_impl/repos.py b/src/hydrilla/proxy/state_impl/repos.py index 5553ec2..f4c7c71 100644 --- a/src/hydrilla/proxy/state_impl/repos.py +++ b/src/hydrilla/proxy/state_impl/repos.py @@ -33,20 +33,27 @@ inside Haketilo. from __future__ import annotations import re +import json +import tempfile +import requests +import sqlite3 import typing as t import dataclasses as dc -from urllib.parse import urlparse +from urllib.parse import urlparse, urljoin from datetime import datetime +from pathlib import Path -import sqlite3 - +from ... import json_instances +from ... import item_infos +from ... import versions from .. import state as st +from .. import simple_dependency_satisfying as sds from . import base -from . import prune_packages +from . import _operations -def validate_repo_url(url: str) -> None: +def sanitize_repo_url(url: str) -> str: try: parsed = urlparse(url) except: @@ -55,6 +62,11 @@ def validate_repo_url(url: str) -> None: if parsed.scheme not in ('http', 'https'): raise st.RepoUrlInvalid() + if url[-1] != '/': + url = url + '/' + + return url + def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None: cursor.execute( @@ -73,6 +85,53 @@ def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None: raise st.MissingItemError() +def sync_remote_repo_definitions(repo_url: str, dest: Path) -> None: + try: + list_all_response = requests.get(urljoin(repo_url, 'list_all')) + assert list_all_response.ok + + list_instance = list_all_response.json() + except: + raise st.RepoCommunicationError() + + try: + json_instances.validate_instance( + list_instance, + 'api_package_list-{}.schema.json' + ) + except json_instances.UnknownSchemaError: + raise st.RepoApiVersionUnsupported() + except: + raise st.RepoCommunicationError() + + ref: dict[str, t.Any] + + for item_type_name in ('resource', 'mapping'): + for ref in list_instance[item_type_name + 's']: + ver = versions.version_string(versions.normalize(ref['version'])) + item_rel_path = f'{item_type_name}/{ref["identifier"]}/{ver}' + + try: + item_response = requests.get(urljoin(repo_url, item_rel_path)) + assert item_response.ok + except: + raise st.RepoCommunicationError() + + item_path = dest / item_rel_path + item_path.parent.mkdir(parents=True, exist_ok=True) + item_path.write_bytes(item_response.content) + + +@dc.dataclass(frozen=True) +class RemoteFileResolver(_operations.FileResolver): + repo_url: str + + def by_sha256(self, sha256: str) -> bytes: + response = requests.get(urljoin(self.repo_url, f'file/sha256/{sha256}')) + assert response.ok + return response.content + + def make_repo_display_info( ref: st.RepoRef, name: str, @@ -122,6 +181,37 @@ class ConcreteRepoRef(st.RepoRef): (self.id,) ) + _operations.prune_packages(cursor) + + # For mappings explicitly enabled by the user (+ all mappings they + # recursively depend on) let's make sure that their exact same + # versions will be enabled after the change. + cursor.execute( + ''' + SELECT + iv.definition, r.name, ri.iteration + FROM + mapping_statuses AS ms + JOIN item_versions AS iv + ON ms.active_version_id = iv.item_version_id + JOIN repo_iterations AS ri + USING (repo_iteration_id) + JOIN repos AS r + USING (repo_id) + WHERE + ms.required + ''' + ) + + requirements = [] + + for definition, repo, iteration in cursor.fetchall(): + info = item_infos.MappingInfo.load(definition, repo, iteration) + req = sds.MappingVersionRequirement(info.identifier, info) + requirements.append(req) + + self.state.recompute_dependencies(requirements) + def update( self, *, @@ -134,7 +224,7 @@ class ConcreteRepoRef(st.RepoRef): if url is None: return - validate_repo_url(url) + url = sanitize_repo_url(url) with self.state.cursor(transaction=True) as cursor: ensure_repo_not_deleted(cursor, self.id) @@ -144,10 +234,30 @@ class ConcreteRepoRef(st.RepoRef): (url, self.id) ) - prune_packages.prune(cursor) - def refresh(self) -> st.RepoIterationRef: - raise NotImplementedError() + with self.state.cursor(transaction=True) as cursor: + ensure_repo_not_deleted(cursor, self.id) + + cursor.execute( + 'SELECT url from repos where repo_id = ?;', + (self.id,) + ) + + (repo_url,), = cursor.fetchall() + + with tempfile.TemporaryDirectory() as tmpdir_str: + tmpdir = Path(tmpdir_str) + sync_remote_repo_definitions(repo_url, tmpdir) + new_iteration_id = _operations.load_packages( + cursor, + tmpdir, + int(self.id), + RemoteFileResolver(repo_url) + ) + + self.state.recompute_dependencies() + + return ConcreteRepoIterationRef(str(new_iteration_id), self.state) def get_display_info(self) -> st.RepoDisplayInfo: with self.state.cursor() as cursor: @@ -199,7 +309,7 @@ class ConcreteRepoStore(st.RepoStore): if repo_name_regex.match(name) is None: raise st.RepoNameInvalid() - validate_repo_url(url) + url = sanitize_repo_url(url) with self.state.cursor(transaction=True) as cursor: cursor.execute( @@ -272,3 +382,8 @@ class ConcreteRepoStore(st.RepoStore): result.append(make_repo_display_info(ref, *rest)) return result + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteRepoIterationRef(st.RepoIterationRef): + state: base.HaketiloStateWithFields -- cgit v1.2.3