diff options
author | Wojtek Kosior <koszko@koszko.org> | 2022-08-17 13:50:34 +0200 |
---|---|---|
committer | Wojtek Kosior <koszko@koszko.org> | 2022-08-17 13:50:34 +0200 |
commit | c287cc01cf26eb9af1d34b03d0b57716d9976da4 (patch) | |
tree | c3d16d6c72ccf55d0dbbc239a88fa28c30ec7d0a /src/hydrilla/proxy/state_impl/concrete_state.py | |
parent | 2c98d04e4d4a344dc04a481b039a235678f7848e (diff) | |
download | haketilo-hydrilla-c287cc01cf26eb9af1d34b03d0b57716d9976da4.tar.gz haketilo-hydrilla-c287cc01cf26eb9af1d34b03d0b57716d9976da4.zip |
allow loading packages from zip files through web UI and listing installed mappings
Diffstat (limited to 'src/hydrilla/proxy/state_impl/concrete_state.py')
-rw-r--r-- | src/hydrilla/proxy/state_impl/concrete_state.py | 402 |
1 files changed, 36 insertions, 366 deletions
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index bb14734..b2b1033 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -46,12 +46,13 @@ from ...exceptions import HaketiloException from ...translations import smart_gettext as _ from ... import pattern_tree from ... import url_patterns -from ... import versions from ... import item_infos from ..simple_dependency_satisfying import compute_payloads, ComputedPayload from .. import state as st from .. import policies from . import base +from . import mappings +from .load_packages import load_packages here = Path(__file__).resolve().parent @@ -80,21 +81,6 @@ class ConcreteRepoIterationRef(st.RepoIterationRef): @dc.dataclass(frozen=True, unsafe_hash=True) -class ConcreteMappingRef(st.MappingRef): - def disable(self, state: st.HaketiloState) -> None: - raise NotImplementedError() - - def forget_enabled(self, state: st.HaketiloState) -> None: - raise NotImplementedError() - - -@dc.dataclass(frozen=True, unsafe_hash=True) -class ConcreteMappingVersionRef(st.MappingVersionRef): - def enable(self, state: st.HaketiloState) -> None: - raise NotImplementedError() - - -@dc.dataclass(frozen=True, unsafe_hash=True) class ConcreteResourceRef(st.ResourceRef): pass @@ -106,15 +92,20 @@ class ConcreteResourceVersionRef(st.ResourceVersionRef): @dc.dataclass(frozen=True, unsafe_hash=True) class ConcretePayloadRef(st.PayloadRef): - def get_data(self, state: st.HaketiloState) -> st.PayloadData: - return t.cast(ConcreteHaketiloState, state).payloads_data[self] + state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False) - def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: - return 'to implement' + def get_data(self) -> st.PayloadData: + try: + return self.state.payloads_data[self] + except KeyError: + raise st.MissingItemError() + + def get_mapping(self) -> st.MappingVersionRef: + raise NotImplementedError() - def get_script_paths(self, state: st.HaketiloState) \ + def get_script_paths(self) \ -> t.Iterable[t.Sequence[str]]: - with t.cast(ConcreteHaketiloState, state).cursor() as cursor: + with self.state.cursor() as cursor: cursor.execute( ''' SELECT @@ -153,7 +144,7 @@ class ConcretePayloadRef(st.PayloadRef): return paths - def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ + def get_file_data(self, path: t.Sequence[str]) \ -> t.Optional[st.FileData]: if len(path) == 0: raise st.MissingItemError() @@ -162,7 +153,7 @@ class ConcretePayloadRef(st.PayloadRef): file_name = '/'.join(file_name_segments) - with t.cast(ConcreteHaketiloState, state).cursor() as cursor: + with self.state.cursor() as cursor: cursor.execute( ''' SELECT @@ -197,61 +188,6 @@ class ConcretePayloadRef(st.PayloadRef): return st.FileData(type=mime_type, name=file_name, contents=data) -# @dc.dataclass(frozen=True, unsafe_hash=True) -# class ConcretePayloadRef(st.PayloadRef): -# computed_payload: ComputedPayload = dc.field(hash=False, compare=False) - -# def get_data(self, state: st.HaketiloState) -> st.PayloadData: -# return t.cast(ConcreteHaketiloState, state).payloads_data[self.id] - -# def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: -# return 'to implement' - -# def get_script_paths(self, state: st.HaketiloState) \ -# -> t.Iterator[t.Sequence[str]]: -# for resource_info in self.computed_payload.resources: -# for file_spec in resource_info.scripts: -# yield (resource_info.identifier, *file_spec.name.split('/')) - -# def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ -# -> t.Optional[st.FileData]: -# if len(path) == 0: -# raise st.MissingItemError() - -# resource_identifier, *file_name_segments = path - -# file_name = '/'.join(file_name_segments) - -# script_sha256 = '' - -# matched_resource_info = False - -# for resource_info in self.computed_payload.resources: -# if resource_info.identifier == resource_identifier: -# matched_resource_info = True - -# for script_spec in resource_info.scripts: -# if script_spec.name == file_name: -# script_sha256 = script_spec.sha256 - -# break - -# if not matched_resource_info: -# raise st.MissingItemError(resource_identifier) - -# if script_sha256 == '': -# return None - -# store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir -# files_dir_path = store_dir_path / 'temporary_malcontent' / 'file' -# file_path = files_dir_path / 'sha256' / script_sha256 - -# return st.FileData( -# type = 'application/javascript', -# name = file_name, -# contents = file_path.read_bytes() -# ) - def register_payload( policy_tree: base.PolicyTree, pattern: url_patterns.ParsedPattern, @@ -278,205 +214,12 @@ def register_payload( return policy_tree -DataById = t.Mapping[str, st.PayloadData] - AnyInfoVar = t.TypeVar( 'AnyInfoVar', item_infos.ResourceInfo, item_infos.MappingInfo ) -def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ - -> t.Iterator[tuple[AnyInfoVar, str]]: - item_type_path = malcontent_path / item_class.type_name - if not item_type_path.is_dir(): - return - - for item_path in item_type_path.iterdir(): - if not item_path.is_dir(): - continue - - for item_version_path in item_path.iterdir(): - definition = item_version_path.read_text() - item_info = item_class.load(io.StringIO(definition)) - - assert item_info.identifier == item_path.name - assert versions.version_string(item_info.version) == \ - item_version_path.name - - yield item_info, definition - -def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int: - cursor.execute( - ''' - INSERT OR IGNORE INTO repos(name, url, deleted, next_iteration) - VALUES(?, '<dummy_url>', TRUE, 2); - ''', - (repo_name,) - ) - - cursor.execute( - ''' - SELECT - repo_id, next_iteration - 1 - FROM - repos - WHERE - name = ?; - ''', - (repo_name,) - ) - - (repo_id, last_iteration), = cursor.fetchall() - - cursor.execute( - ''' - INSERT OR IGNORE INTO repo_iterations(repo_id, iteration) - VALUES(?, ?); - ''', - (repo_id, last_iteration) - ) - - cursor.execute( - ''' - SELECT - repo_iteration_id - FROM - repo_iterations - WHERE - repo_id = ? AND iteration = ?; - ''', - (repo_id, last_iteration) - ) - - (repo_iteration_id,), = cursor.fetchall() - - return repo_iteration_id - -def get_or_make_item(cursor: sqlite3.Cursor, type: str, identifier: str) -> int: - type_letter = {'resource': 'R', 'mapping': 'M'}[type] - - cursor.execute( - ''' - INSERT OR IGNORE INTO items(type, identifier) - VALUES(?, ?); - ''', - (type_letter, identifier) - ) - - cursor.execute( - ''' - SELECT - item_id - FROM - items - WHERE - type = ? AND identifier = ?; - ''', - (type_letter, identifier) - ) - - (item_id,), = cursor.fetchall() - - return item_id - -def get_or_make_item_version( - cursor: sqlite3.Cursor, - item_id: int, - repo_iteration_id: int, - version: versions.VerTuple, - definition: str -) -> int: - ver_str = versions.version_string(version) - - cursor.execute( - ''' - INSERT OR IGNORE INTO item_versions( - item_id, - version, - repo_iteration_id, - definition - ) - VALUES(?, ?, ?, ?); - ''', - (item_id, ver_str, repo_iteration_id, definition) - ) - - cursor.execute( - ''' - SELECT - item_version_id - FROM - item_versions - WHERE - item_id = ? AND version = ? AND repo_iteration_id = ?; - ''', - (item_id, ver_str, repo_iteration_id) - ) - - (item_version_id,), = cursor.fetchall() - - return item_version_id - -def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None: - cursor.execute( - ''' - INSERT OR IGNORE INTO mapping_statuses(item_id, enabled) - VALUES(?, 'N'); - ''', - (item_id,) - ) - -def get_or_make_file(cursor: sqlite3.Cursor, sha256: str, file_bytes: bytes) \ - -> int: - cursor.execute( - ''' - INSERT OR IGNORE INTO files(sha256, data) - VALUES(?, ?) - ''', - (sha256, file_bytes) - ) - - cursor.execute( - ''' - SELECT - file_id - FROM - files - WHERE - sha256 = ?; - ''', - (sha256,) - ) - - (file_id,), = cursor.fetchall() - - return file_id - -def make_file_use( - cursor: sqlite3.Cursor, - item_version_id: int, - file_id: int, - name: str, - type: str, - mime_type: str, - idx: int -) -> None: - cursor.execute( - ''' - INSERT OR IGNORE INTO file_uses( - item_version_id, - file_id, - name, - type, - mime_type, - idx - ) - VALUES(?, ?, ?, ?, ?, ?); - ''', - (item_version_id, file_id, name, type, mime_type, idx) - ) - def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \ -> t.Mapping[AnyInfoVar, int]: cursor.execute( @@ -508,10 +251,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): def __post_init__(self) -> None: self._prepare_database() - self._populate_database_with_stuff_from_temporary_malcontent_dir() - - with self.cursor(transaction=True) as cursor: - self.recompute_payloads(cursor) + self._rebuild_structures() def _prepare_database(self) -> None: """....""" @@ -546,7 +286,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): (db_haketilo_version,) = cursor.fetchone() if db_haketilo_version != '3.0b1': - raise HaketiloException(_('err.unknown_db_schema')) + raise HaketiloException(_('err.proxy.unknown_db_schema')) cursor.execute('PRAGMA FOREIGN_KEYS;') if cursor.fetchall() == []: @@ -556,88 +296,10 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): finally: cursor.close() - def _populate_database_with_stuff_from_temporary_malcontent_dir(self) \ - -> None: - malcontent_dir_path = self.store_dir / 'temporary_malcontent' - files_by_sha256_path = malcontent_dir_path / 'file' / 'sha256' - + def import_packages(self, malcontent_path: Path) -> None: with self.cursor(transaction=True) as cursor: - for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]: - info: item_infos.AnyInfo - for info, definition in read_items( - malcontent_dir_path, - info_type # type: ignore - ): - repo_iteration_id = get_or_make_repo_iteration( - cursor, - info.repo - ) - - item_id = get_or_make_item( - cursor, - info.type_name, - info.identifier - ) - - item_version_id = get_or_make_item_version( - cursor, - item_id, - repo_iteration_id, - info.version, - definition - ) - - if info_type is item_infos.MappingInfo: - make_mapping_status(cursor, item_id) - - file_ids_bytes = {} - - file_specifiers = [*info.source_copyright] - if isinstance(info, item_infos.ResourceInfo): - file_specifiers.extend(info.scripts) - - for file_spec in file_specifiers: - file_path = files_by_sha256_path / file_spec.sha256 - file_bytes = file_path.read_bytes() - - sha256 = hashlib.sha256(file_bytes).digest().hex() - assert sha256 == file_spec.sha256 - - file_id = get_or_make_file(cursor, sha256, file_bytes) - - file_ids_bytes[sha256] = (file_id, file_bytes) - - for idx, file_spec in enumerate(info.source_copyright): - file_id, file_bytes = file_ids_bytes[file_spec.sha256] - if file_bytes.isascii(): - mime = 'text/plain' - else: - mime = 'application/octet-stream' - - make_file_use( - cursor, - item_version_id = item_version_id, - file_id = file_id, - name = file_spec.name, - type = 'L', - mime_type = mime, - idx = idx - ) - - if isinstance(info, item_infos.MappingInfo): - continue - - for idx, file_spec in enumerate(info.scripts): - file_id, _ = file_ids_bytes[file_spec.sha256] - make_file_use( - cursor, - item_version_id = item_version_id, - file_id = file_id, - name = file_spec.name, - type = 'W', - mime_type = 'application/javascript', - idx = idx - ) + load_packages(self, cursor, malcontent_path) + self.recompute_payloads(cursor) def recompute_payloads(self, cursor: sqlite3.Cursor) -> None: assert self.connection.in_transaction @@ -700,7 +362,16 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): self._rebuild_structures(cursor) - def _rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + def _rebuild_structures(self, cursor: t.Optional[sqlite3.Cursor] = None) \ + -> None: + """ + Recreation of data structures as done after every recomputation of + dependencies as well as at startup. + """ + if cursor is None: + with self.cursor() as new_cursor: + return self._rebuild_structures(new_cursor) + cursor.execute( ''' SELECT @@ -734,7 +405,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): enabled_status, identifier) = row - payload_ref = ConcretePayloadRef(str(payload_id_int)) + payload_ref = ConcretePayloadRef(str(payload_id_int), self) previous_data = self.payloads_data.get(payload_ref) if previous_data is not None: @@ -775,12 +446,11 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef: return ConcreteRepoIterationRef(repo_iteration_id) - def get_mapping(self, mapping_id: str) -> st.MappingRef: - return ConcreteMappingRef(mapping_id) + def mapping_store(self) -> st.MappingStore: + raise NotImplementedError() - def get_mapping_version(self, mapping_version_id: str) \ - -> st.MappingVersionRef: - return ConcreteMappingVersionRef(mapping_version_id) + def mapping_version_store(self) -> st.MappingVersionStore: + return mappings.ConcreteMappingVersionStore(self) def get_resource(self, resource_id: str) -> st.ResourceRef: return ConcreteResourceRef(resource_id) @@ -790,7 +460,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): return ConcreteResourceVersionRef(resource_version_id) def get_payload(self, payload_id: str) -> st.PayloadRef: - return 'not implemented' + raise NotImplementedError() def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \ -> st.RepoRef: |