From 879c41927171efc8d77d1de2739b18e2eb57580f Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Wed, 27 Jul 2022 15:56:24 +0200 Subject: unfinished partial work --- src/hydrilla/proxy/state_impl/concrete_state.py | 704 ++++++++++++++++++++++++ 1 file changed, 704 insertions(+) create mode 100644 src/hydrilla/proxy/state_impl/concrete_state.py (limited to 'src/hydrilla/proxy/state_impl/concrete_state.py') diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py new file mode 100644 index 0000000..cd6698c --- /dev/null +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -0,0 +1,704 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (instantiatable HaketiloState subtype). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +This module contains logic for keeping track of all settings, rules, mappings +and resources. +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import secrets +import io +import hashlib +import typing as t +import dataclasses as dc + +from pathlib import Path + +import sqlite3 + +from ...exceptions import HaketiloException +from ...translations import smart_gettext as _ +from ... import pattern_tree +from ... import url_patterns +from ... import versions +from ... import item_infos +from ..simple_dependency_satisfying import ItemsCollection, ComputedPayload +from .. import state as st +from .. import policies +from . import base + + +here = Path(__file__).resolve().parent + +AnyInfo = t.Union[item_infos.ResourceInfo, item_infos.MappingInfo] + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class ConcreteRepoRef(st.RepoRef): + def remove(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + def update( + self, + state: st.HaketiloState, + *, + name: t.Optional[str] = None, + url: t.Optional[str] = None + ) -> ConcreteRepoRef: + raise NotImplementedError() + + def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteRepoIterationRef(st.RepoIterationRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteMappingRef(st.MappingRef): + def disable(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + def forget_enabled(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteMappingVersionRef(st.MappingVersionRef): + def enable(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteResourceRef(st.ResourceRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteResourceVersionRef(st.ResourceVersionRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcretePayloadRef(st.PayloadRef): + computed_payload: ComputedPayload = dc.field(hash=False, compare=False) + + def get_data(self, state: st.HaketiloState) -> st.PayloadData: + return t.cast(ConcreteHaketiloState, state).payloads_data[self.id] + + def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: + return 'to implement' + + def get_script_paths(self, state: st.HaketiloState) \ + -> t.Iterator[t.Sequence[str]]: + for resource_info in self.computed_payload.resources: + for file_spec in resource_info.scripts: + yield (resource_info.identifier, *file_spec.name.split('/')) + + def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ + -> t.Optional[st.FileData]: + if len(path) == 0: + raise st.MissingItemError() + + resource_identifier, *file_name_segments = path + + file_name = '/'.join(file_name_segments) + + script_sha256 = '' + + matched_resource_info = False + + for resource_info in self.computed_payload.resources: + if resource_info.identifier == resource_identifier: + matched_resource_info = True + + for script_spec in resource_info.scripts: + if script_spec.name == file_name: + script_sha256 = script_spec.sha256 + + break + + if not matched_resource_info: + raise st.MissingItemError(resource_identifier) + + if script_sha256 == '': + return None + + store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir + files_dir_path = store_dir_path / 'temporary_malcontent' / 'file' + file_path = files_dir_path / 'sha256' / script_sha256 + + return st.FileData( + type = 'application/javascript', + name = file_name, + contents = file_path.read_bytes() + ) + +PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] + +def register_payload( + policy_tree: PolicyTree, + payload_key: st.PayloadKey, + token: str +) -> PolicyTree: + """....""" + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + payload_key.pattern, + payload_policy_factory + ) + + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + payload_key.pattern.path_append(token, '***'), + resource_policy_factory + ) + + return policy_tree + +DataById = t.Mapping[str, st.PayloadData] + +AnyInfoVar = t.TypeVar( + 'AnyInfoVar', + item_infos.ResourceInfo, + item_infos.MappingInfo +) + +# def newest_item_path(item_dir: Path) -> t.Optional[Path]: +# available_versions = tuple( +# versions.parse_normalize_version(ver_path.name) +# for ver_path in item_dir.iterdir() +# if ver_path.is_file() +# ) + +# if available_versions == (): +# return None + +# newest_version = max(available_versions) + +# version_path = item_dir / versions.version_string(newest_version) + +# assert version_path.is_file() + +# return version_path + +def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ + -> t.Iterator[tuple[AnyInfoVar, str]]: + item_type_path = malcontent_path / item_class.type_name + if not item_type_path.is_dir(): + return + + for item_path in item_type_path.iterdir(): + if not item_path.is_dir(): + continue + + for item_version_path in item_path.iterdir(): + definition = item_version_path.read_text() + item_info = item_class.load(io.StringIO(definition)) + + assert item_info.identifier == item_path.name + assert versions.version_string(item_info.version) == \ + item_version_path.name + + yield item_info, definition + +def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int: + cursor.execute( + ''' + INSERT OR IGNORE INTO repos(name, url, deleted, next_iteration) + VALUES(?, '', TRUE, 2); + ''', + (repo_name,) + ) + + cursor.execute( + ''' + SELECT + repo_id, next_iteration - 1 + FROM + repos + WHERE + name = ?; + ''', + (repo_name,) + ) + + (repo_id, last_iteration), = cursor.fetchall() + + cursor.execute( + ''' + INSERT OR IGNORE INTO repo_iterations(repo_id, iteration) + VALUES(?, ?); + ''', + (repo_id, last_iteration) + ) + + cursor.execute( + ''' + SELECT + repo_iteration_id + FROM + repo_iterations + WHERE + repo_id = ? AND iteration = ?; + ''', + (repo_id, last_iteration) + ) + + (repo_iteration_id,), = cursor.fetchall() + + return repo_iteration_id + +def get_or_make_item(cursor: sqlite3.Cursor, type: str, identifier: str) -> int: + type_letter = {'resource': 'R', 'mapping': 'M'}[type] + + cursor.execute( + ''' + INSERT OR IGNORE INTO items(type, identifier) + VALUES(?, ?); + ''', + (type_letter, identifier) + ) + + cursor.execute( + ''' + SELECT + item_id + FROM + items + WHERE + type = ? AND identifier = ?; + ''', + (type_letter, identifier) + ) + + (item_id,), = cursor.fetchall() + + return item_id + +def get_or_make_item_version( + cursor: sqlite3.Cursor, + item_id: int, + repo_iteration_id: int, + definition: str, + info: AnyInfo +) -> int: + ver_str = versions.version_string(info.version) + + values = ( + item_id, + ver_str, + repo_iteration_id, + definition, + info.allows_eval, + info.allows_cors_bypass + ) + + cursor.execute( + ''' + INSERT OR IGNORE INTO item_versions( + item_id, + version, + repo_iteration_id, + definition, + eval_allowed, + cors_bypass_allowed + ) + VALUES(?, ?, ?, ?, ?, ?); + ''', + values + ) + + cursor.execute( + ''' + SELECT + item_version_id + FROM + item_versions + WHERE + item_id = ? AND version = ? AND repo_iteration_id = ?; + ''', + (item_id, ver_str, repo_iteration_id) + ) + + (item_version_id,), = cursor.fetchall() + + return item_version_id + +def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None: + cursor.execute( + ''' + INSERT OR IGNORE INTO mapping_statuses(item_id, enabled) + VALUES(?, 'N'); + ''', + (item_id,) + ) + +def get_or_make_file(cursor: sqlite3.Cursor, sha256: str, file_bytes: bytes) \ + -> int: + cursor.execute( + ''' + INSERT OR IGNORE INTO files(sha256, data) + VALUES(?, ?) + ''', + (sha256, file_bytes) + ) + + cursor.execute( + ''' + SELECT + file_id + FROM + files + WHERE + sha256 = ?; + ''', + (sha256,) + ) + + (file_id,), = cursor.fetchall() + + return file_id + +def make_file_use( + cursor: sqlite3.Cursor, + item_version_id: int, + file_id: int, + name: str, + type: str, + mime_type: str, + idx: int +) -> None: + cursor.execute( + ''' + INSERT OR IGNORE INTO file_uses( + item_version_id, + file_id, + name, + type, + mime_type, + idx + ) + VALUES(?, ?, ?, ?, ?, ?); + ''', + (item_version_id, file_id, name, type, mime_type, idx) + ) + +@dc.dataclass +class ConcreteHaketiloState(base.HaketiloStateWithFields): + def __post_init__(self) -> None: + self._prepare_database() + + self._populate_database_with_stuff_from_temporary_malcontent_dir() + + with self.cursor() as cursor: + self.rebuild_structures(cursor) + + def _prepare_database(self) -> None: + """....""" + cursor = self.connection.cursor() + + try: + cursor.execute( + ''' + SELECT COUNT(name) + FROM sqlite_master + WHERE name = 'general' AND type = 'table'; + ''' + ) + + (db_initialized,), = cursor.fetchall() + + if not db_initialized: + cursor.executescript((here.parent / 'tables.sql').read_text()) + + else: + cursor.execute( + ''' + SELECT haketilo_version + FROM general; + ''' + ) + + (db_haketilo_version,) = cursor.fetchone() + if db_haketilo_version != '3.0b1': + raise HaketiloException(_('err.unknown_db_schema')) + + cursor.execute('PRAGMA FOREIGN_KEYS;') + if cursor.fetchall() == []: + raise HaketiloException(_('err.proxy.no_sqlite_foreign_keys')) + + cursor.execute('PRAGMA FOREIGN_KEYS=ON;') + finally: + cursor.close() + + def _populate_database_with_stuff_from_temporary_malcontent_dir(self) \ + -> None: + malcontent_dir_path = self.store_dir / 'temporary_malcontent' + files_by_sha256_path = malcontent_dir_path / 'file' / 'sha256' + + with self.cursor(lock=True, transaction=True) as cursor: + for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]: + info: AnyInfo + for info, definition in read_items( + malcontent_dir_path, + info_type # type: ignore + ): + repo_iteration_id = get_or_make_repo_iteration( + cursor, + info.repo + ) + + item_id = get_or_make_item( + cursor, + info.type_name, + info.identifier + ) + + item_version_id = get_or_make_item_version( + cursor, + item_id, + repo_iteration_id, + definition, + info + ) + + if info_type is item_infos.MappingInfo: + make_mapping_status(cursor, item_id) + + file_ids_bytes = {} + + file_specifiers = [*info.source_copyright] + if isinstance(info, item_infos.ResourceInfo): + file_specifiers.extend(info.scripts) + + for file_spec in file_specifiers: + file_path = files_by_sha256_path / file_spec.sha256 + file_bytes = file_path.read_bytes() + + sha256 = hashlib.sha256(file_bytes).digest().hex() + assert sha256 == file_spec.sha256 + + file_id = get_or_make_file(cursor, sha256, file_bytes) + + file_ids_bytes[sha256] = (file_id, file_bytes) + + for idx, file_spec in enumerate(info.source_copyright): + file_id, file_bytes = file_ids_bytes[file_spec.sha256] + if file_bytes.isascii(): + mime = 'text/plain' + else: + mime = 'application/octet-stream' + + make_file_use( + cursor, + item_version_id = item_version_id, + file_id = file_id, + name = file_spec.name, + type = 'L', + mime_type = mime, + idx = idx + ) + + if isinstance(info, item_infos.MappingInfo): + continue + + for idx, file_spec in enumerate(info.scripts): + file_id, _ = file_ids_bytes[file_spec.sha256] + make_file_use( + cursor, + item_version_id = item_version_id, + file_id = file_id, + name = file_spec.name, + type = 'W', + mime_type = 'application/javascript', + idx = idx + ) + + def rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + cursor.execute( + ''' + SELECT + item_id, type, version, definition + FROM + item_versions JOIN items USING (item_id); + ''' + ) + + best_versions: dict[int, versions.VerTuple] = {} + definitions = {} + types = {} + + for item_id, item_type, ver_str, definition in cursor.fetchall(): + # TODO: what we're doing in this loop does not yet take different + # repos and different repo iterations into account. + ver = versions.parse_normalize_version(ver_str) + if best_versions.get(item_id, (0,)) < ver: + best_versions[item_id] = ver + definitions[item_id] = definition + types[item_id] = item_type + + resources = {} + mappings = {} + + for item_id, definition in definitions.items(): + if types[item_id] == 'R': + r_info = item_infos.ResourceInfo.load(io.StringIO(definition)) + resources[r_info.identifier] = r_info + else: + m_info = item_infos.MappingInfo.load(io.StringIO(definition)) + mappings[m_info.identifier] = m_info + + items_collection = ItemsCollection(resources, mappings) + computed_payloads = items_collection.compute_payloads() + + payloads_data = {} + + for mapping_info, by_pattern in computed_payloads.items(): + for num, (pattern, payload) in enumerate(by_pattern.items()): + payload_id = f'{num}@{mapping_info.identifier}' + + ref = ConcretePayloadRef(payload_id, payload) + + data = st.PayloadData( + payload_ref = ref, + mapping_installed = True, + explicitly_enabled = True, + unique_token = secrets.token_urlsafe(16), + pattern = pattern, + eval_allowed = payload.allows_eval, + cors_bypass_allowed = payload.allows_cors_bypass + ) + + payloads_data[payload_id] = data + + key = st.PayloadKey( + payload_ref = ref, + mapping_identifier = mapping_info.identifier, + pattern = pattern + ) + + self.policy_tree = register_payload( + self.policy_tree, + key, + data.unique_token + ) + + self.payloads_data = payloads_data + + def get_repo(self, repo_id: str) -> st.RepoRef: + return ConcreteRepoRef(repo_id) + + def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef: + return ConcreteRepoIterationRef(repo_iteration_id) + + def get_mapping(self, mapping_id: str) -> st.MappingRef: + return ConcreteMappingRef(mapping_id) + + def get_mapping_version(self, mapping_version_id: str) \ + -> st.MappingVersionRef: + return ConcreteMappingVersionRef(mapping_version_id) + + def get_resource(self, resource_id: str) -> st.ResourceRef: + return ConcreteResourceRef(resource_id) + + def get_resource_version(self, resource_version_id: str) \ + -> st.ResourceVersionRef: + return ConcreteResourceVersionRef(resource_version_id) + + def get_payload(self, payload_id: str) -> st.PayloadRef: + return 'not implemented' + + def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \ + -> st.RepoRef: + raise NotImplementedError() + + def get_settings(self) -> st.HaketiloGlobalSettings: + return st.HaketiloGlobalSettings( + mapping_use_mode = st.MappingUseMode.AUTO, + default_allow_scripts = True, + repo_refresh_seconds = 0 + ) + + def update_settings( + self, + *, + mapping_use_mode: t.Optional[st.MappingUseMode] = None, + default_allow_scripts: t.Optional[bool] = None, + repo_refresh_seconds: t.Optional[int] = None + ) -> None: + raise NotImplementedError() + + def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy: + """....""" + with self.lock: + policy_tree = self.policy_tree + + try: + best_priority: int = 0 + best_policy: t.Optional[policies.Policy] = None + + for factories_set in policy_tree.search(url): + for stored_factory in sorted(factories_set): + factory = stored_factory.item + + policy = factory.make_policy(self) + + if policy.priority > best_priority: + best_priority = policy.priority + best_policy = policy + except Exception as e: + return policies.ErrorBlockPolicy( + builtin = True, + error = e + ) + + if best_policy is not None: + return best_policy + + if self.get_settings().default_allow_scripts: + return policies.FallbackAllowPolicy() + else: + return policies.FallbackBlockPolicy() + + @staticmethod + def make(store_dir: Path) -> 'ConcreteHaketiloState': + return ConcreteHaketiloState( + store_dir = store_dir, + connection = sqlite3.connect(str(store_dir / 'sqlite3.db')) + ) -- cgit v1.2.3