diff options
Diffstat (limited to 'src/hydrilla/proxy/state_impl/base.py')
-rw-r--r-- | src/hydrilla/proxy/state_impl/base.py | 280 |
1 files changed, 280 insertions, 0 deletions
diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py new file mode 100644 index 0000000..f8291d8 --- /dev/null +++ b/src/hydrilla/proxy/state_impl/base.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (definition of fields of a class that +# will implement HaketiloState). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use of this +# code in a proprietary program, I am not going to enforce this in +# court. + +""" +This module defines fields that will later be part of a concrete HaketiloState +subtype. +""" + +import sqlite3 +import threading +import secrets +import webbrowser +import dataclasses as dc +import typing as t + +from pathlib import Path +from contextlib import contextmanager +from abc import abstractmethod + +from ... import url_patterns +from ... import pattern_tree +from .. import simple_dependency_satisfying as sds +from .. import state as st +from .. import policies + + +@contextmanager +def temporary_ids_tables( + cursor: sqlite3.Cursor, + tables: t.Iterable[tuple[str, t.Iterable[int]]] +) -> t.Iterator[None]: + created: set[str] = set() + + try: + for name, ids in tables: + cursor.execute( + f'CREATE TEMPORARY TABLE "{name}"(id INTEGER PRIMARY KEY);' + ) + created.add(name) + + for id in ids: + cursor.execute(f'INSERT INTO "{name}" VALUES(?);', (id,)) + + yield + finally: + for name in created: + cursor.execute(f'DROP TABLE "{name}";') + + +@dc.dataclass(frozen=True) +class PolicyTree(pattern_tree.PatternTree[policies.PolicyFactory]): + SelfType = t.TypeVar('SelfType', bound='PolicyTree') + + def register_payload( + self: 'SelfType', + pattern: url_patterns.ParsedPattern, + payload_key: st.PayloadKey, + token: str + ) -> 'SelfType': + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = self.register(pattern, payload_policy_factory) + + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + pattern.path_append(token, '***'), + resource_policy_factory + ) + + return policy_tree + +def mark_failed_file_installs( + cursor: sqlite3.Cursor, + file_sha256: str, + repo_id: int +) -> None: + cursor.execute( + ''' + WITH failed_items AS ( + SELECT DISTINCT + item_version_id + FROM + files AS f + JOIN file_uses AS fu USING (file_id) + JOIN item_versions_extra AS ive USING (item_version_id) + WHERE + f.sha256 = ? AND f.data IS NULL AND ive.repo_id = ? + ) + UPDATE + item_versions + SET + installed = 'F' + WHERE + item_version_id IN failed_items; + ''', + (file_sha256, repo_id) + ) + +NoLockArg = t.Union[t.Sequence[int], t.Literal['all_mappings_unlocked']] + +PayloadsData = t.Mapping[st.PayloadRef, st.PayloadData] + +# mypy needs to be corrected: +# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 +@dc.dataclass # type: ignore[misc] +class HaketiloStateWithFields(st.HaketiloState): + """....""" + store_dir: Path + _listen_host: str + _listen_port: int + _logger: st.Logger + connection: sqlite3.Connection + settings: st.HaketiloGlobalSettings + current_cursor: t.Optional[sqlite3.Cursor] = None + + secret: bytes = dc.field(default_factory=(lambda: secrets.token_bytes(16))) + + policy_tree: PolicyTree = PolicyTree() + payloads_data: PayloadsData = dc.field(default_factory=dict) + + lock: threading.RLock = dc.field(default_factory=threading.RLock) + + @contextmanager + def cursor(self, transaction: bool = False) \ + -> t.Iterator[sqlite3.Cursor]: + with self.lock: + start_transaction = \ + transaction and not self.connection.in_transaction + + if self.current_cursor is not None: + yield self.current_cursor + return + + try: + self.current_cursor = self.connection.cursor() + + if start_transaction: + self.current_cursor.execute('BEGIN TRANSACTION;') + + try: + yield self.current_cursor + + if start_transaction: + assert self.connection.in_transaction + self.current_cursor.execute('COMMIT TRANSACTION;') + except: + if start_transaction: + self.current_cursor.execute('ROLLBACK TRANSACTION;') + raise + except st.FileInstallationError as ex: + if start_transaction: + assert self.current_cursor is not None + mark_failed_file_installs( + cursor = self.current_cursor, + file_sha256 = ex.sha256, + repo_id = int(ex.repo_id) + ) + raise + finally: + self.current_cursor = None + + def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy: + """....""" + with self.lock: + policy_tree = self.policy_tree + + try: + best_priority: int = 0 + best_policy: t.Optional[policies.Policy] = None + + for factories_set in policy_tree.search(url): + for stored_factory in sorted(factories_set): + factory = stored_factory.item + + policy = factory.make_policy(self) + + if policy.priority > best_priority: + best_priority = policy.priority + best_policy = policy + except Exception as e: + return policies.ErrorBlockPolicy(self.settings, error=e) + + if best_policy is not None: + return best_policy + + if self.settings.default_allow_scripts: + return policies.FallbackAllowPolicy(self.settings) + else: + return policies.FallbackBlockPolicy(self.settings) + + @abstractmethod + def import_items(self, malcontent_path: Path, repo_id: int = 1) -> None: + ... + + @abstractmethod + def soft_prune_orphan_items(self) -> None: + ... + + @abstractmethod + def recompute_dependencies( + self, + unlocked_required_mappings: NoLockArg = [] + ) -> None: + ... + + @abstractmethod + def pull_missing_files(self) -> None: + """ + This function checks which packages marked as installed are missing + files in the database. It attempts to restore integrity by downloading + the files from their respective repositories. + """ + ... + + @abstractmethod + def rebuild_structures(self, *, payloads: bool = True, rules: bool = True) \ + -> None: + """ + Recreation of data structures as done after every recomputation of + dependencies as well as at startup. + """ + ... + + @property + def listen_host(self) -> str: + if self._listen_host != '0.0.0.0': + return '127.0.0.1' + + return self._listen_host + + @property + def listen_port(self) -> int: + return self._listen_port + + @property + def efective_listen_addr(self) -> str: + effective_host = self._listen_host + if self._listen_host == '0.0.0.0': + effective_host = '127.0.0.1' + + return f'http://{effective_host}:{self._listen_port}' + + def launch_browser(self) -> bool: + return webbrowser.open(self.efective_listen_addr) + + @property + def logger(self) -> st.Logger: + return self._logger |