aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/state_impl/concrete_state.py
diff options
context:
space:
mode:
authorWojtek Kosior <koszko@koszko.org>2022-07-27 15:56:24 +0200
committerWojtek Kosior <koszko@koszko.org>2022-08-10 17:25:05 +0200
commit879c41927171efc8d77d1de2739b18e2eb57580f (patch)
treede0e78afe2ea49e58c9bf2c662657392a00139ee /src/hydrilla/proxy/state_impl/concrete_state.py
parent52d12a4fa124daa1595529e3e7008276a7986d95 (diff)
downloadhaketilo-hydrilla-879c41927171efc8d77d1de2739b18e2eb57580f.tar.gz
haketilo-hydrilla-879c41927171efc8d77d1de2739b18e2eb57580f.zip
unfinished partial work
Diffstat (limited to 'src/hydrilla/proxy/state_impl/concrete_state.py')
-rw-r--r--src/hydrilla/proxy/state_impl/concrete_state.py704
1 files changed, 704 insertions, 0 deletions
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py
new file mode 100644
index 0000000..cd6698c
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl/concrete_state.py
@@ -0,0 +1,704 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy data and configuration (instantiatable HaketiloState subtype).
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+This module contains logic for keeping track of all settings, rules, mappings
+and resources.
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import secrets
+import io
+import hashlib
+import typing as t
+import dataclasses as dc
+
+from pathlib import Path
+
+import sqlite3
+
+from ...exceptions import HaketiloException
+from ...translations import smart_gettext as _
+from ... import pattern_tree
+from ... import url_patterns
+from ... import versions
+from ... import item_infos
+from ..simple_dependency_satisfying import ItemsCollection, ComputedPayload
+from .. import state as st
+from .. import policies
+from . import base
+
+
+here = Path(__file__).resolve().parent
+
+AnyInfo = t.Union[item_infos.ResourceInfo, item_infos.MappingInfo]
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class ConcreteRepoRef(st.RepoRef):
+ def remove(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+ def update(
+ self,
+ state: st.HaketiloState,
+ *,
+ name: t.Optional[str] = None,
+ url: t.Optional[str] = None
+ ) -> ConcreteRepoRef:
+ raise NotImplementedError()
+
+ def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteRepoIterationRef(st.RepoIterationRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteMappingRef(st.MappingRef):
+ def disable(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+ def forget_enabled(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteMappingVersionRef(st.MappingVersionRef):
+ def enable(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteResourceRef(st.ResourceRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteResourceVersionRef(st.ResourceVersionRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcretePayloadRef(st.PayloadRef):
+ computed_payload: ComputedPayload = dc.field(hash=False, compare=False)
+
+ def get_data(self, state: st.HaketiloState) -> st.PayloadData:
+ return t.cast(ConcreteHaketiloState, state).payloads_data[self.id]
+
+ def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
+ return 'to implement'
+
+ def get_script_paths(self, state: st.HaketiloState) \
+ -> t.Iterator[t.Sequence[str]]:
+ for resource_info in self.computed_payload.resources:
+ for file_spec in resource_info.scripts:
+ yield (resource_info.identifier, *file_spec.name.split('/'))
+
+ def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
+ -> t.Optional[st.FileData]:
+ if len(path) == 0:
+ raise st.MissingItemError()
+
+ resource_identifier, *file_name_segments = path
+
+ file_name = '/'.join(file_name_segments)
+
+ script_sha256 = ''
+
+ matched_resource_info = False
+
+ for resource_info in self.computed_payload.resources:
+ if resource_info.identifier == resource_identifier:
+ matched_resource_info = True
+
+ for script_spec in resource_info.scripts:
+ if script_spec.name == file_name:
+ script_sha256 = script_spec.sha256
+
+ break
+
+ if not matched_resource_info:
+ raise st.MissingItemError(resource_identifier)
+
+ if script_sha256 == '':
+ return None
+
+ store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir
+ files_dir_path = store_dir_path / 'temporary_malcontent' / 'file'
+ file_path = files_dir_path / 'sha256' / script_sha256
+
+ return st.FileData(
+ type = 'application/javascript',
+ name = file_name,
+ contents = file_path.read_bytes()
+ )
+
+PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
+
+def register_payload(
+ policy_tree: PolicyTree,
+ payload_key: st.PayloadKey,
+ token: str
+) -> PolicyTree:
+ """...."""
+ payload_policy_factory = policies.PayloadPolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ payload_key.pattern,
+ payload_policy_factory
+ )
+
+ resource_policy_factory = policies.PayloadResourcePolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ payload_key.pattern.path_append(token, '***'),
+ resource_policy_factory
+ )
+
+ return policy_tree
+
+DataById = t.Mapping[str, st.PayloadData]
+
+AnyInfoVar = t.TypeVar(
+ 'AnyInfoVar',
+ item_infos.ResourceInfo,
+ item_infos.MappingInfo
+)
+
+# def newest_item_path(item_dir: Path) -> t.Optional[Path]:
+# available_versions = tuple(
+# versions.parse_normalize_version(ver_path.name)
+# for ver_path in item_dir.iterdir()
+# if ver_path.is_file()
+# )
+
+# if available_versions == ():
+# return None
+
+# newest_version = max(available_versions)
+
+# version_path = item_dir / versions.version_string(newest_version)
+
+# assert version_path.is_file()
+
+# return version_path
+
+def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
+ -> t.Iterator[tuple[AnyInfoVar, str]]:
+ item_type_path = malcontent_path / item_class.type_name
+ if not item_type_path.is_dir():
+ return
+
+ for item_path in item_type_path.iterdir():
+ if not item_path.is_dir():
+ continue
+
+ for item_version_path in item_path.iterdir():
+ definition = item_version_path.read_text()
+ item_info = item_class.load(io.StringIO(definition))
+
+ assert item_info.identifier == item_path.name
+ assert versions.version_string(item_info.version) == \
+ item_version_path.name
+
+ yield item_info, definition
+
+def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int:
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO repos(name, url, deleted, next_iteration)
+ VALUES(?, '<dummy_url>', TRUE, 2);
+ ''',
+ (repo_name,)
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ repo_id, next_iteration - 1
+ FROM
+ repos
+ WHERE
+ name = ?;
+ ''',
+ (repo_name,)
+ )
+
+ (repo_id, last_iteration), = cursor.fetchall()
+
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO repo_iterations(repo_id, iteration)
+ VALUES(?, ?);
+ ''',
+ (repo_id, last_iteration)
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ repo_iteration_id
+ FROM
+ repo_iterations
+ WHERE
+ repo_id = ? AND iteration = ?;
+ ''',
+ (repo_id, last_iteration)
+ )
+
+ (repo_iteration_id,), = cursor.fetchall()
+
+ return repo_iteration_id
+
+def get_or_make_item(cursor: sqlite3.Cursor, type: str, identifier: str) -> int:
+ type_letter = {'resource': 'R', 'mapping': 'M'}[type]
+
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO items(type, identifier)
+ VALUES(?, ?);
+ ''',
+ (type_letter, identifier)
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ item_id
+ FROM
+ items
+ WHERE
+ type = ? AND identifier = ?;
+ ''',
+ (type_letter, identifier)
+ )
+
+ (item_id,), = cursor.fetchall()
+
+ return item_id
+
+def get_or_make_item_version(
+ cursor: sqlite3.Cursor,
+ item_id: int,
+ repo_iteration_id: int,
+ definition: str,
+ info: AnyInfo
+) -> int:
+ ver_str = versions.version_string(info.version)
+
+ values = (
+ item_id,
+ ver_str,
+ repo_iteration_id,
+ definition,
+ info.allows_eval,
+ info.allows_cors_bypass
+ )
+
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO item_versions(
+ item_id,
+ version,
+ repo_iteration_id,
+ definition,
+ eval_allowed,
+ cors_bypass_allowed
+ )
+ VALUES(?, ?, ?, ?, ?, ?);
+ ''',
+ values
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ item_version_id
+ FROM
+ item_versions
+ WHERE
+ item_id = ? AND version = ? AND repo_iteration_id = ?;
+ ''',
+ (item_id, ver_str, repo_iteration_id)
+ )
+
+ (item_version_id,), = cursor.fetchall()
+
+ return item_version_id
+
+def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None:
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO mapping_statuses(item_id, enabled)
+ VALUES(?, 'N');
+ ''',
+ (item_id,)
+ )
+
+def get_or_make_file(cursor: sqlite3.Cursor, sha256: str, file_bytes: bytes) \
+ -> int:
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO files(sha256, data)
+ VALUES(?, ?)
+ ''',
+ (sha256, file_bytes)
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ file_id
+ FROM
+ files
+ WHERE
+ sha256 = ?;
+ ''',
+ (sha256,)
+ )
+
+ (file_id,), = cursor.fetchall()
+
+ return file_id
+
+def make_file_use(
+ cursor: sqlite3.Cursor,
+ item_version_id: int,
+ file_id: int,
+ name: str,
+ type: str,
+ mime_type: str,
+ idx: int
+) -> None:
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO file_uses(
+ item_version_id,
+ file_id,
+ name,
+ type,
+ mime_type,
+ idx
+ )
+ VALUES(?, ?, ?, ?, ?, ?);
+ ''',
+ (item_version_id, file_id, name, type, mime_type, idx)
+ )
+
+@dc.dataclass
+class ConcreteHaketiloState(base.HaketiloStateWithFields):
+ def __post_init__(self) -> None:
+ self._prepare_database()
+
+ self._populate_database_with_stuff_from_temporary_malcontent_dir()
+
+ with self.cursor() as cursor:
+ self.rebuild_structures(cursor)
+
+ def _prepare_database(self) -> None:
+ """...."""
+ cursor = self.connection.cursor()
+
+ try:
+ cursor.execute(
+ '''
+ SELECT COUNT(name)
+ FROM sqlite_master
+ WHERE name = 'general' AND type = 'table';
+ '''
+ )
+
+ (db_initialized,), = cursor.fetchall()
+
+ if not db_initialized:
+ cursor.executescript((here.parent / 'tables.sql').read_text())
+
+ else:
+ cursor.execute(
+ '''
+ SELECT haketilo_version
+ FROM general;
+ '''
+ )
+
+ (db_haketilo_version,) = cursor.fetchone()
+ if db_haketilo_version != '3.0b1':
+ raise HaketiloException(_('err.unknown_db_schema'))
+
+ cursor.execute('PRAGMA FOREIGN_KEYS;')
+ if cursor.fetchall() == []:
+ raise HaketiloException(_('err.proxy.no_sqlite_foreign_keys'))
+
+ cursor.execute('PRAGMA FOREIGN_KEYS=ON;')
+ finally:
+ cursor.close()
+
+ def _populate_database_with_stuff_from_temporary_malcontent_dir(self) \
+ -> None:
+ malcontent_dir_path = self.store_dir / 'temporary_malcontent'
+ files_by_sha256_path = malcontent_dir_path / 'file' / 'sha256'
+
+ with self.cursor(lock=True, transaction=True) as cursor:
+ for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]:
+ info: AnyInfo
+ for info, definition in read_items(
+ malcontent_dir_path,
+ info_type # type: ignore
+ ):
+ repo_iteration_id = get_or_make_repo_iteration(
+ cursor,
+ info.repo
+ )
+
+ item_id = get_or_make_item(
+ cursor,
+ info.type_name,
+ info.identifier
+ )
+
+ item_version_id = get_or_make_item_version(
+ cursor,
+ item_id,
+ repo_iteration_id,
+ definition,
+ info
+ )
+
+ if info_type is item_infos.MappingInfo:
+ make_mapping_status(cursor, item_id)
+
+ file_ids_bytes = {}
+
+ file_specifiers = [*info.source_copyright]
+ if isinstance(info, item_infos.ResourceInfo):
+ file_specifiers.extend(info.scripts)
+
+ for file_spec in file_specifiers:
+ file_path = files_by_sha256_path / file_spec.sha256
+ file_bytes = file_path.read_bytes()
+
+ sha256 = hashlib.sha256(file_bytes).digest().hex()
+ assert sha256 == file_spec.sha256
+
+ file_id = get_or_make_file(cursor, sha256, file_bytes)
+
+ file_ids_bytes[sha256] = (file_id, file_bytes)
+
+ for idx, file_spec in enumerate(info.source_copyright):
+ file_id, file_bytes = file_ids_bytes[file_spec.sha256]
+ if file_bytes.isascii():
+ mime = 'text/plain'
+ else:
+ mime = 'application/octet-stream'
+
+ make_file_use(
+ cursor,
+ item_version_id = item_version_id,
+ file_id = file_id,
+ name = file_spec.name,
+ type = 'L',
+ mime_type = mime,
+ idx = idx
+ )
+
+ if isinstance(info, item_infos.MappingInfo):
+ continue
+
+ for idx, file_spec in enumerate(info.scripts):
+ file_id, _ = file_ids_bytes[file_spec.sha256]
+ make_file_use(
+ cursor,
+ item_version_id = item_version_id,
+ file_id = file_id,
+ name = file_spec.name,
+ type = 'W',
+ mime_type = 'application/javascript',
+ idx = idx
+ )
+
+ def rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
+ cursor.execute(
+ '''
+ SELECT
+ item_id, type, version, definition
+ FROM
+ item_versions JOIN items USING (item_id);
+ '''
+ )
+
+ best_versions: dict[int, versions.VerTuple] = {}
+ definitions = {}
+ types = {}
+
+ for item_id, item_type, ver_str, definition in cursor.fetchall():
+ # TODO: what we're doing in this loop does not yet take different
+ # repos and different repo iterations into account.
+ ver = versions.parse_normalize_version(ver_str)
+ if best_versions.get(item_id, (0,)) < ver:
+ best_versions[item_id] = ver
+ definitions[item_id] = definition
+ types[item_id] = item_type
+
+ resources = {}
+ mappings = {}
+
+ for item_id, definition in definitions.items():
+ if types[item_id] == 'R':
+ r_info = item_infos.ResourceInfo.load(io.StringIO(definition))
+ resources[r_info.identifier] = r_info
+ else:
+ m_info = item_infos.MappingInfo.load(io.StringIO(definition))
+ mappings[m_info.identifier] = m_info
+
+ items_collection = ItemsCollection(resources, mappings)
+ computed_payloads = items_collection.compute_payloads()
+
+ payloads_data = {}
+
+ for mapping_info, by_pattern in computed_payloads.items():
+ for num, (pattern, payload) in enumerate(by_pattern.items()):
+ payload_id = f'{num}@{mapping_info.identifier}'
+
+ ref = ConcretePayloadRef(payload_id, payload)
+
+ data = st.PayloadData(
+ payload_ref = ref,
+ mapping_installed = True,
+ explicitly_enabled = True,
+ unique_token = secrets.token_urlsafe(16),
+ pattern = pattern,
+ eval_allowed = payload.allows_eval,
+ cors_bypass_allowed = payload.allows_cors_bypass
+ )
+
+ payloads_data[payload_id] = data
+
+ key = st.PayloadKey(
+ payload_ref = ref,
+ mapping_identifier = mapping_info.identifier,
+ pattern = pattern
+ )
+
+ self.policy_tree = register_payload(
+ self.policy_tree,
+ key,
+ data.unique_token
+ )
+
+ self.payloads_data = payloads_data
+
+ def get_repo(self, repo_id: str) -> st.RepoRef:
+ return ConcreteRepoRef(repo_id)
+
+ def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef:
+ return ConcreteRepoIterationRef(repo_iteration_id)
+
+ def get_mapping(self, mapping_id: str) -> st.MappingRef:
+ return ConcreteMappingRef(mapping_id)
+
+ def get_mapping_version(self, mapping_version_id: str) \
+ -> st.MappingVersionRef:
+ return ConcreteMappingVersionRef(mapping_version_id)
+
+ def get_resource(self, resource_id: str) -> st.ResourceRef:
+ return ConcreteResourceRef(resource_id)
+
+ def get_resource_version(self, resource_version_id: str) \
+ -> st.ResourceVersionRef:
+ return ConcreteResourceVersionRef(resource_version_id)
+
+ def get_payload(self, payload_id: str) -> st.PayloadRef:
+ return 'not implemented'
+
+ def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \
+ -> st.RepoRef:
+ raise NotImplementedError()
+
+ def get_settings(self) -> st.HaketiloGlobalSettings:
+ return st.HaketiloGlobalSettings(
+ mapping_use_mode = st.MappingUseMode.AUTO,
+ default_allow_scripts = True,
+ repo_refresh_seconds = 0
+ )
+
+ def update_settings(
+ self,
+ *,
+ mapping_use_mode: t.Optional[st.MappingUseMode] = None,
+ default_allow_scripts: t.Optional[bool] = None,
+ repo_refresh_seconds: t.Optional[int] = None
+ ) -> None:
+ raise NotImplementedError()
+
+ def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy:
+ """...."""
+ with self.lock:
+ policy_tree = self.policy_tree
+
+ try:
+ best_priority: int = 0
+ best_policy: t.Optional[policies.Policy] = None
+
+ for factories_set in policy_tree.search(url):
+ for stored_factory in sorted(factories_set):
+ factory = stored_factory.item
+
+ policy = factory.make_policy(self)
+
+ if policy.priority > best_priority:
+ best_priority = policy.priority
+ best_policy = policy
+ except Exception as e:
+ return policies.ErrorBlockPolicy(
+ builtin = True,
+ error = e
+ )
+
+ if best_policy is not None:
+ return best_policy
+
+ if self.get_settings().default_allow_scripts:
+ return policies.FallbackAllowPolicy()
+ else:
+ return policies.FallbackBlockPolicy()
+
+ @staticmethod
+ def make(store_dir: Path) -> 'ConcreteHaketiloState':
+ return ConcreteHaketiloState(
+ store_dir = store_dir,
+ connection = sqlite3.connect(str(store_dir / 'sqlite3.db'))
+ )