summaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/oversimplified_state_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/oversimplified_state_impl.py')
-rw-r--r--src/hydrilla/proxy/oversimplified_state_impl.py392
1 files changed, 392 insertions, 0 deletions
diff --git a/src/hydrilla/proxy/oversimplified_state_impl.py b/src/hydrilla/proxy/oversimplified_state_impl.py
new file mode 100644
index 0000000..c082add
--- /dev/null
+++ b/src/hydrilla/proxy/oversimplified_state_impl.py
@@ -0,0 +1,392 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy data and configuration (instantiatable HaketiloState subtype).
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+This module contains logic for keeping track of all settings, rules, mappings
+and resources.
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import secrets
+import threading
+import typing as t
+import dataclasses as dc
+
+from pathlib import Path
+
+from ..pattern_tree import PatternTree
+from .. import url_patterns
+from .. import versions
+from .. import item_infos
+from .simple_dependency_satisfying import ItemsCollection, ComputedPayload
+from . import state as st
+from . import policies
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class ConcreteRepoRef(st.RepoRef):
+ def remove(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+ def update(
+ self,
+ state: st.HaketiloState,
+ *,
+ name: t.Optional[str] = None,
+ url: t.Optional[str] = None
+ ) -> ConcreteRepoRef:
+ raise NotImplementedError()
+
+ def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteRepoIterationRef(st.RepoIterationRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteMappingRef(st.MappingRef):
+ def disable(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+ def forget_enabled(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteMappingVersionRef(st.MappingVersionRef):
+ def enable(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteResourceRef(st.ResourceRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteResourceVersionRef(st.ResourceVersionRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcretePayloadRef(st.PayloadRef):
+ computed_payload: ComputedPayload = dc.field(hash=False, compare=False)
+
+ def get_data(self, state: st.HaketiloState) -> st.PayloadData:
+ return t.cast(ConcreteHaketiloState, state).payloads_data[self.id]
+
+ def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
+ return 'to implement'
+
+ def get_script_paths(self, state: st.HaketiloState) \
+ -> t.Iterator[t.Sequence[str]]:
+ for resource_info in self.computed_payload.resources:
+ for file_spec in resource_info.scripts:
+ yield (resource_info.identifier, *file_spec.name.split('/'))
+
+ def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
+ -> t.Optional[st.FileData]:
+ if len(path) == 0:
+ raise st.MissingItemError()
+
+ resource_identifier, *file_name_segments = path
+
+ file_name = '/'.join(file_name_segments)
+
+ script_sha256 = ''
+
+ matched_resource_info = False
+
+ for resource_info in self.computed_payload.resources:
+ if resource_info.identifier == resource_identifier:
+ matched_resource_info = True
+
+ for script_spec in resource_info.scripts:
+ if script_spec.name == file_name:
+ script_sha256 = script_spec.sha256
+
+ break
+
+ if not matched_resource_info:
+ raise st.MissingItemError(resource_identifier)
+
+ if script_sha256 == '':
+ return None
+
+ store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir
+ files_dir_path = store_dir_path / 'temporary_malcontent' / 'file'
+ file_path = files_dir_path / 'sha256' / script_sha256
+
+ return st.FileData(
+ type = 'application/javascript',
+ name = file_name,
+ contents = file_path.read_bytes()
+ )
+
+
+# @dc.dataclass(frozen=True, unsafe_hash=True)
+# class DummyPayloadRef(ConcretePayloadRef):
+# paths = {
+# ('someresource', 'somefolder', 'somescript.js'): st.FileData(
+# type = 'application/javascript',
+# name = 'somefolder/somescript.js',
+# contents = b'console.log("hello, mitmproxy")'
+# )
+# }
+
+# def get_data(self, state: st.HaketiloState) -> st.PayloadData:
+# parsed_pattern = next(url_patterns.parse_pattern('https://example.com'))
+
+# return st.PayloadData(
+# payload_ref = self,
+# mapping_installed = True,
+# explicitly_enabled = True,
+# unique_token = 'g54v45g456h4r',
+# pattern = parsed_pattern,
+# eval_allowed = True,
+# cors_bypass_allowed = True
+# )
+
+# def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
+# return ConcreteMappingVersionRef('somemapping')
+
+# def get_file_paths(self, state: st.HaketiloState) \
+# -> t.Iterable[t.Sequence[str]]:
+# return tuple(self.paths.keys())
+
+# def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
+# -> t.Optional[st.FileData]:
+# return self.paths[tuple(path)]
+
+
+PolicyTree = PatternTree[policies.PolicyFactory]
+
+def register_payload(
+ policy_tree: PolicyTree,
+ payload_key: st.PayloadKey,
+ token: str
+) -> PolicyTree:
+ """...."""
+ payload_policy_factory = policies.PayloadPolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ payload_key.pattern,
+ payload_policy_factory
+ )
+
+ resource_policy_factory = policies.PayloadResourcePolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ payload_key.pattern.path_append(token, '***'),
+ resource_policy_factory
+ )
+
+ return policy_tree
+
+DataById = t.Mapping[str, st.PayloadData]
+
+AnyInfo = t.TypeVar('AnyInfo', item_infos.ResourceInfo, item_infos.MappingInfo)
+
+@dc.dataclass
+class ConcreteHaketiloState(st.HaketiloState):
+ store_dir: Path
+ # settings: state.HaketiloGlobalSettings
+ policy_tree: PolicyTree = PatternTree()
+ payloads_data: DataById = dc.field(default_factory=dict)
+
+ lock: threading.RLock = dc.field(default_factory=threading.RLock)
+
+ def __post_init__(self) -> None:
+ def newest_item_path(item_dir: Path) -> t.Optional[Path]:
+ available_versions = tuple(
+ versions.parse_normalize_version(ver_path.name)
+ for ver_path in item_dir.iterdir()
+ if ver_path.is_file()
+ )
+
+ if available_versions == ():
+ return None
+
+ newest_version = max(available_versions)
+
+ version_path = item_dir / versions.version_string(newest_version)
+
+ assert version_path.is_file()
+
+ return version_path
+
+ def read_items(dir_path: Path, item_class: t.Type[AnyInfo]) \
+ -> t.Mapping[str, AnyInfo]:
+ items: dict[str, AnyInfo] = {}
+
+ for resource_dir in dir_path.iterdir():
+ if not resource_dir.is_dir():
+ continue
+
+ item_path = newest_item_path(resource_dir)
+ if item_path is None:
+ continue
+
+ item = item_class.load(item_path)
+
+ assert versions.version_string(item.version) == item_path.name
+ assert item.identifier == resource_dir.name
+
+ items[item.identifier] = item
+
+ return items
+
+ malcontent_dir = self.store_dir / 'temporary_malcontent'
+
+ items_collection = ItemsCollection(
+ read_items(malcontent_dir / 'resource', item_infos.ResourceInfo),
+ read_items(malcontent_dir / 'mapping', item_infos.MappingInfo)
+ )
+ computed_payloads = items_collection.compute_payloads()
+
+ payloads_data = {}
+
+ for mapping_info, by_pattern in computed_payloads.items():
+ for num, (pattern, payload) in enumerate(by_pattern.items()):
+ payload_id = f'{num}@{mapping_info.identifier}'
+
+ ref = ConcretePayloadRef(payload_id, payload)
+
+ data = st.PayloadData(
+ payload_ref = ref,
+ mapping_installed = True,
+ explicitly_enabled = True,
+ unique_token = secrets.token_urlsafe(16),
+ pattern = pattern,
+ eval_allowed = payload.allows_eval,
+ cors_bypass_allowed = payload.allows_cors_bypass
+ )
+
+ payloads_data[payload_id] = data
+
+ key = st.PayloadKey(
+ payload_ref = ref,
+ mapping_identifier = mapping_info.identifier,
+ pattern = pattern
+ )
+
+ self.policy_tree = register_payload(
+ self.policy_tree,
+ key,
+ data.unique_token
+ )
+
+ self.payloads_data = payloads_data
+
+ def get_repo(self, repo_id: str) -> st.RepoRef:
+ return ConcreteRepoRef(repo_id)
+
+ def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef:
+ return ConcreteRepoIterationRef(repo_iteration_id)
+
+ def get_mapping(self, mapping_id: str) -> st.MappingRef:
+ return ConcreteMappingRef(mapping_id)
+
+ def get_mapping_version(self, mapping_version_id: str) \
+ -> st.MappingVersionRef:
+ return ConcreteMappingVersionRef(mapping_version_id)
+
+ def get_resource(self, resource_id: str) -> st.ResourceRef:
+ return ConcreteResourceRef(resource_id)
+
+ def get_resource_version(self, resource_version_id: str) \
+ -> st.ResourceVersionRef:
+ return ConcreteResourceVersionRef(resource_version_id)
+
+ def get_payload(self, payload_id: str) -> st.PayloadRef:
+ return 'not implemented'
+
+ def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \
+ -> st.RepoRef:
+ raise NotImplementedError()
+
+ def get_settings(self) -> st.HaketiloGlobalSettings:
+ return st.HaketiloGlobalSettings(
+ mapping_use_mode = st.MappingUseMode.AUTO,
+ default_allow_scripts = True,
+ repo_refresh_seconds = 0
+ )
+
+ def update_settings(
+ self,
+ *,
+ mapping_use_mode: t.Optional[st.MappingUseMode] = None,
+ default_allow_scripts: t.Optional[bool] = None,
+ repo_refresh_seconds: t.Optional[int] = None
+ ) -> None:
+ raise NotImplementedError()
+
+ def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy:
+ """...."""
+ with self.lock:
+ policy_tree = self.policy_tree
+
+ try:
+ best_priority: int = 0
+ best_policy: t.Optional[policies.Policy] = None
+
+ for factories_set in policy_tree.search(url):
+ for stored_factory in sorted(factories_set):
+ factory = stored_factory.item
+
+ policy = factory.make_policy(self)
+
+ if policy.priority > best_priority:
+ best_priority = policy.priority
+ best_policy = policy
+ except Exception as e:
+ return policies.ErrorBlockPolicy(
+ builtin = True,
+ error = e
+ )
+
+ if best_policy is not None:
+ return best_policy
+
+ if self.get_settings().default_allow_scripts:
+ return policies.FallbackAllowPolicy()
+ else:
+ return policies.FallbackBlockPolicy()
+
+ @staticmethod
+ def make(store_dir: Path) -> 'ConcreteHaketiloState':
+ return ConcreteHaketiloState(store_dir=store_dir)