diff options
Diffstat (limited to 'src/hydrilla/proxy/oversimplified_state_impl.py')
-rw-r--r-- | src/hydrilla/proxy/oversimplified_state_impl.py | 392 |
1 files changed, 392 insertions, 0 deletions
diff --git a/src/hydrilla/proxy/oversimplified_state_impl.py b/src/hydrilla/proxy/oversimplified_state_impl.py new file mode 100644 index 0000000..c082add --- /dev/null +++ b/src/hydrilla/proxy/oversimplified_state_impl.py @@ -0,0 +1,392 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (instantiatable HaketiloState subtype). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +This module contains logic for keeping track of all settings, rules, mappings +and resources. +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import secrets +import threading +import typing as t +import dataclasses as dc + +from pathlib import Path + +from ..pattern_tree import PatternTree +from .. import url_patterns +from .. import versions +from .. import item_infos +from .simple_dependency_satisfying import ItemsCollection, ComputedPayload +from . import state as st +from . import policies + + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class ConcreteRepoRef(st.RepoRef): + def remove(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + def update( + self, + state: st.HaketiloState, + *, + name: t.Optional[str] = None, + url: t.Optional[str] = None + ) -> ConcreteRepoRef: + raise NotImplementedError() + + def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteRepoIterationRef(st.RepoIterationRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteMappingRef(st.MappingRef): + def disable(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + def forget_enabled(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteMappingVersionRef(st.MappingVersionRef): + def enable(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteResourceRef(st.ResourceRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteResourceVersionRef(st.ResourceVersionRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcretePayloadRef(st.PayloadRef): + computed_payload: ComputedPayload = dc.field(hash=False, compare=False) + + def get_data(self, state: st.HaketiloState) -> st.PayloadData: + return t.cast(ConcreteHaketiloState, state).payloads_data[self.id] + + def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: + return 'to implement' + + def get_script_paths(self, state: st.HaketiloState) \ + -> t.Iterator[t.Sequence[str]]: + for resource_info in self.computed_payload.resources: + for file_spec in resource_info.scripts: + yield (resource_info.identifier, *file_spec.name.split('/')) + + def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ + -> t.Optional[st.FileData]: + if len(path) == 0: + raise st.MissingItemError() + + resource_identifier, *file_name_segments = path + + file_name = '/'.join(file_name_segments) + + script_sha256 = '' + + matched_resource_info = False + + for resource_info in self.computed_payload.resources: + if resource_info.identifier == resource_identifier: + matched_resource_info = True + + for script_spec in resource_info.scripts: + if script_spec.name == file_name: + script_sha256 = script_spec.sha256 + + break + + if not matched_resource_info: + raise st.MissingItemError(resource_identifier) + + if script_sha256 == '': + return None + + store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir + files_dir_path = store_dir_path / 'temporary_malcontent' / 'file' + file_path = files_dir_path / 'sha256' / script_sha256 + + return st.FileData( + type = 'application/javascript', + name = file_name, + contents = file_path.read_bytes() + ) + + +# @dc.dataclass(frozen=True, unsafe_hash=True) +# class DummyPayloadRef(ConcretePayloadRef): +# paths = { +# ('someresource', 'somefolder', 'somescript.js'): st.FileData( +# type = 'application/javascript', +# name = 'somefolder/somescript.js', +# contents = b'console.log("hello, mitmproxy")' +# ) +# } + +# def get_data(self, state: st.HaketiloState) -> st.PayloadData: +# parsed_pattern = next(url_patterns.parse_pattern('https://example.com')) + +# return st.PayloadData( +# payload_ref = self, +# mapping_installed = True, +# explicitly_enabled = True, +# unique_token = 'g54v45g456h4r', +# pattern = parsed_pattern, +# eval_allowed = True, +# cors_bypass_allowed = True +# ) + +# def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: +# return ConcreteMappingVersionRef('somemapping') + +# def get_file_paths(self, state: st.HaketiloState) \ +# -> t.Iterable[t.Sequence[str]]: +# return tuple(self.paths.keys()) + +# def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ +# -> t.Optional[st.FileData]: +# return self.paths[tuple(path)] + + +PolicyTree = PatternTree[policies.PolicyFactory] + +def register_payload( + policy_tree: PolicyTree, + payload_key: st.PayloadKey, + token: str +) -> PolicyTree: + """....""" + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + payload_key.pattern, + payload_policy_factory + ) + + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + payload_key.pattern.path_append(token, '***'), + resource_policy_factory + ) + + return policy_tree + +DataById = t.Mapping[str, st.PayloadData] + +AnyInfo = t.TypeVar('AnyInfo', item_infos.ResourceInfo, item_infos.MappingInfo) + +@dc.dataclass +class ConcreteHaketiloState(st.HaketiloState): + store_dir: Path + # settings: state.HaketiloGlobalSettings + policy_tree: PolicyTree = PatternTree() + payloads_data: DataById = dc.field(default_factory=dict) + + lock: threading.RLock = dc.field(default_factory=threading.RLock) + + def __post_init__(self) -> None: + def newest_item_path(item_dir: Path) -> t.Optional[Path]: + available_versions = tuple( + versions.parse_normalize_version(ver_path.name) + for ver_path in item_dir.iterdir() + if ver_path.is_file() + ) + + if available_versions == (): + return None + + newest_version = max(available_versions) + + version_path = item_dir / versions.version_string(newest_version) + + assert version_path.is_file() + + return version_path + + def read_items(dir_path: Path, item_class: t.Type[AnyInfo]) \ + -> t.Mapping[str, AnyInfo]: + items: dict[str, AnyInfo] = {} + + for resource_dir in dir_path.iterdir(): + if not resource_dir.is_dir(): + continue + + item_path = newest_item_path(resource_dir) + if item_path is None: + continue + + item = item_class.load(item_path) + + assert versions.version_string(item.version) == item_path.name + assert item.identifier == resource_dir.name + + items[item.identifier] = item + + return items + + malcontent_dir = self.store_dir / 'temporary_malcontent' + + items_collection = ItemsCollection( + read_items(malcontent_dir / 'resource', item_infos.ResourceInfo), + read_items(malcontent_dir / 'mapping', item_infos.MappingInfo) + ) + computed_payloads = items_collection.compute_payloads() + + payloads_data = {} + + for mapping_info, by_pattern in computed_payloads.items(): + for num, (pattern, payload) in enumerate(by_pattern.items()): + payload_id = f'{num}@{mapping_info.identifier}' + + ref = ConcretePayloadRef(payload_id, payload) + + data = st.PayloadData( + payload_ref = ref, + mapping_installed = True, + explicitly_enabled = True, + unique_token = secrets.token_urlsafe(16), + pattern = pattern, + eval_allowed = payload.allows_eval, + cors_bypass_allowed = payload.allows_cors_bypass + ) + + payloads_data[payload_id] = data + + key = st.PayloadKey( + payload_ref = ref, + mapping_identifier = mapping_info.identifier, + pattern = pattern + ) + + self.policy_tree = register_payload( + self.policy_tree, + key, + data.unique_token + ) + + self.payloads_data = payloads_data + + def get_repo(self, repo_id: str) -> st.RepoRef: + return ConcreteRepoRef(repo_id) + + def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef: + return ConcreteRepoIterationRef(repo_iteration_id) + + def get_mapping(self, mapping_id: str) -> st.MappingRef: + return ConcreteMappingRef(mapping_id) + + def get_mapping_version(self, mapping_version_id: str) \ + -> st.MappingVersionRef: + return ConcreteMappingVersionRef(mapping_version_id) + + def get_resource(self, resource_id: str) -> st.ResourceRef: + return ConcreteResourceRef(resource_id) + + def get_resource_version(self, resource_version_id: str) \ + -> st.ResourceVersionRef: + return ConcreteResourceVersionRef(resource_version_id) + + def get_payload(self, payload_id: str) -> st.PayloadRef: + return 'not implemented' + + def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \ + -> st.RepoRef: + raise NotImplementedError() + + def get_settings(self) -> st.HaketiloGlobalSettings: + return st.HaketiloGlobalSettings( + mapping_use_mode = st.MappingUseMode.AUTO, + default_allow_scripts = True, + repo_refresh_seconds = 0 + ) + + def update_settings( + self, + *, + mapping_use_mode: t.Optional[st.MappingUseMode] = None, + default_allow_scripts: t.Optional[bool] = None, + repo_refresh_seconds: t.Optional[int] = None + ) -> None: + raise NotImplementedError() + + def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy: + """....""" + with self.lock: + policy_tree = self.policy_tree + + try: + best_priority: int = 0 + best_policy: t.Optional[policies.Policy] = None + + for factories_set in policy_tree.search(url): + for stored_factory in sorted(factories_set): + factory = stored_factory.item + + policy = factory.make_policy(self) + + if policy.priority > best_priority: + best_priority = policy.priority + best_policy = policy + except Exception as e: + return policies.ErrorBlockPolicy( + builtin = True, + error = e + ) + + if best_policy is not None: + return best_policy + + if self.get_settings().default_allow_scripts: + return policies.FallbackAllowPolicy() + else: + return policies.FallbackBlockPolicy() + + @staticmethod + def make(store_dir: Path) -> 'ConcreteHaketiloState': + return ConcreteHaketiloState(store_dir=store_dir) |