diff options
author | Wojtek Kosior <koszko@koszko.org> | 2022-07-27 15:56:24 +0200 |
---|---|---|
committer | Wojtek Kosior <koszko@koszko.org> | 2022-08-10 17:25:05 +0200 |
commit | 879c41927171efc8d77d1de2739b18e2eb57580f (patch) | |
tree | de0e78afe2ea49e58c9bf2c662657392a00139ee /src/hydrilla | |
parent | 52d12a4fa124daa1595529e3e7008276a7986d95 (diff) | |
download | haketilo-hydrilla-879c41927171efc8d77d1de2739b18e2eb57580f.tar.gz haketilo-hydrilla-879c41927171efc8d77d1de2739b18e2eb57580f.zip |
unfinished partial work
Diffstat (limited to 'src/hydrilla')
25 files changed, 3946 insertions, 794 deletions
diff --git a/src/hydrilla/item_infos.py b/src/hydrilla/item_infos.py index c366ab5..9ba47bd 100644 --- a/src/hydrilla/item_infos.py +++ b/src/hydrilla/item_infos.py @@ -31,48 +31,75 @@ # Enable using with Python 3.7. from __future__ import annotations +import sys + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + import typing as t import dataclasses as dc -from pathlib import Path, PurePath +from pathlib import Path, PurePosixPath +from abc import ABC -from immutables import Map, MapMutation +from immutables import Map from . import versions, json_instances -from .url_patterns import parse_pattern, ParsedUrl +from .url_patterns import parse_pattern, ParsedUrl, ParsedPattern from .exceptions import HaketiloException from .translations import smart_gettext as _ VerTuple = t.Tuple[int, ...] @dc.dataclass(frozen=True, unsafe_hash=True) -class ItemRef: +class ItemSpecifier: """....""" identifier: str -RefObjs = t.Sequence[t.Mapping[str, t.Any]] +SpecifierObjs = t.Sequence[t.Mapping[str, t.Any]] -def make_item_refs_seq(ref_objs: RefObjs) -> tuple[ItemRef, ...]: +def make_item_specifiers_seq(spec_objs: SpecifierObjs) \ + -> tuple[ItemSpecifier, ...]: """....""" - return tuple(ItemRef(ref['identifier']) for ref in ref_objs) + return tuple(ItemSpecifier(obj['identifier']) for obj in spec_objs) -def make_required_mappings(refs_objs: t.Any, schema_compat: int) \ - -> tuple[ItemRef, ...]: +def make_required_mappings(spec_objs: t.Any, schema_compat: int) \ + -> tuple[ItemSpecifier, ...]: """....""" if schema_compat < 2: return () - return make_item_refs_seq(refs_objs) + return make_item_specifiers_seq(spec_objs) @dc.dataclass(frozen=True, unsafe_hash=True) -class FileRef: +class FileSpecifier: """....""" name: str sha256: str -def make_file_refs_seq(ref_objs: RefObjs) -> tuple[FileRef, ...]: +def normalize_filename(name: str): + """ + This function eliminated double slashes in file name and ensures it does not + try to reference parent directories. + """ + path = PurePosixPath(name) + + if '.' in path.parts or '..' in path.parts: + msg = _('err.item_info.filename_invalid_{}').format(name) + raise HaketiloException(msg) + + return str(path) + +def make_file_specifiers_seq(spec_objs: SpecifierObjs) \ + -> tuple[FileSpecifier, ...]: """....""" - return tuple(FileRef(ref['file'], ref['sha256']) for ref in ref_objs) + return tuple( + FileSpecifier(normalize_filename(obj['file']), obj['sha256']) + for obj + in spec_objs + ) @dc.dataclass(frozen=True, unsafe_hash=True) class GeneratedBy: @@ -81,60 +108,107 @@ class GeneratedBy: version: t.Optional[str] @staticmethod - def make(generated_obj: t.Optional[t.Mapping[str, t.Any]]) -> \ + def make(generated_by_obj: t.Optional[t.Mapping[str, t.Any]]) -> \ t.Optional['GeneratedBy']: """....""" - if generated_obj is None: + if generated_by_obj is None: return None return GeneratedBy( - name = generated_obj['name'], - version = generated_obj.get('version') + name = generated_by_obj['name'], + version = generated_by_obj.get('version') ) -@dc.dataclass(frozen=True, unsafe_hash=True) -class ItemInfoBase: + +def make_eval_permission(perms_obj: t.Any, schema_compat: int) -> bool: + if schema_compat < 2: + return False + + return perms_obj.get('eval', False) + + +def make_cors_bypass_permission(perms_obj: t.Any, schema_compat: int) -> bool: + if schema_compat < 2: + return False + + return perms_obj.get('cors_bypass', False) + + +class Categorizable(Protocol): """....""" - repository: str # repository used in __hash__() - source_name: str = dc.field(hash=False) - source_copyright: tuple[FileRef, ...] = dc.field(hash=False) - version: VerTuple # version used in __hash__() - identifier: str # identifier used in __hash__() - uuid: t.Optional[str] = dc.field(hash=False) - long_name: str = dc.field(hash=False) - required_mappings: tuple[ItemRef, ...] = dc.field(hash=False) - generated_by: t.Optional[GeneratedBy] = dc.field(hash=False) - - def path_relative_to_type(self) -> str: - """ - Get a relative path to this item's JSON definition with respect to - directory containing items of this type. - """ - return f'{self.identifier}/{versions.version_string(self.version)}' - - def path(self) -> str: - """ - Get a relative path to this item's JSON definition with respect to - malcontent directory containing loadable items. - """ - return f'{self.type_name}/{self.path_relative_to_type()}' + uuid: t.Optional[str] + identifier: str - @property - def versioned_identifier(self): - """....""" - return f'{self.identifier}-{versions.version_string(self.version)}' +@dc.dataclass(frozen=True, unsafe_hash=True) +class ItemIdentity: + repo: str + repo_iteration: int + version: versions.VerTuple + identifier: str + +# mypy needs to be corrected: +# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 +@dc.dataclass(frozen=True) # type: ignore[misc] +class ItemInfoBase(ABC, ItemIdentity, Categorizable): + """....""" + type_name: t.ClassVar[str] = '!INVALID!' + + source_name: str = dc.field(hash=False) + source_copyright: tuple[FileSpecifier, ...] = dc.field(hash=False) + uuid: t.Optional[str] = dc.field(hash=False) + long_name: str = dc.field(hash=False) + allows_eval: bool = dc.field(hash=False) + allows_cors_bypass: bool = dc.field(hash=False) + required_mappings: tuple[ItemSpecifier, ...] = dc.field(hash=False) + generated_by: t.Optional[GeneratedBy] = dc.field(hash=False) + + # def path_relative_to_type(self) -> str: + # """ + # Get a relative path to this item's JSON definition with respect to + # directory containing items of this type. + # """ + # return f'{self.identifier}/{versions.version_string(self.version)}' + + # def path(self) -> str: + # """ + # Get a relative path to this item's JSON definition with respect to + # malcontent directory containing loadable items. + # """ + # return f'{self.type_name}/{self.path_relative_to_type()}' + + # @property + # def identity(self): + # """....""" + # return ItemIdentity( + # repository = self.repository, + # version = self.version, + # identifier = self.identifier + # ) + + # @property + # def versioned_identifier(self): + # """....""" + # return f'{self.identifier}-{versions.version_string(self.version)}' @staticmethod def _get_base_init_kwargs( - item_obj: t.Mapping[str, t.Any], - schema_compat: int, - repository: str + item_obj: t.Mapping[str, t.Any], + schema_compat: int, + repo: str, + repo_iteration: int ) -> t.Mapping[str, t.Any]: """....""" - source_copyright = make_file_refs_seq(item_obj['source_copyright']) + source_copyright = make_file_specifiers_seq( + item_obj['source_copyright'] + ) version = versions.normalize_version(item_obj['version']) + perms_obj = item_obj.get('permissions', {}) + + eval_perm = make_eval_permission(perms_obj, schema_compat) + cors_bypass_perm = make_cors_bypass_permission(perms_obj, schema_compat) + required_mappings = make_required_mappings( item_obj.get('required_mappings', []), schema_compat @@ -143,28 +217,29 @@ class ItemInfoBase: generated_by = GeneratedBy.make(item_obj.get('generated_by')) return Map( - repository = repository, - source_name = item_obj['source_name'], - source_copyright = source_copyright, - version = version, - identifier = item_obj['identifier'], - uuid = item_obj.get('uuid'), - long_name = item_obj['long_name'], - required_mappings = required_mappings, - generated_by = generated_by + repo = repo, + repo_iteration = repo_iteration, + source_name = item_obj['source_name'], + source_copyright = source_copyright, + version = version, + identifier = item_obj['identifier'], + uuid = item_obj.get('uuid'), + long_name = item_obj['long_name'], + allows_eval = eval_perm, + allows_cors_bypass = cors_bypass_perm, + required_mappings = required_mappings, + generated_by = generated_by ) - # class property - type_name = '!INVALID!' - -InstanceOrPath = t.Union[Path, str, dict[str, t.Any]] @dc.dataclass(frozen=True, unsafe_hash=True) class ResourceInfo(ItemInfoBase): """....""" - revision: int = dc.field(hash=False) - dependencies: tuple[ItemRef, ...] = dc.field(hash=False) - scripts: tuple[FileRef, ...] = dc.field(hash=False) + type_name: t.ClassVar[str] = 'resource' + + revision: int = dc.field(hash=False) + dependencies: tuple[ItemSpecifier, ...] = dc.field(hash=False) + scripts: tuple[FileSpecifier, ...] = dc.field(hash=False) @property def versioned_identifier(self): @@ -173,41 +248,70 @@ class ResourceInfo(ItemInfoBase): @staticmethod def make( - item_obj: t.Mapping[str, t.Any], - schema_compat: int, - repository: str + item_obj: t.Mapping[str, t.Any], + schema_compat: int, + repo: str, + repo_iteration: int ) -> 'ResourceInfo': """....""" base_init_kwargs = ItemInfoBase._get_base_init_kwargs( item_obj, schema_compat, - repository + repo, + repo_iteration + ) + + dependencies = make_item_specifiers_seq( + item_obj.get('dependencies', []) + ) + + scripts = make_file_specifiers_seq( + item_obj.get('scripts', []) ) return ResourceInfo( **base_init_kwargs, revision = item_obj['revision'], - dependencies = make_item_refs_seq(item_obj.get('dependencies', [])), - scripts = make_file_refs_seq(item_obj.get('scripts', [])), + dependencies = dependencies, + scripts = scripts ) @staticmethod - def load(instance_or_path: 'InstanceOrPath', repository: str) \ - -> 'ResourceInfo': + def load( + instance_or_path: json_instances.InstanceOrPathOrIO, + repo: str = '<dummyrepo>', + repo_iteration: int = -1 + ) -> 'ResourceInfo': """....""" - return _load_item_info(ResourceInfo, instance_or_path, repository) + return _load_item_info( + ResourceInfo, + instance_or_path, + repo, + repo_iteration + ) - # class property - type_name = 'resource' + # def __lt__(self, other: 'ResourceInfo') -> bool: + # """....""" + # return ( + # self.identifier, + # self.version, + # self.revision, + # self.repository + # ) < ( + # other.identifier, + # other.version, + # other.revision, + # other.repository + # ) def make_payloads(payloads_obj: t.Mapping[str, t.Any]) \ - -> t.Mapping[ParsedUrl, ItemRef]: + -> t.Mapping[ParsedPattern, ItemSpecifier]: """....""" - mapping: list[tuple[ParsedUrl, ItemRef]] = [] + mapping: list[tuple[ParsedPattern, ItemSpecifier]] = [] - for pattern, ref_obj in payloads_obj.items(): - ref = ItemRef(ref_obj['identifier']) + for pattern, spec_obj in payloads_obj.items(): + ref = ItemSpecifier(spec_obj['identifier']) mapping.extend((parsed, ref) for parsed in parse_pattern(pattern)) return Map(mapping) @@ -215,19 +319,23 @@ def make_payloads(payloads_obj: t.Mapping[str, t.Any]) \ @dc.dataclass(frozen=True, unsafe_hash=True) class MappingInfo(ItemInfoBase): """....""" - payloads: t.Mapping[ParsedUrl, ItemRef] = dc.field(hash=False) + type_name: t.ClassVar[str] = 'mapping' + + payloads: t.Mapping[ParsedPattern, ItemSpecifier] = dc.field(hash=False) @staticmethod def make( - item_obj: t.Mapping[str, t.Any], - schema_compat: int, - repository: str + item_obj: t.Mapping[str, t.Any], + schema_compat: int, + repo: str, + repo_iteration: int ) -> 'MappingInfo': """....""" base_init_kwargs = ItemInfoBase._get_base_init_kwargs( item_obj, schema_compat, - repository + repo, + repo_iteration ) return MappingInfo( @@ -237,10 +345,23 @@ class MappingInfo(ItemInfoBase): ) @staticmethod - def load(instance_or_path: 'InstanceOrPath', repository: str) \ - -> 'MappingInfo': + def load( + instance_or_path: json_instances.InstanceOrPathOrIO, + repo: str = '<dummyrepo>', + repo_iteration: int = -1 + ) -> 'MappingInfo': """....""" - return _load_item_info(MappingInfo, instance_or_path, repository) + return _load_item_info( + MappingInfo, + instance_or_path, + repo, + repo_iteration + ) + + # def __lt__(self, other: 'MappingInfo') -> bool: + # """....""" + # return (self.identifier, self.version, self.repository) < \ + # (other.identifier, other.version, other.repository) # class property type_name = 'mapping' @@ -250,8 +371,9 @@ LoadedType = t.TypeVar('LoadedType', ResourceInfo, MappingInfo) def _load_item_info( info_type: t.Type[LoadedType], - instance_or_path: InstanceOrPath, - repository: str + instance_or_path: json_instances.InstanceOrPathOrIO, + repo: str, + repo_iteration: int ) -> LoadedType: """Read, validate and autocomplete a mapping/resource description.""" instance = json_instances.read_instance(instance_or_path) @@ -264,81 +386,156 @@ def _load_item_info( return info_type.make( t.cast('dict[str, t.Any]', instance), schema_compat, - repository + repo, + repo_iteration ) -VersionedType = t.TypeVar('VersionedType', ResourceInfo, MappingInfo) - -@dc.dataclass(frozen=True) -class VersionedItemInfo(t.Generic[VersionedType]): - """Stores data of multiple versions of given resource/mapping.""" - uuid: t.Optional[str] = None - identifier: str = '<dummy>' - _by_version: Map[VerTuple, VersionedType] = Map() - _initialized: bool = False - - def register(self, item_info: VersionedType) -> 'VersionedInfoSelfType': - """ - Make item info queryable by version. Perform sanity checks for uuid. - """ - identifier = item_info.identifier - if self._initialized: - assert identifier == self.identifier - - if self.uuid is not None: - uuid: t.Optional[str] = self.uuid - if item_info.uuid is not None and self.uuid != item_info.uuid: - raise HaketiloException(_('uuid_mismatch_{identifier}') - .format(identifier=identifier)) - else: - uuid = item_info.uuid - - by_version = self._by_version.set(item_info.version, item_info) - - return VersionedItemInfo( - identifier = identifier, - uuid = uuid, - _by_version = by_version, - _initialized = True - ) - - def unregister(self, version: VerTuple) -> 'VersionedInfoSelfType': - """....""" - try: - by_version = self._by_version.delete(version) - except KeyError: - by_version = self._by_version - - return dc.replace(self, _by_version=by_version) - - def is_empty(self) -> bool: - """....""" - return len(self._by_version) == 0 - - def newest_version(self) -> VerTuple: - """....""" - assert not self.is_empty() - - return max(self._by_version.keys()) - - def get_newest(self) -> VersionedType: - """Find and return info of the newest version of item.""" - newest = self._by_version[self.newest_version()] - assert newest is not None - return newest - - def get_by_ver(self, ver: t.Iterable[int]) -> t.Optional[VersionedType]: - """ - Find and return info of the specified version of the item (or None if - absent). - """ - return self._by_version.get(tuple(ver)) - - def get_all(self) -> t.Iterator[VersionedType]: - """Generate item info for all its versions, from oldest ot newest.""" - for version in sorted(self._by_version.keys()): - yield self._by_version[version] - -# Below we define 1 type used by recursively-typed VersionedItemInfo. -VersionedInfoSelfType = VersionedItemInfo[VersionedType] +# CategorizedType = t.TypeVar( +# 'CategorizedType', +# bound=Categorizable +# ) + +# CategorizedUpdater = t.Callable[ +# [t.Optional[CategorizedType]], +# t.Optional[CategorizedType] +# ] + +# CategoryKeyType = t.TypeVar('CategoryKeyType', bound=t.Hashable) + +# @dc.dataclass(frozen=True) +# class CategorizedItemInfo(Categorizable, t.Generic[CategorizedType, CategoryKeyType]): +# """....""" +# SelfType = t.TypeVar( +# 'SelfType', +# bound = 'CategorizedItemInfo[CategorizedType, CategoryKeyType]' +# ) + +# uuid: t.Optional[str] = None +# identifier: str = '<dummy>' +# items: Map[CategoryKeyType, CategorizedType] = Map() +# _initialized: bool = False + +# def _update( +# self: 'SelfType', +# key: CategoryKeyType, +# updater: CategorizedUpdater +# ) -> 'SelfType': +# """...... Perform sanity checks for uuid.""" +# uuid = self.uuid + +# items = self.items.mutate() + +# updated = updater(items.get(key)) +# if updated is None: +# items.pop(key, None) + +# identifier = self.identifier +# else: +# items[key] = updated + +# identifier = updated.identifier +# if self._initialized: +# assert identifier == self.identifier + +# if uuid is not None: +# if updated.uuid is not None and uuid != updated.uuid: +# raise HaketiloException(_('uuid_mismatch_{identifier}') +# .format(identifier=identifier)) +# else: +# uuid = updated.uuid + +# return dc.replace( +# self, +# identifier = identifier, +# uuid = uuid, +# items = items.finish(), +# _initialized = self._initialized or updated is not None +# ) + +# def is_empty(self) -> bool: +# """....""" +# return len(self.items) == 0 + + +# VersionedType = t.TypeVar('VersionedType', ResourceInfo, MappingInfo) + +# class VersionedItemInfo( +# CategorizedItemInfo[VersionedType, VerTuple], +# t.Generic[VersionedType] +# ): +# """Stores data of multiple versions of given resource/mapping.""" +# SelfType = t.TypeVar('SelfType', bound='VersionedItemInfo[VersionedType]') + +# def register(self: 'SelfType', item_info: VersionedType) -> 'SelfType': +# """ +# Make item info queryable by version. Perform sanity checks for uuid. +# """ +# return self._update(item_info.version, lambda old_info: item_info) + +# def unregister(self: 'SelfType', version: VerTuple) -> 'SelfType': +# """....""" +# return self._update(version, lambda old_info: None) + +# def newest_version(self) -> VerTuple: +# """....""" +# assert not self.is_empty() + +# return max(self.items.keys()) + +# def get_newest(self) -> VersionedType: +# """Find and return info of the newest version of item.""" +# newest = self.items[self.newest_version()] +# assert newest is not None +# return newest + +# def get_by_ver(self, ver: t.Iterable[int]) -> t.Optional[VersionedType]: +# """ +# Find and return info of the specified version of the item (or None if +# absent). +# """ +# return self.items.get(tuple(ver)) + +# def get_all(self) -> t.Iterator[VersionedType]: +# """Generate item info for all its versions, from oldest to newest.""" +# for version in sorted(self.items.keys()): +# yield self.items[version] + + +# MultiRepoType = t.TypeVar('MultiRepoType', ResourceInfo, MappingInfo) +# MultiRepoVersioned = VersionedItemInfo[MultiRepoType] + +# class MultiRepoItemInfo( +# CategorizedItemInfo[MultiRepoVersioned, str], +# t.Generic[MultiRepoType] +# ): +# SelfType = t.TypeVar('SelfType', bound='MultiRepoItemInfo[MultiRepoType]') + +# def register(self: 'SelfType', item_info: MultiRepoType) -> 'SelfType': +# """ +# Make item info queryable by version. Perform sanity checks for uuid. +# """ +# def updater(old_item: t.Optional[MultiRepoVersioned]) \ +# -> MultiRepoVersioned: +# """....""" +# if old_item is None: +# old_item = VersionedItemInfo() + +# return old_item.register(item_info) + +# return self._update(item_info.repository, updater) + +# def unregister(self: 'SelfType', version: VerTuple, repository: str) \ +# -> 'SelfType': +# """....""" +# def updater(old_item: t.Optional[MultiRepoVersioned]) -> \ +# t.Optional[MultiRepoVersioned]: +# """....""" +# if old_item is None: +# return None + +# new_item = old_item.unregister(version) + +# return None if new_item.is_empty() else new_item + +# return self._update(repository, updater) diff --git a/src/hydrilla/json_instances.py b/src/hydrilla/json_instances.py index 40b213b..33a3785 100644 --- a/src/hydrilla/json_instances.py +++ b/src/hydrilla/json_instances.py @@ -34,6 +34,7 @@ from __future__ import annotations import re import json import os +import io import typing as t from pathlib import Path, PurePath @@ -158,15 +159,22 @@ def parse_instance(text: str) -> object: """Parse 'text' as JSON with additional '//' comments support.""" return json.loads(strip_json_comments(text)) -InstanceOrPath = t.Union[Path, str, dict[str, t.Any]] +InstanceOrPathOrIO = t.Union[Path, str, io.TextIOBase, dict[str, t.Any]] -def read_instance(instance_or_path: InstanceOrPath) -> object: +def read_instance(instance_or_path: InstanceOrPathOrIO) -> object: """....""" if isinstance(instance_or_path, dict): return instance_or_path - with open(instance_or_path, 'rt') as handle: + if isinstance(instance_or_path, io.TextIOBase): + handle = instance_or_path + else: + handle = t.cast(io.TextIOBase, open(instance_or_path, 'rt')) + + try: text = handle.read() + finally: + handle.close() try: return parse_instance(text) diff --git a/src/hydrilla/mitmproxy_launcher/__init__.py b/src/hydrilla/mitmproxy_launcher/__init__.py new file mode 100644 index 0000000..d382ead --- /dev/null +++ b/src/hydrilla/mitmproxy_launcher/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: CC0-1.0 + +# Copyright (C) 2022 Wojtek Kosior <koszko@koszko.org> +# +# Available under the terms of Creative Commons Zero v1.0 Universal. diff --git a/src/hydrilla/mitmproxy_launcher/launch.py b/src/hydrilla/mitmproxy_launcher/launch.py index c826598..765a9ce 100644 --- a/src/hydrilla/mitmproxy_launcher/launch.py +++ b/src/hydrilla/mitmproxy_launcher/launch.py @@ -42,6 +42,12 @@ import click from .. import _version from ..translations import smart_gettext as _ +addon_script_text = ''' +from hydrilla.proxy.addon import HaketiloAddon + +addons = [HaketiloAddon()] +''' + @click.command(help=_('cli_help.haketilo')) @click.option('-p', '--port', default=8080, type=click.IntRange(0, 65535), help=_('cli_opt.haketilo.port')) @@ -61,17 +67,13 @@ def launch(port: int, directory: str): script_path = directory_path / 'addon.py' - script_path.write_text(''' -from hydrilla.mitmproxy_addon.addon import Haketilo - -addons = [Haketilo()] -''') + script_path.write_text(addon_script_text) code = sp.call(['mitmdump', '-p', str(port), - '--set', f'confdir={directory_path / "mitmproxy"}' + '--set', f'confdir={directory_path / "mitmproxy"}', '--set', 'upstream_cert=false', - '--set', f'haketilo_dir={directory_path}' + '--set', f'haketilo_dir={directory_path}', '--scripts', str(script_path)]) sys.exit(code) diff --git a/src/hydrilla/pattern_tree.py b/src/hydrilla/pattern_tree.py index 1128a06..99f45a5 100644 --- a/src/hydrilla/pattern_tree.py +++ b/src/hydrilla/pattern_tree.py @@ -31,44 +31,37 @@ This module defines data structures for querying data using URL patterns. # Enable using with Python 3.7. from __future__ import annotations -import sys import typing as t import dataclasses as dc from immutables import Map -from .url_patterns import ParsedUrl, parse_url +from .url_patterns import ParsedPattern, ParsedUrl, parse_url#, catchall_pattern from .translations import smart_gettext as _ WrapperStoredType = t.TypeVar('WrapperStoredType', bound=t.Hashable) -@dc.dataclass(frozen=True, unsafe_hash=True) +@dc.dataclass(frozen=True, unsafe_hash=True, order=True) class StoredTreeItem(t.Generic[WrapperStoredType]): """ In the Pattern Tree, each item is stored together with the pattern used to register it. """ - pattern: ParsedUrl item: WrapperStoredType + pattern: ParsedPattern -# if sys.version_info >= (3, 8): -# CopyableType = t.TypeVar('CopyableType', bound='Copyable') - -# class Copyable(t.Protocol): -# """Certain classes in Pattern Tree depend on this interface.""" -# def copy(self: CopyableType) -> CopyableType: -# """Make a distinct instance with the same properties as this one.""" -# ... -# else: -# Copyable = t.Any NodeStoredType = t.TypeVar('NodeStoredType') @dc.dataclass(frozen=True) class PatternTreeNode(t.Generic[NodeStoredType]): """....""" - children: 'NodeChildrenType' = Map() + SelfType = t.TypeVar('SelfType', bound='PatternTreeNode[NodeStoredType]') + + ChildrenType = Map[str, SelfType] + + children: 'ChildrenType' = Map() literal_match: t.Optional[NodeStoredType] = None def is_empty(self) -> bool: @@ -76,17 +69,17 @@ class PatternTreeNode(t.Generic[NodeStoredType]): return len(self.children) == 0 and self.literal_match is None def update_literal_match( - self, + self: 'SelfType', new_match_item: t.Optional[NodeStoredType] - ) -> 'NodeSelfType': + ) -> 'SelfType': """....""" return dc.replace(self, literal_match=new_match_item) - def get_child(self, child_key: str) -> t.Optional['NodeSelfType']: + def get_child(self: 'SelfType', child_key: str) -> t.Optional['SelfType']: """....""" return self.children.get(child_key) - def remove_child(self, child_key: str) -> 'NodeSelfType': + def remove_child(self: 'SelfType', child_key: str) -> 'SelfType': """....""" try: children = self.children.delete(child_key) @@ -95,19 +88,15 @@ class PatternTreeNode(t.Generic[NodeStoredType]): return dc.replace(self, children=children) - def set_child(self, child_key: str, child: 'NodeSelfType') \ - -> 'NodeSelfType': + def set_child(self: 'SelfType', child_key: str, child: 'SelfType') \ + -> 'SelfType': """....""" return dc.replace(self, children=self.children.set(child_key, child)) -# Below we define 2 types used by recursively-typed PatternTreeNode. -NodeSelfType = PatternTreeNode[NodeStoredType] -NodeChildrenType = Map[str, NodeSelfType] - BranchStoredType = t.TypeVar('BranchStoredType') -ItemUpdater = t.Callable[ +BranchItemUpdater = t.Callable[ [t.Optional[BranchStoredType]], t.Optional[BranchStoredType] ] @@ -115,18 +104,22 @@ ItemUpdater = t.Callable[ @dc.dataclass(frozen=True) class PatternTreeBranch(t.Generic[BranchStoredType]): """....""" + SelfType = t.TypeVar( + 'SelfType', + bound = 'PatternTreeBranch[BranchStoredType]' + ) + root_node: PatternTreeNode[BranchStoredType] = PatternTreeNode() def is_empty(self) -> bool: """....""" return self.root_node.is_empty() - # def copy(self) -> 'BranchSelfType': - # """....""" - # return dc.replace(self) - - def update(self, segments: t.Iterable[str], item_updater: ItemUpdater) \ - -> 'BranchSelfType': + def update( + self: 'SelfType', + segments: t.Iterable[str], + item_updater: BranchItemUpdater + ) -> 'SelfType': """ ....... """ @@ -188,9 +181,6 @@ class PatternTreeBranch(t.Generic[BranchStoredType]): if condition(): yield match_node.literal_match -# Below we define 1 type used by recursively-typed PatternTreeBranch. -BranchSelfType = PatternTreeBranch[BranchStoredType] - FilterStoredType = t.TypeVar('FilterStoredType', bound=t.Hashable) FilterWrappedType = StoredTreeItem[FilterStoredType] @@ -218,19 +208,21 @@ class PatternTree(t.Generic[TreeStoredType]): is to make it possible to quickly retrieve all known patterns that match a given URL. """ - _by_scheme_and_port: TreeRoot = Map() + SelfType = t.TypeVar('SelfType', bound='PatternTree[TreeStoredType]') + + _by_scheme_and_port: TreeRoot = Map() def _register( - self, - parsed_pattern: ParsedUrl, + self: 'SelfType', + parsed_pattern: ParsedPattern, item: TreeStoredType, register: bool = True - ) -> 'TreeSelfType': + ) -> 'SelfType': """ Make an item wrapped in StoredTreeItem object queryable through the Pattern Tree by the given parsed URL pattern. """ - wrapped_item = StoredTreeItem(parsed_pattern, item) + wrapped_item = StoredTreeItem(item, parsed_pattern) def item_updater(item_set: t.Optional[StoredSet]) \ -> t.Optional[StoredSet]: @@ -276,36 +268,21 @@ class PatternTree(t.Generic[TreeStoredType]): return dc.replace(self, _by_scheme_and_port=new_root) - # def _register( - # self, - # url_pattern: str, - # item: TreeStoredType, - # register: bool = True - # ) -> 'TreeSelfType': - # """ - # .... - # """ - # tree = self - - # for parsed_pat in parse_pattern(url_pattern): - # wrapped_item = StoredTreeItem(parsed_pat, item) - # tree = tree._register_with_parsed_pattern( - # parsed_pat, - # wrapped_item, - # register - # ) - - # return tree - - def register(self, parsed_pattern: ParsedUrl, item: TreeStoredType) \ - -> 'TreeSelfType': + def register( + self: 'SelfType', + parsed_pattern: ParsedPattern, + item: TreeStoredType + ) -> 'SelfType': """ Make item queryable through the Pattern Tree by the given URL pattern. """ return self._register(parsed_pattern, item) - def deregister(self, parsed_pattern: ParsedUrl, item: TreeStoredType) \ - -> 'TreeSelfType': + def deregister( + self: 'SelfType', + parsed_pattern: ParsedPattern, + item: TreeStoredType + ) -> 'SelfType': """ Make item no longer queryable through the Pattern Tree by the given URL pattern. @@ -334,6 +311,3 @@ class PatternTree(t.Generic[TreeStoredType]): items = filter_by_trailing_slash(item_set, with_slash) if len(items) > 0: yield items - -# Below we define 1 type used by recursively-typed PatternTree. -TreeSelfType = PatternTree[TreeStoredType] diff --git a/src/hydrilla/proxy/addon.py b/src/hydrilla/proxy/addon.py index 7d6487b..16c2841 100644 --- a/src/hydrilla/proxy/addon.py +++ b/src/hydrilla/proxy/addon.py @@ -32,58 +32,70 @@ from addon script. # Enable using with Python 3.7. from __future__ import annotations -import os.path +import sys import typing as t import dataclasses as dc +import traceback as tb from threading import Lock from pathlib import Path from contextlib import contextmanager -from mitmproxy import http, addonmanager, ctx +# for mitmproxy 6.* +from mitmproxy.net import http +if not hasattr(http, 'Headers'): + # for mitmproxy 8.* + from mitmproxy import http # type: ignore + +from mitmproxy.http import HTTPFlow + +from mitmproxy import addonmanager, ctx from mitmproxy.script import concurrent -from .flow_handlers import make_flow_handler, FlowHandler -from .state import HaketiloState +from ..exceptions import HaketiloException from ..translations import smart_gettext as _ +from ..url_patterns import parse_url +from .state_impl import ConcreteHaketiloState +from . import policies -FlowHandlers = dict[int, FlowHandler] -StateUpdater = t.Callable[[HaketiloState], None] +DefaultGetValue = t.TypeVar('DefaultGetValue', object, None) -HTTPHandlerFun = t.Callable[ - ['HaketiloAddon', http.HTTPFlow], - t.Optional[StateUpdater] -] - -def http_event_handler(handler_fun: HTTPHandlerFun): - """....decorator""" - def wrapped_handler(self: 'HaketiloAddon', flow: http.HTTPFlow): +class MitmproxyHeadersWrapper(): + """....""" + def __init__(self, headers: http.Headers) -> None: """....""" - with self.configured_lock: - assert self.configured + self.headers = headers - assert self.state is not None + __getitem__ = lambda self, key: self.headers[key] + get_all = lambda self, key: self.headers.get_all(key) - state_updater = handler_fun(self, flow) + def get(self, key: str, default: DefaultGetValue = None) \ + -> t.Union[str, DefaultGetValue]: + """....""" + value = self.headers.get(key) - if state_updater is not None: - state_updater(self.state) + if value is None: + return default + else: + return t.cast(str, value) - return wrapped_handler + def items(self) -> t.Iterable[tuple[str, str]]: + """....""" + return self.headers.items(multi=True) @dc.dataclass class HaketiloAddon: """ ....... """ - configured: bool = False - configured_lock: Lock = dc.field(default_factory=Lock) + configured: bool = False + configured_lock: Lock = dc.field(default_factory=Lock) - state: t.Optional[HaketiloState] = None + flow_policies: dict[int, policies.Policy] = dc.field(default_factory=dict) + policies_lock: Lock = dc.field(default_factory=Lock) - flow_handlers: FlowHandlers = dc.field(default_factory=dict) - handlers_lock: Lock = dc.field(default_factory=Lock) + state: t.Optional[ConcreteHaketiloState] = None def load(self, loader: addonmanager.Loader) -> None: """....""" @@ -104,74 +116,165 @@ class HaketiloAddon: ctx.log.warn(_('haketilo_dir_already_configured')) return - haketilo_dir = Path(ctx.options.haketilo_dir) - self.state = HaketiloState(haketilo_dir / 'store') + try: + haketilo_dir = Path(ctx.options.haketilo_dir) + + self.state = ConcreteHaketiloState.make(haketilo_dir / 'store') + except Exception as e: + tb.print_exception(None, e, e.__traceback__) + sys.exit(1) - def assign_handler(self, flow: http.HTTPFlow, flow_handler: FlowHandler) \ - -> None: + self.configured = True + + def try_get_policy(self, flow: HTTPFlow, fail_ok: bool = True) -> \ + t.Optional[policies.Policy]: """....""" - with self.handlers_lock: - self.flow_handlers[id(flow)] = flow_handler + with self.policies_lock: + policy = self.flow_policies.get(id(flow)) + + if policy is None: + try: + parsed_url = parse_url(flow.request.url) + except HaketiloException: + if fail_ok: + return None + else: + raise + + assert self.state is not None + + policy = self.state.select_policy(parsed_url) + + with self.policies_lock: + self.flow_policies[id(flow)] = policy + + return policy + + def get_policy(self, flow: HTTPFlow) -> policies.Policy: + return t.cast(policies.Policy, self.try_get_policy(flow, fail_ok=False)) - def lookup_handler(self, flow: http.HTTPFlow) -> FlowHandler: + def forget_policy(self, flow: HTTPFlow) -> None: """....""" - with self.handlers_lock: - return self.flow_handlers[id(flow)] + with self.policies_lock: + self.flow_policies.pop(id(flow), None) - def forget_handler(self, flow: http.HTTPFlow) -> None: + @contextmanager + def http_safe_event_handling(self, flow: HTTPFlow) -> t.Iterator: """....""" - with self.handlers_lock: - self.flow_handlers.pop(id(flow), None) + with self.configured_lock: + assert self.configured + + try: + yield + except Exception as e: + tb_string = ''.join(tb.format_exception(None, e, e.__traceback__)) + error_text = _('err.proxy.unknown_error_{}_try_again')\ + .format(tb_string)\ + .encode() + flow.response = http.Response.make( + status_code = 500, + content = error_text, + headers = [(b'Content-Type', b'text/plain; charset=utf-8')] + ) + + self.forget_policy(flow) @concurrent - @http_event_handler - def requestheaders(self, flow: http.HTTPFlow) -> t.Optional[StateUpdater]: + def requestheaders(self, flow: HTTPFlow) -> None: + # TODO: don't account for mitmproxy 6 in the code + # Mitmproxy 6 causes even more strange behavior than described below. + # This cannot be easily worked around. Let's just use version 8 and + # make an APT package for it. """ - ..... + Under mitmproxy 8 this handler deduces an appropriate policy for flow's + URL and assigns it to the flow. Under mitmproxy 6 the URL is not yet + available at this point, so the handler effectively does nothing. """ - assert self.state is not None - - policy = self.state.select_policy(flow.request.url) + with self.http_safe_event_handling(flow): + policy = self.try_get_policy(flow) - flow_handler = make_flow_handler(flow, policy) - - self.assign_handler(flow, flow_handler) - - return flow_handler.on_requestheaders() + if policy is not None: + if not policy.process_request: + flow.request.stream = True + if policy.anticache: + flow.request.anticache() @concurrent - @http_event_handler - def request(self, flow: http.HTTPFlow) -> t.Optional[StateUpdater]: + def request(self, flow: HTTPFlow) -> None: """ .... """ - return self.lookup_handler(flow).on_request() + if flow.request.stream: + return - @concurrent - @http_event_handler - def responseheaders(self, flow: http.HTTPFlow) -> t.Optional[StateUpdater]: + with self.http_safe_event_handling(flow): + policy = self.get_policy(flow) + + request_info = policies.RequestInfo( + url = parse_url(flow.request.url), + method = flow.request.method, + headers = MitmproxyHeadersWrapper(flow.request.headers), + body = flow.request.get_content(strict=False) or b'' + ) + + result = policy.consume_request(request_info) + + if result is not None: + if isinstance(result, policies.ProducedRequest): + flow.request = http.Request.make( + url = result.url, + method = result.method, + headers = http.Headers(result.headers), + content = result.body + ) + else: + # isinstance(result, policies.ProducedResponse) + flow.response = http.Response.make( + status_code = result.status_code, + headers = http.Headers(result.headers), + content = result.body + ) + + def responseheaders(self, flow: HTTPFlow) -> None: """ ...... """ - return self.lookup_handler(flow).on_responseheaders() + assert flow.response is not None + + with self.http_safe_event_handling(flow): + policy = self.get_policy(flow) + + if not policy.process_response: + flow.response.stream = True @concurrent - @http_event_handler - def response(self, flow: http.HTTPFlow) -> t.Optional[StateUpdater]: + def response(self, flow: HTTPFlow) -> None: """ ...... """ - updater = self.lookup_handler(flow).on_response() + assert flow.response is not None - self.forget_handler(flow) + if flow.response.stream: + return - return updater + with self.http_safe_event_handling(flow): + policy = self.get_policy(flow) - @http_event_handler - def error(self, flow: http.HTTPFlow) -> None: - """....""" - self.forget_handler(flow) + response_info = policies.ResponseInfo( + url = parse_url(flow.request.url), + status_code = flow.response.status_code, + headers = MitmproxyHeadersWrapper(flow.response.headers), + body = flow.response.get_content(strict=False) or b'' + ) -addons = [ - HaketiloAddon() -] + result = policy.consume_response(response_info) + if result is not None: + flow.response.status_code = result.status_code + flow.response.headers = http.Headers(result.headers) + flow.response.set_content(result.body) + + self.forget_policy(flow) + + def error(self, flow: HTTPFlow) -> None: + """....""" + self.forget_policy(flow) diff --git a/src/hydrilla/proxy/csp.py b/src/hydrilla/proxy/csp.py new file mode 100644 index 0000000..59d93f2 --- /dev/null +++ b/src/hydrilla/proxy/csp.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Tools for working with Content Security Policy headers. +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +..... +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import re +import typing as t +import dataclasses as dc + +from immutables import Map, MapMutation + +from .policies.base import IHeaders + + +header_names_and_dispositions = ( + ('content-security-policy', 'enforce'), + ('content-security-policy-report-only', 'report'), + ('x-content-security-policy', 'enforce'), + ('x-content-security-policy', 'report'), + ('x-webkit-csp', 'enforce'), + ('x-webkit-csp', 'report') +) + +enforce_header_names_set = { + name for name, disposition in header_names_and_dispositions + if disposition == 'enforce' +} + +@dc.dataclass +class ContentSecurityPolicy: + directives: Map[str, t.Sequence[str]] + header_name: str + disposition: str + + def serialize(self) -> str: + """....""" + serialized_directives = [] + for name, value_list in self.directives.items(): + serialized_directives.append(f'{name} {" ".join(value_list)}') + + return ';'.join(serialized_directives) + + @staticmethod + def deserialize( + serialized: str, + header_name: str, + disposition: str = 'enforce' + ) -> 'ContentSecurityPolicy': + """....""" + # For more info, see: + # https://www.w3.org/TR/CSP3/#parse-serialized-policy + empty_directives: Map[str, t.Sequence[str]] = Map() + + directives = empty_directives.mutate() + + for serialized_directive in serialized.split(';'): + if not serialized_directive.isascii(): + continue + + serialized_directive = serialized_directive.strip() + if len(serialized_directive) == 0: + continue + + tokens = serialized_directive.split() + directive_name = tokens.pop(0).lower() + directive_value = tokens + + # Specs mention giving warnings for duplicate directive names but + # from our proxy's perspective this is not important right now. + if directive_name in directives: + continue + + directives[directive_name] = directive_value + + return ContentSecurityPolicy( + directives = directives.finish(), + header_name = header_name, + disposition = disposition + ) + +def extract(headers: IHeaders) -> tuple[ContentSecurityPolicy, ...]: + """....""" + csp_policies = [] + + for header_name, disposition in header_names_and_dispositions: + for serialized_list in headers.get_all(header_name): + for serialized in serialized_list.split(','): + policy = ContentSecurityPolicy.deserialize( + serialized, + header_name, + disposition + ) + + if policy.directives != Map(): + csp_policies.append(policy) + + return tuple(csp_policies) diff --git a/src/hydrilla/proxy/flow_handlers.py b/src/hydrilla/proxy/flow_handlers.py deleted file mode 100644 index 605c7f9..0000000 --- a/src/hydrilla/proxy/flow_handlers.py +++ /dev/null @@ -1,383 +0,0 @@ -# SPDX-License-Identifier: GPL-3.0-or-later - -# Logic for modifying mitmproxy's HTTP flows. -# -# This file is part of Hydrilla&Haketilo. -# -# Copyright (C) 2022 Wojtek Kosior -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <https://www.gnu.org/licenses/>. -# -# -# I, Wojtek Kosior, thereby promise not to sue for violation of this -# file's license. Although I request that you do not make use this code -# in a proprietary program, I am not going to enforce this in court. - -""" -This module's file gets passed to Mitmproxy as addon script and makes it serve -as Haketilo proxy. -""" - -# Enable using with Python 3.7. -from __future__ import annotations - -import re -import typing as t -import dataclasses as dc - -import bs4 # type: ignore - -from mitmproxy import http -from mitmproxy.net.http import Headers -from mitmproxy.script import concurrent - -from .state import HaketiloState -from . import policies - -StateUpdater = t.Callable[[HaketiloState], None] - -@dc.dataclass(frozen=True) -class FlowHandler: - """....""" - flow: http.HTTPFlow - policy: policies.Policy - - stream_request: bool = False - stream_response: bool = False - - def on_requestheaders(self) -> t.Optional[StateUpdater]: - """....""" - if self.stream_request: - self.flow.request.stream = True - - return None - - def on_request(self) -> t.Optional[StateUpdater]: - """....""" - return None - - def on_responseheaders(self) -> t.Optional[StateUpdater]: - """....""" - assert self.flow.response is not None - - if self.stream_response: - self.flow.response.stream = True - - return None - - def on_response(self) -> t.Optional[StateUpdater]: - """....""" - return None - -@dc.dataclass(frozen=True) -class FlowHandlerAllowScripts(FlowHandler): - """....""" - policy: policies.AllowPolicy - - stream_request: bool = True - stream_response: bool = True - -csp_header_names_and_dispositions = ( - ('content-security-policy', 'enforce'), - ('content-security-policy-report-only', 'report'), - ('x-content-security-policy', 'enforce'), - ('x-content-security-policy', 'report'), - ('x-webkit-csp', 'enforce'), - ('x-webkit-csp', 'report') -) - -csp_enforce_header_names_set = { - name for name, disposition in csp_header_names_and_dispositions - if disposition == 'enforce' -} - -@dc.dataclass -class ContentSecurityPolicy: - directives: dict[str, list[str]] - header_name: str - disposition: str - - @staticmethod - def deserialize( - serialized: str, - header_name: str, - disposition: str = 'enforce' - ) -> 'ContentSecurityPolicy': - """....""" - # For more info, see: - # https://www.w3.org/TR/CSP3/#parse-serialized-policy - directives = {} - - for serialized_directive in serialized.split(';'): - if not serialized_directive.isascii(): - continue - - serialized_directive = serialized_directive.strip() - if len(serialized_directive) == 0: - continue - - tokens = serialized_directive.split() - directive_name = tokens.pop(0).lower() - directive_value = tokens - - # Specs mention giving warnings for duplicate directive names but - # from our proxy's perspective this is not important right now. - if directive_name in directives: - continue - - directives[directive_name] = directive_value - - return ContentSecurityPolicy(directives, header_name, disposition) - - def serialize(self) -> str: - """....""" - serialized_directives = [] - for name, value_list in self.directives.items(): - serialized_directives.append(f'{name} {" ".join(value_list)}') - - return ';'.join(serialized_directives) - -def extract_csp(headers: Headers) -> tuple[ContentSecurityPolicy, ...]: - """....""" - csp_policies = [] - - for header_name, disposition in csp_header_names_and_dispositions: - for serialized_list in headers.get(header_name, ''): - for serialized in serialized_list.split(','): - policy = ContentSecurityPolicy.deserialize( - serialized, - header_name, - disposition - ) - - if policy.directives != {}: - csp_policies.append(policy) - - return tuple(csp_policies) - -csp_script_directive_names = ( - 'script-src', - 'script-src-elem', - 'script-src-attr' -) - -@dc.dataclass(frozen=True) -class FlowHandlerBlockScripts(FlowHandler): - policy: policies.BlockPolicy - - stream_request: bool = True - stream_response: bool = True - - def on_responseheaders(self) -> t.Optional[StateUpdater]: - """....""" - super().on_responseheaders() - - assert self.flow.response is not None - - csp_policies = extract_csp(self.flow.response.headers) - - for header_name, _ in csp_header_names_and_dispositions: - del self.flow.response.headers[header_name] - - for policy in csp_policies: - if policy.disposition != 'enforce': - continue - - policy.directives.pop('report-to') - policy.directives.pop('report-uri') - - self.flow.response.headers.add( - policy.header_name, - policy.serialize() - ) - - extra_csp = ';'.join(( - "script-src 'none'", - "script-src-elem 'none'", - "script-src-attr 'none'" - )) - - self.flow.response.headers.add('Content-Security-Policy', extra_csp) - - return None - -# For details of 'Content-Type' header's structure, see: -# https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.1 -content_type_reg = re.compile(r''' -^ -(?P<mime>[\w-]+/[\w-]+) -\s* -(?: - ; - (?:[^;]*;)* # match possible parameter other than "charset" -) -\s* -charset= # no whitespace allowed in parameter as per RFC -(?P<encoding> - [\w-]+ - | - "[\w-]+" # quotes are optional per RFC -) -(?:;[^;]+)* # match possible parameter other than "charset" -$ # forbid possible dangling characters after closing '"' -''', re.VERBOSE | re.IGNORECASE) - -def deduce_content_type(headers: Headers) \ - -> tuple[t.Optional[str], t.Optional[str]]: - """....""" - content_type = headers.get('content-type') - if content_type is None: - return (None, None) - - match = content_type_reg.match(content_type) - if match is None: - return (None, None) - - mime, encoding = match.group('mime'), match.group('encoding') - - if encoding is not None: - encoding = encoding.lower() - - return mime, encoding - -UTF8_BOM = b'\xEF\xBB\xBF' -BOMs = ( - (UTF8_BOM, 'utf-8'), - (b'\xFE\xFF', 'utf-16be'), - (b'\xFF\xFE', 'utf-16le') -) - -def block_attr(element: bs4.PageElement, atrr_name: str) -> None: - """....""" - # TODO: implement - pass - -@dc.dataclass(frozen=True) -class FlowHandlerInjectPayload(FlowHandler): - """....""" - policy: policies.PayloadPolicy - - stream_request: bool = True - - def __post_init__(self) -> None: - """....""" - script_src = f"script-src {self.policy.assets_base_url()}" - if self.policy.is_eval_allowed(): - script_src = f"{script_src} 'unsafe-eval'" - - self.new_csp = '; '.join(( - script_src, - "script-src-elem 'none'", - "script-src-attr 'none'" - )) - - def on_responseheaders(self) -> t.Optional[StateUpdater]: - """....""" - super().on_responseheaders() - - assert self.flow.response is not None - - for header_name, _ in csp_header_names_and_dispositions: - del self.flow.response.headers[header_name] - - self.flow.response.headers.add('Content-Security-Policy', self.new_csp) - - return None - - def on_response(self) -> t.Optional[StateUpdater]: - """....""" - super().on_response() - - assert self.flow.response is not None - - if self.flow.response.content is None: - return None - - mime, encoding = deduce_content_type(self.flow.response.headers) - if mime is None or 'html' not in mime: - return None - - # A UTF BOM overrides encoding specified by the header. - for bom, encoding_name in BOMs: - if self.flow.response.content.startswith(bom): - encoding = encoding_name - - soup = bs4.BeautifulSoup( - markup = self.flow.response.content, - from_encoding = encoding, - features = 'html5lib' - ) - - # Inject scripts. - script_parent = soup.find('body') or soup.find('html') - if script_parent is None: - return None - - for url in self.policy.script_urls(): - script_parent.append(bs4.Tag(name='script', attrs={'src': url})) - - # Remove Content Security Policy that could possibly block injected - # scripts. - for meta in soup.select('head meta[http-equiv]'): - header_name = meta.attrs.get('http-equiv', '').lower().strip() - if header_name in csp_enforce_header_names_set: - block_attr(meta, 'http-equiv') - block_attr(meta, 'content') - - # Appending a three-byte Byte Order Mark (BOM) will force the browser to - # decode this as UTF-8 regardless of the 'Content-Type' header. See: - # https://www.w3.org/International/tests/repository/html5/the-input-byte-stream/results-basics#precedence - self.flow.response.content = UTF8_BOM + soup.encode() - - return None - -@dc.dataclass(frozen=True) -class FlowHandlerMetaResource(FlowHandler): - """....""" - policy: policies.MetaResourcePolicy - - def on_request(self) -> t.Optional[StateUpdater]: - """....""" - super().on_request() - # TODO: implement - #self.flow.response = .... - - return None - -def make_flow_handler(flow: http.HTTPFlow, policy: policies.Policy) \ - -> FlowHandler: - """....""" - if isinstance(policy, policies.BlockPolicy): - return FlowHandlerBlockScripts(flow, policy) - - if isinstance(policy, policies.AllowPolicy): - return FlowHandlerAllowScripts(flow, policy) - - if isinstance(policy, policies.PayloadPolicy): - return FlowHandlerInjectPayload(flow, policy) - - assert isinstance(policy, policies.MetaResourcePolicy) - # def response_creator(request: http.HTTPRequest) -> http.HTTPResponse: - # """....""" - # replacement_details = make_replacement_resource( - # policy.replacement, - # request.path - # ) - - # return http.HTTPResponse.make( - # replacement_details.status_code, - # replacement_details.content, - # replacement_details.content_type - # ) - return FlowHandlerMetaResource(flow, policy) diff --git a/src/hydrilla/proxy/oversimplified_state_impl.py b/src/hydrilla/proxy/oversimplified_state_impl.py new file mode 100644 index 0000000..c082add --- /dev/null +++ b/src/hydrilla/proxy/oversimplified_state_impl.py @@ -0,0 +1,392 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (instantiatable HaketiloState subtype). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +This module contains logic for keeping track of all settings, rules, mappings +and resources. +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import secrets +import threading +import typing as t +import dataclasses as dc + +from pathlib import Path + +from ..pattern_tree import PatternTree +from .. import url_patterns +from .. import versions +from .. import item_infos +from .simple_dependency_satisfying import ItemsCollection, ComputedPayload +from . import state as st +from . import policies + + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class ConcreteRepoRef(st.RepoRef): + def remove(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + def update( + self, + state: st.HaketiloState, + *, + name: t.Optional[str] = None, + url: t.Optional[str] = None + ) -> ConcreteRepoRef: + raise NotImplementedError() + + def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteRepoIterationRef(st.RepoIterationRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteMappingRef(st.MappingRef): + def disable(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + def forget_enabled(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteMappingVersionRef(st.MappingVersionRef): + def enable(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteResourceRef(st.ResourceRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteResourceVersionRef(st.ResourceVersionRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcretePayloadRef(st.PayloadRef): + computed_payload: ComputedPayload = dc.field(hash=False, compare=False) + + def get_data(self, state: st.HaketiloState) -> st.PayloadData: + return t.cast(ConcreteHaketiloState, state).payloads_data[self.id] + + def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: + return 'to implement' + + def get_script_paths(self, state: st.HaketiloState) \ + -> t.Iterator[t.Sequence[str]]: + for resource_info in self.computed_payload.resources: + for file_spec in resource_info.scripts: + yield (resource_info.identifier, *file_spec.name.split('/')) + + def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ + -> t.Optional[st.FileData]: + if len(path) == 0: + raise st.MissingItemError() + + resource_identifier, *file_name_segments = path + + file_name = '/'.join(file_name_segments) + + script_sha256 = '' + + matched_resource_info = False + + for resource_info in self.computed_payload.resources: + if resource_info.identifier == resource_identifier: + matched_resource_info = True + + for script_spec in resource_info.scripts: + if script_spec.name == file_name: + script_sha256 = script_spec.sha256 + + break + + if not matched_resource_info: + raise st.MissingItemError(resource_identifier) + + if script_sha256 == '': + return None + + store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir + files_dir_path = store_dir_path / 'temporary_malcontent' / 'file' + file_path = files_dir_path / 'sha256' / script_sha256 + + return st.FileData( + type = 'application/javascript', + name = file_name, + contents = file_path.read_bytes() + ) + + +# @dc.dataclass(frozen=True, unsafe_hash=True) +# class DummyPayloadRef(ConcretePayloadRef): +# paths = { +# ('someresource', 'somefolder', 'somescript.js'): st.FileData( +# type = 'application/javascript', +# name = 'somefolder/somescript.js', +# contents = b'console.log("hello, mitmproxy")' +# ) +# } + +# def get_data(self, state: st.HaketiloState) -> st.PayloadData: +# parsed_pattern = next(url_patterns.parse_pattern('https://example.com')) + +# return st.PayloadData( +# payload_ref = self, +# mapping_installed = True, +# explicitly_enabled = True, +# unique_token = 'g54v45g456h4r', +# pattern = parsed_pattern, +# eval_allowed = True, +# cors_bypass_allowed = True +# ) + +# def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: +# return ConcreteMappingVersionRef('somemapping') + +# def get_file_paths(self, state: st.HaketiloState) \ +# -> t.Iterable[t.Sequence[str]]: +# return tuple(self.paths.keys()) + +# def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ +# -> t.Optional[st.FileData]: +# return self.paths[tuple(path)] + + +PolicyTree = PatternTree[policies.PolicyFactory] + +def register_payload( + policy_tree: PolicyTree, + payload_key: st.PayloadKey, + token: str +) -> PolicyTree: + """....""" + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + payload_key.pattern, + payload_policy_factory + ) + + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + payload_key.pattern.path_append(token, '***'), + resource_policy_factory + ) + + return policy_tree + +DataById = t.Mapping[str, st.PayloadData] + +AnyInfo = t.TypeVar('AnyInfo', item_infos.ResourceInfo, item_infos.MappingInfo) + +@dc.dataclass +class ConcreteHaketiloState(st.HaketiloState): + store_dir: Path + # settings: state.HaketiloGlobalSettings + policy_tree: PolicyTree = PatternTree() + payloads_data: DataById = dc.field(default_factory=dict) + + lock: threading.RLock = dc.field(default_factory=threading.RLock) + + def __post_init__(self) -> None: + def newest_item_path(item_dir: Path) -> t.Optional[Path]: + available_versions = tuple( + versions.parse_normalize_version(ver_path.name) + for ver_path in item_dir.iterdir() + if ver_path.is_file() + ) + + if available_versions == (): + return None + + newest_version = max(available_versions) + + version_path = item_dir / versions.version_string(newest_version) + + assert version_path.is_file() + + return version_path + + def read_items(dir_path: Path, item_class: t.Type[AnyInfo]) \ + -> t.Mapping[str, AnyInfo]: + items: dict[str, AnyInfo] = {} + + for resource_dir in dir_path.iterdir(): + if not resource_dir.is_dir(): + continue + + item_path = newest_item_path(resource_dir) + if item_path is None: + continue + + item = item_class.load(item_path) + + assert versions.version_string(item.version) == item_path.name + assert item.identifier == resource_dir.name + + items[item.identifier] = item + + return items + + malcontent_dir = self.store_dir / 'temporary_malcontent' + + items_collection = ItemsCollection( + read_items(malcontent_dir / 'resource', item_infos.ResourceInfo), + read_items(malcontent_dir / 'mapping', item_infos.MappingInfo) + ) + computed_payloads = items_collection.compute_payloads() + + payloads_data = {} + + for mapping_info, by_pattern in computed_payloads.items(): + for num, (pattern, payload) in enumerate(by_pattern.items()): + payload_id = f'{num}@{mapping_info.identifier}' + + ref = ConcretePayloadRef(payload_id, payload) + + data = st.PayloadData( + payload_ref = ref, + mapping_installed = True, + explicitly_enabled = True, + unique_token = secrets.token_urlsafe(16), + pattern = pattern, + eval_allowed = payload.allows_eval, + cors_bypass_allowed = payload.allows_cors_bypass + ) + + payloads_data[payload_id] = data + + key = st.PayloadKey( + payload_ref = ref, + mapping_identifier = mapping_info.identifier, + pattern = pattern + ) + + self.policy_tree = register_payload( + self.policy_tree, + key, + data.unique_token + ) + + self.payloads_data = payloads_data + + def get_repo(self, repo_id: str) -> st.RepoRef: + return ConcreteRepoRef(repo_id) + + def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef: + return ConcreteRepoIterationRef(repo_iteration_id) + + def get_mapping(self, mapping_id: str) -> st.MappingRef: + return ConcreteMappingRef(mapping_id) + + def get_mapping_version(self, mapping_version_id: str) \ + -> st.MappingVersionRef: + return ConcreteMappingVersionRef(mapping_version_id) + + def get_resource(self, resource_id: str) -> st.ResourceRef: + return ConcreteResourceRef(resource_id) + + def get_resource_version(self, resource_version_id: str) \ + -> st.ResourceVersionRef: + return ConcreteResourceVersionRef(resource_version_id) + + def get_payload(self, payload_id: str) -> st.PayloadRef: + return 'not implemented' + + def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \ + -> st.RepoRef: + raise NotImplementedError() + + def get_settings(self) -> st.HaketiloGlobalSettings: + return st.HaketiloGlobalSettings( + mapping_use_mode = st.MappingUseMode.AUTO, + default_allow_scripts = True, + repo_refresh_seconds = 0 + ) + + def update_settings( + self, + *, + mapping_use_mode: t.Optional[st.MappingUseMode] = None, + default_allow_scripts: t.Optional[bool] = None, + repo_refresh_seconds: t.Optional[int] = None + ) -> None: + raise NotImplementedError() + + def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy: + """....""" + with self.lock: + policy_tree = self.policy_tree + + try: + best_priority: int = 0 + best_policy: t.Optional[policies.Policy] = None + + for factories_set in policy_tree.search(url): + for stored_factory in sorted(factories_set): + factory = stored_factory.item + + policy = factory.make_policy(self) + + if policy.priority > best_priority: + best_priority = policy.priority + best_policy = policy + except Exception as e: + return policies.ErrorBlockPolicy( + builtin = True, + error = e + ) + + if best_policy is not None: + return best_policy + + if self.get_settings().default_allow_scripts: + return policies.FallbackAllowPolicy() + else: + return policies.FallbackBlockPolicy() + + @staticmethod + def make(store_dir: Path) -> 'ConcreteHaketiloState': + return ConcreteHaketiloState(store_dir=store_dir) diff --git a/src/hydrilla/proxy/policies/__init__.py b/src/hydrilla/proxy/policies/__init__.py new file mode 100644 index 0000000..66e07ee --- /dev/null +++ b/src/hydrilla/proxy/policies/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: CC0-1.0 + +# Copyright (C) 2022 Wojtek Kosior <koszko@koszko.org> +# +# Available under the terms of Creative Commons Zero v1.0 Universal. + +from .base import * + +from .payload import PayloadPolicyFactory + +from .payload_resource import PayloadResourcePolicyFactory + +from .rule import RuleBlockPolicyFactory, RuleAllowPolicyFactory + +from .fallback import FallbackAllowPolicy, FallbackBlockPolicy, ErrorBlockPolicy diff --git a/src/hydrilla/proxy/policies/base.py b/src/hydrilla/proxy/policies/base.py new file mode 100644 index 0000000..3bde6f2 --- /dev/null +++ b/src/hydrilla/proxy/policies/base.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Base defintions for policies for altering HTTP requests. +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +..... +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import sys + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + +import dataclasses as dc +import typing as t +import enum + +from abc import ABC, abstractmethod + +from immutables import Map + +from ...url_patterns import ParsedUrl +from .. import state + + +class PolicyPriority(int, enum.Enum): + """....""" + _ONE = 1 + _TWO = 2 + _THREE = 3 + +DefaultGetValue = t.TypeVar('DefaultGetValue', object, None) + +class IHeaders(Protocol): + """....""" + def __getitem__(self, key: str) -> str: ... + + def get_all(self, key: str) -> t.Iterable[str]: ... + + def get(self, key: str, default: DefaultGetValue = None) \ + -> t.Union[str, DefaultGetValue]: ... + + def items(self) -> t.Iterable[tuple[str, str]]: ... + +def encode_headers_items(headers: t.Iterable[tuple[str, str]]) \ + -> t.Iterable[tuple[bytes, bytes]]: + """....""" + for name, value in headers: + yield name.encode(), value.encode() + +@dc.dataclass(frozen=True) +class ProducedRequest: + """....""" + url: str + method: str + headers: t.Iterable[tuple[bytes, bytes]] + body: bytes + +@dc.dataclass(frozen=True) +class RequestInfo: + """....""" + url: ParsedUrl + method: str + headers: IHeaders + body: bytes + + def make_produced_request(self) -> ProducedRequest: + """....""" + return ProducedRequest( + url = self.url.orig_url, + method = self.method, + headers = encode_headers_items(self.headers.items()), + body = self.body + ) + +@dc.dataclass(frozen=True) +class ProducedResponse: + """....""" + status_code: int + headers: t.Iterable[tuple[bytes, bytes]] + body: bytes + +@dc.dataclass(frozen=True) +class ResponseInfo: + """....""" + url: ParsedUrl + status_code: int + headers: IHeaders + body: bytes + + def make_produced_response(self) -> ProducedResponse: + """....""" + return ProducedResponse( + status_code = self.status_code, + headers = encode_headers_items(self.headers.items()), + body = self.body + ) + +class Policy(ABC): + """....""" + process_request: t.ClassVar[bool] = False + process_response: t.ClassVar[bool] = False + + priority: t.ClassVar[PolicyPriority] + + @property + def anticache(self) -> bool: + return self.process_request or self.process_response + + def consume_request(self, request_info: RequestInfo) \ + -> t.Optional[t.Union[ProducedRequest, ProducedResponse]]: + """....""" + return None + + def consume_response(self, response_info: ResponseInfo) \ + -> t.Optional[ProducedResponse]: + """....""" + return None + + +# mypy needs to be corrected: +# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class PolicyFactory(ABC): + """....""" + builtin: bool + + @abstractmethod + def make_policy(self, haketilo_state: state.HaketiloState) \ + -> t.Optional[Policy]: + """....""" + ... + + def __lt__(self, other: 'PolicyFactory'): + """....""" + return sorting_keys.get(self.__class__.__name__, 999) < \ + sorting_keys.get(other.__class__.__name__, 999) + +sorting_order = ( + 'PayloadResourcePolicyFactory', + + 'PayloadPolicyFactory', + + 'RuleBlockPolicyFactory', + 'RuleAllowPolicyFactory', + + 'FallbackPolicyFactory' +) + +sorting_keys = Map((cls, name) for name, cls in enumerate(sorting_order)) diff --git a/src/hydrilla/proxy/policies.py b/src/hydrilla/proxy/policies/fallback.py index 5e9451b..75da61c 100644 --- a/src/hydrilla/proxy/policies.py +++ b/src/hydrilla/proxy/policies/fallback.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later -# Various policies for altering HTTP requests. +# Policies for blocking and allowing JS when no other policies match. # # This file is part of Hydrilla&Haketilo. # @@ -24,53 +24,37 @@ # file's license. Although I request that you do not make use this code # in a proprietary program, I am not going to enforce this in court. -import dataclasses as dc -import typing as t +""" +..... +""" -from abc import ABC +# Enable using with Python 3.7. +from __future__ import annotations -class Policy(ABC): - pass - -class PayloadPolicy(Policy): - """....""" - def assets_base_url(self) -> str: - """....""" - return 'https://example.com/static/' - - def script_urls(self) -> t.Sequence[str]: - """....""" - # TODO: implement - return ('https://example.com/static/somescript.js',) - - def is_eval_allowed(self) -> bool: - """....""" - # TODO: implement - return True +import dataclasses as dc +import typing as t +import enum -class MetaResourcePolicy(Policy): - pass +from abc import ABC, abstractmethod -class AllowPolicy(Policy): - pass +from .. import state +from . import base +from .rule import AllowPolicy, BlockPolicy -@dc.dataclass -class RuleAllowPolicy(AllowPolicy): - pattern: str class FallbackAllowPolicy(AllowPolicy): - pass - -class BlockPolicy(Policy): - pass + """.....""" + priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE -@dc.dataclass -class RuleBlockPolicy(BlockPolicy): - pattern: str class FallbackBlockPolicy(BlockPolicy): - pass + """....""" + priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE -@dc.dataclass + +@dc.dataclass(frozen=True) class ErrorBlockPolicy(BlockPolicy): + """....""" error: Exception + + builtin: bool = True diff --git a/src/hydrilla/proxy/policies/payload.py b/src/hydrilla/proxy/policies/payload.py new file mode 100644 index 0000000..d616f1b --- /dev/null +++ b/src/hydrilla/proxy/policies/payload.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Policies for applying payload injections to HTTP requests. +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +..... +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import dataclasses as dc +import typing as t +import re + +import bs4 # type: ignore + +from ...url_patterns import ParsedUrl +from .. import state +from .. import csp +from . import base + +@dc.dataclass(frozen=True) # type: ignore[misc] +class PayloadAwarePolicy(base.Policy): + """....""" + haketilo_state: state.HaketiloState + payload_data: state.PayloadData + + def assets_base_url(self, request_url: ParsedUrl): + """....""" + token = self.payload_data.unique_token + + base_path_segments = (*self.payload_data.pattern.path_segments, token) + + return f'{request_url.url_without_path}/{"/".join(base_path_segments)}/' + + +@dc.dataclass(frozen=True) # type: ignore[misc] +class PayloadAwarePolicyFactory(base.PolicyFactory): + """....""" + payload_key: state.PayloadKey + + @property + def payload_ref(self) -> state.PayloadRef: + """....""" + return self.payload_key.payload_ref + + def __lt__(self, other: base.PolicyFactory) -> bool: + """....""" + if isinstance(other, type(self)): + return self.payload_key < other.payload_key + + return super().__lt__(other) + + +# For details of 'Content-Type' header's structure, see: +# https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.1 +content_type_reg = re.compile(r''' +^ +(?P<mime>[\w-]+/[\w-]+) +\s* +(?: + ; + (?:[^;]*;)* # match possible parameter other than "charset" +) +\s* +charset= # no whitespace allowed in parameter as per RFC +(?P<encoding> + [\w-]+ + | + "[\w-]+" # quotes are optional per RFC +) +(?:;[^;]+)* # match possible parameter other than "charset" +$ # forbid possible dangling characters after closing '"' +''', re.VERBOSE | re.IGNORECASE) + +def deduce_content_type(headers: base.IHeaders) \ + -> tuple[t.Optional[str], t.Optional[str]]: + """....""" + content_type = headers.get('content-type') + if content_type is None: + return (None, None) + + match = content_type_reg.match(content_type) + if match is None: + return (None, None) + + mime, encoding = match.group('mime'), match.group('encoding') + + if encoding is not None: + encoding = encoding.lower() + + return mime, encoding + +UTF8_BOM = b'\xEF\xBB\xBF' +BOMs = ( + (UTF8_BOM, 'utf-8'), + (b'\xFE\xFF', 'utf-16be'), + (b'\xFF\xFE', 'utf-16le') +) + +def block_attr(element: bs4.PageElement, attr_name: str) -> None: + """ + Disable HTML node attributes by prepending `blocked-'. This allows them to + still be relatively easily accessed in case they contain some useful data. + """ + blocked_value = element.attrs.pop(attr_name, None) + + while blocked_value is not None: + attr_name = f'blocked-{attr_name}' + next_blocked_value = element.attrs.pop(attr_name, None) + element.attrs[attr_name] = blocked_value + + blocked_value = next_blocked_value + +@dc.dataclass(frozen=True) +class PayloadInjectPolicy(PayloadAwarePolicy): + """....""" + process_response: t.ClassVar[bool] = True + + priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._TWO + + def _new_csp(self, request_url: ParsedUrl) -> str: + """....""" + assets_base = self.assets_base_url(request_url) + + script_src = f"script-src {assets_base}" + + if self.payload_data.eval_allowed: + script_src = f"{script_src} 'unsafe-eval'" + + return '; '.join(( + script_src, + "script-src-elem 'none'", + "script-src-attr 'none'" + )) + + def _modify_headers(self, response_info: base.ResponseInfo) \ + -> t.Iterable[tuple[bytes, bytes]]: + """....""" + for header_name, header_value in response_info.headers.items(): + if header_name.lower() not in csp.header_names_and_dispositions: + yield header_name.encode(), header_value.encode() + + new_csp = self._new_csp(response_info.url) + + yield b'Content-Security-Policy', new_csp.encode() + + def _script_urls(self, url: ParsedUrl) -> t.Iterable[str]: + """....""" + base_url = self.assets_base_url(url) + payload_ref = self.payload_data.payload_ref + + for path in payload_ref.get_script_paths(self.haketilo_state): + yield base_url + '/'.join(('static', *path)) + + def _modify_body( + self, + url: ParsedUrl, + body: bytes, + encoding: t.Optional[str] + ) -> bytes: + """....""" + soup = bs4.BeautifulSoup( + markup = body, + from_encoding = encoding, + features = 'html5lib' + ) + + # Inject scripts. + script_parent = soup.find('body') or soup.find('html') + if script_parent is None: + return body + + for script_url in self._script_urls(url): + tag = bs4.Tag(name='script', attrs={'src': script_url}) + script_parent.append(tag) + + # Remove Content Security Policy that could possibly block injected + # scripts. + for meta in soup.select('head meta[http-equiv]'): + header_name = meta.attrs.get('http-equiv', '').lower().strip() + if header_name in csp.enforce_header_names_set: + block_attr(meta, 'http-equiv') + block_attr(meta, 'content') + + # Appending a three-byte Byte Order Mark (BOM) will force the browser to + # decode this as UTF-8 regardless of the 'Content-Type' header. See: + # https://www.w3.org/International/tests/repository/html5/the-input-byte-stream/results-basics#precedence + return UTF8_BOM + soup.encode() + + def _consume_response_unsafe(self, response_info: base.ResponseInfo) \ + -> base.ProducedResponse: + """....""" + new_response = response_info.make_produced_response() + + new_headers = self._modify_headers(response_info) + + new_response = dc.replace(new_response, headers=new_headers) + + mime, encoding = deduce_content_type(response_info.headers) + if mime is None or 'html' not in mime.lower(): + return new_response + + data = response_info.body + if data is None: + data = b'' + + # A UTF BOM overrides encoding specified by the header. + for bom, encoding_name in BOMs: + if data.startswith(bom): + encoding = encoding_name + + new_data = self._modify_body(response_info.url, data, encoding) + + return dc.replace(new_response, body=new_data) + + def consume_response(self, response_info: base.ResponseInfo) \ + -> base.ProducedResponse: + """....""" + try: + return self._consume_response_unsafe(response_info) + except Exception as e: + # TODO: actually describe the errors + import traceback + + error_info_list = traceback.format_exception( + type(e), + e, + e.__traceback__ + ) + + return base.ProducedResponse( + 500, + ((b'Content-Type', b'text/plain; charset=utf-8'),), + '\n'.join(error_info_list).encode() + ) + + +class AutoPayloadInjectPolicy(PayloadInjectPolicy): + """....""" + priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE + + def _modify_body( + self, + url: ParsedUrl, + body: bytes, + encoding: t.Optional[str] + ) -> bytes: + """....""" + payload_ref = self.payload_data.payload_ref + mapping_ref = payload_ref.get_mapping(self.haketilo_state) + mapping_ref.enable(self.haketilo_state) + + return super()._modify_body(url, body, encoding) + + +@dc.dataclass(frozen=True) +class PayloadSuggestPolicy(PayloadAwarePolicy): + """....""" + priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE + + def make_response(self, request_info: base.RequestInfo) \ + -> base.ProducedResponse: + """....""" + # TODO: implement + return base.ProducedResponse(200, ((b'a', b'b'),), b'') + + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class PayloadPolicyFactory(PayloadAwarePolicyFactory): + """....""" + def make_policy(self, haketilo_state: state.HaketiloState) \ + -> t.Optional[base.Policy]: + """....""" + try: + payload_data = self.payload_ref.get_data(haketilo_state) + except: + return None + + if payload_data.explicitly_enabled: + return PayloadInjectPolicy(haketilo_state, payload_data) + + mode = haketilo_state.get_settings().mapping_use_mode + + if mode == state.MappingUseMode.QUESTION: + return PayloadSuggestPolicy(haketilo_state, payload_data) + + if mode == state.MappingUseMode.WHEN_ENABLED: + return None + + # mode == state.MappingUseMode.AUTO + return AutoPayloadInjectPolicy(haketilo_state, payload_data) diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py new file mode 100644 index 0000000..84d0919 --- /dev/null +++ b/src/hydrilla/proxy/policies/payload_resource.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Policies for resolving HTTP requests with local resources. +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +..... + +We make file resources available to HTTP clients by mapping them +at: + http(s)://<pattern-matching_origin>/<pattern_path>/<token>/ +where <token> is a per-session secret unique for every mapping. +For example, a payload with pattern like the following: + http*://***.example.com/a/b/** +Could cause resources to be mapped (among others) at each of: + https://example.com/a/b/**/Da2uiF2UGfg/ + https://www.example.com/a/b/**/Da2uiF2UGfg/ + http://gnome.vs.kde.example.com/a/b/**/Da2uiF2UGfg/ + +Unauthorized web pages running in the user's browser are exected to be +unable to guess the secret. This way we stop them from spying on the +user and from interfering with Haketilo's normal operation. + +This is only a soft prevention method. With some mechanisms +(e.g. service workers), under certain scenarios, it might be possible +to bypass it. Thus, to make the risk slightly smaller, we also block +the unauthorized accesses that we can detect. + +Since a web page authorized to access the resources may only be served +when the corresponding mapping is enabled (or AUTO mode is on), we +consider accesses to non-enabled mappings' resources a security breach +and block them by responding with 403 Not Found. +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import dataclasses as dc +import typing as t + +from ...translations import smart_gettext as _ +from .. import state +from . import base +from .payload import PayloadAwarePolicy, PayloadAwarePolicyFactory + + +@dc.dataclass(frozen=True) +class PayloadResourcePolicy(PayloadAwarePolicy): + """....""" + process_request: t.ClassVar[bool] = True + + priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._THREE + + def _make_file_resource_response(self, path: tuple[str, ...]) \ + -> base.ProducedResponse: + """....""" + try: + file_data = self.payload_data.payload_ref.get_file_data( + self.haketilo_state, + path + ) + except state.MissingItemError: + return resource_blocked_response + + if file_data is None: + return base.ProducedResponse( + 404, + [(b'Content-Type', b'text/plain; charset=utf-8')], + _('api.file_not_found').encode() + ) + + return base.ProducedResponse( + 200, + ((b'Content-Type', file_data.type.encode()),), + file_data.contents + ) + + def consume_request(self, request_info: base.RequestInfo) \ + -> base.ProducedResponse: + """....""" + # Payload resource pattern has path of the form: + # "/some/arbitrary/segments/<per-session_token>/***" + # + # Corresponding requests shall have path of the form: + # "/some/arbitrary/segments/<per-session_token>/actual/resource/path" + # + # Here we need to extract the "/actual/resource/path" part. + segments_to_drop = len(self.payload_data.pattern.path_segments) + 1 + resource_path = request_info.url.path_segments[segments_to_drop:] + + if resource_path == (): + return resource_blocked_response + elif resource_path[0] == 'static': + return self._make_file_resource_response(resource_path[1:]) + elif resource_path[0] == 'api': + # TODO: implement Haketilo APIs + return resource_blocked_response + else: + return resource_blocked_response + + +resource_blocked_response = base.ProducedResponse( + 403, + [(b'Content-Type', b'text/plain; charset=utf-8')], + _('api.resource_not_enabled_for_access').encode() +) + +@dc.dataclass(frozen=True) +class BlockedResponsePolicy(base.Policy): + """....""" + process_request: t.ClassVar[bool] = True + + priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._THREE + + def consume_request(self, request_info: base.RequestInfo) \ + -> base.ProducedResponse: + """....""" + return resource_blocked_response + + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class PayloadResourcePolicyFactory(PayloadAwarePolicyFactory): + """....""" + def make_policy(self, haketilo_state: state.HaketiloState) \ + -> t.Union[PayloadResourcePolicy, BlockedResponsePolicy]: + """....""" + try: + payload_data = self.payload_ref.get_data(haketilo_state) + except state.MissingItemError: + return BlockedResponsePolicy() + + if not payload_data.explicitly_enabled and \ + haketilo_state.get_settings().mapping_use_mode != \ + state.MappingUseMode.AUTO: + return BlockedResponsePolicy() + + return PayloadResourcePolicy(haketilo_state, payload_data) + + diff --git a/src/hydrilla/proxy/policies/rule.py b/src/hydrilla/proxy/policies/rule.py new file mode 100644 index 0000000..eb70147 --- /dev/null +++ b/src/hydrilla/proxy/policies/rule.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Policies for blocking and allowing JS in pages fetched with HTTP. +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +..... +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import dataclasses as dc +import typing as t + +from ...url_patterns import ParsedPattern +from .. import csp +from .. import state +from . import base + + +class AllowPolicy(base.Policy): + """....""" + priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._TWO + +class BlockPolicy(base.Policy): + """....""" + process_response: t.ClassVar[bool] = True + + priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._TWO + + def _modify_headers(self, response_info: base.ResponseInfo) \ + -> t.Iterable[tuple[bytes, bytes]]: + """....""" + csp_policies = csp.extract(response_info.headers) + + for header_name, header_value in response_info.headers.items(): + if header_name.lower() not in csp.header_names_and_dispositions: + yield header_name.encode(), header_value.encode() + + for policy in csp_policies: + if policy.disposition != 'enforce': + continue + + directives = policy.directives.mutate() + directives.pop('report-to', None) + directives.pop('report-uri', None) + + policy = dc.replace(policy, directives=directives.finish()) + + yield policy.header_name.encode(), policy.serialize().encode() + + extra_csp = ';'.join(( + "script-src 'none'", + "script-src-elem 'none'", + "script-src-attr 'none'" + )) + + yield b'Content-Security-Policy', extra_csp.encode() + + + def consume_response(self, response_info: base.ResponseInfo) \ + -> base.ProducedResponse: + """....""" + new_response = response_info.make_produced_response() + + new_headers = self._modify_headers(response_info) + + return dc.replace(new_response, headers=new_headers) + +@dc.dataclass(frozen=True) +class RuleAllowPolicy(AllowPolicy): + """....""" + pattern: ParsedPattern + + +@dc.dataclass(frozen=True) +class RuleBlockPolicy(BlockPolicy): + """....""" + pattern: ParsedPattern + + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class RulePolicyFactory(base.PolicyFactory): + """....""" + pattern: ParsedPattern + + def __lt__(self, other: base.PolicyFactory) -> bool: + """....""" + if type(other) is type(self): + return super().__lt__(other) + + assert isinstance(other, RulePolicyFactory) + + return self.pattern < other.pattern + + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class RuleBlockPolicyFactory(RulePolicyFactory): + """....""" + def make_policy(self, haketilo_state: state.HaketiloState) \ + -> RuleBlockPolicy: + """....""" + return RuleBlockPolicy(self.pattern) + + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class RuleAllowPolicyFactory(RulePolicyFactory): + """....""" + def make_policy(self, haketilo_state: state.HaketiloState) \ + -> RuleAllowPolicy: + """....""" + return RuleAllowPolicy(self.pattern) diff --git a/src/hydrilla/proxy/simple_dependency_satisfying.py b/src/hydrilla/proxy/simple_dependency_satisfying.py new file mode 100644 index 0000000..9716fe5 --- /dev/null +++ b/src/hydrilla/proxy/simple_dependency_satisfying.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy payloads dependency resolution. +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +..... +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import dataclasses as dc +import typing as t + +from .. import item_infos +from .. import url_patterns + +@dc.dataclass +class ComputedPayload: + resources: list[item_infos.ResourceInfo] = dc.field(default_factory=list) + + allows_eval: bool = False + allows_cors_bypass: bool = False + +SingleMappingPayloads = t.Mapping[ + url_patterns.ParsedPattern, + ComputedPayload +] + +ComputedPayloadsDict = dict[ + item_infos.MappingInfo, + SingleMappingPayloads +] + +empty_identifiers_set: set[str] = set() + +@dc.dataclass(frozen=True) +class ItemsCollection: + resources: t.Mapping[str, item_infos.ResourceInfo] + mappings: t.Mapping[str, item_infos.MappingInfo] + + def _satisfy_payload_resource_rec( + self, + resource_identifier: str, + processed_resources: set[str], + computed_payload: ComputedPayload + ) -> t.Optional[ComputedPayload]: + if resource_identifier in processed_resources: + # We forbid circular dependencies. + return None + + resource_info = self.resources.get(resource_identifier) + if resource_info is None: + return None + + if resource_info in computed_payload.resources: + return computed_payload + + processed_resources.add(resource_identifier) + + if resource_info.allows_eval: + computed_payload.allows_eval = True + + if resource_info.allows_cors_bypass: + computed_payload.allows_cors_bypass = True + + for dependency_spec in resource_info.dependencies: + if self._satisfy_payload_resource_rec( + dependency_spec.identifier, + processed_resources, + computed_payload + ) is None: + return None + + processed_resources.remove(resource_identifier) + + computed_payload.resources.append(resource_info) + + return computed_payload + + def _satisfy_payload_resource(self, resource_identifier: str) \ + -> t.Optional[ComputedPayload]: + return self._satisfy_payload_resource_rec( + resource_identifier, + set(), + ComputedPayload() + ) + + def _compute_payloads_no_mapping_requirements(self) -> ComputedPayloadsDict: + computed_result: ComputedPayloadsDict = ComputedPayloadsDict() + + for mapping_info in self.mappings.values(): + by_pattern: dict[url_patterns.ParsedPattern, ComputedPayload] = {} + + failure = False + + for pattern, resource_spec in mapping_info.payloads.items(): + computed_payload = self._satisfy_payload_resource( + resource_spec.identifier + ) + if computed_payload is None: + failure = True + break + + if mapping_info.allows_eval: + computed_payload.allows_eval = True + + if mapping_info.allows_cors_bypass: + computed_payload.allows_cors_bypass = True + + by_pattern[pattern] = computed_payload + + if not failure: + computed_result[mapping_info] = by_pattern + + return computed_result + + def _mark_mappings_bad( + self, + identifier: str, + reverse_mapping_deps: t.Mapping[str, set[str]], + bad_mappings: set[str] + ) -> None: + if identifier in bad_mappings: + return + + bad_mappings.add(identifier) + + for requiring in reverse_mapping_deps.get(identifier, ()): + self._mark_mappings_bad( + requiring, + reverse_mapping_deps, + bad_mappings + ) + + def compute_payloads(self) -> ComputedPayloadsDict: + computed_result = self._compute_payloads_no_mapping_requirements() + + reverse_mapping_deps: dict[str, set[str]] = {} + + for mapping_info, by_pattern in computed_result.items(): + specs_to_resolve = [*mapping_info.required_mappings] + + for computed_payload in by_pattern.values(): + for resource_info in computed_payload.resources: + specs_to_resolve.extend(resource_info.required_mappings) + + for required_mapping_spec in specs_to_resolve: + identifier = required_mapping_spec.identifier + requiring = reverse_mapping_deps.setdefault(identifier, set()) + requiring.add(mapping_info.identifier) + + bad_mappings: set[str] = set() + + for required_identifier in reverse_mapping_deps.keys(): + if self.mappings.get(required_identifier) not in computed_result: + self._mark_mappings_bad( + required_identifier, + reverse_mapping_deps, + bad_mappings + ) + + for identifier in bad_mappings: + if identifier in self.mappings: + computed_result.pop(self.mappings[identifier], None) + + return computed_result diff --git a/src/hydrilla/proxy/state.py b/src/hydrilla/proxy/state.py index fc01536..e22c9fe 100644 --- a/src/hydrilla/proxy/state.py +++ b/src/hydrilla/proxy/state.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later -# Haketilo proxy data and configuration. +# Haketilo proxy data and configuration (interface definition through abstract +# class). # # This file is part of Hydrilla&Haketilo. # @@ -25,49 +26,271 @@ # in a proprietary program, I am not going to enforce this in court. """ -This module contains logic for keeping track of all settings, rules, mappings -and resources. +This module defines API for keeping track of all settings, rules, mappings and +resources. """ # Enable using with Python 3.7. from __future__ import annotations -import typing as t import dataclasses as dc +import typing as t -from threading import Lock from pathlib import Path +from abc import ABC, abstractmethod +from enum import Enum + +from immutables import Map + +from ..versions import VerTuple +from ..url_patterns import ParsedPattern + + +class EnabledStatus(Enum): + """ + ENABLED - User wished to always apply given mapping when it matched. + + DISABLED - User wished to never apply given mapping. + + AUTO_ENABLED - User has not configured given mapping but it will still be + used when automatic application of mappings is turned on. + + NO_MARK - User has not configured given mapping and it won't be used. + """ + ENABLED = 'E' + DISABLED = 'D' + AUTO_ENABLED = 'A' + NO_MARK = 'N' + +@dc.dataclass(frozen=True, unsafe_hash=True) +class Ref: + """....""" + id: str + + +# mypy needs to be corrected: +# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class RepoRef(Ref): + """....""" + @abstractmethod + def remove(self, state: 'HaketiloState') -> None: + """....""" + ... + + @abstractmethod + def update( + self, + state: 'HaketiloState', + *, + name: t.Optional[str] = None, + url: t.Optional[str] = None + ) -> 'RepoRef': + """....""" + ... + + @abstractmethod + def refresh(self, state: 'HaketiloState') -> 'RepoIterationRef': + """....""" + ... -from ..pattern_tree import PatternTree -from .store import HaketiloStore -from . import policies -def make_pattern_tree_with_builtin_policies() -> PatternTree[policies.Policy]: +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class RepoIterationRef(Ref): """....""" - # TODO: implement - return PatternTree() + pass -tree_field = dc.field(default_factory=make_pattern_tree_with_builtin_policies) -@dc.dataclass -class HaketiloState(HaketiloStore): +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class MappingRef(Ref): """....""" - pattern_tree: PatternTree[policies.Policy] = tree_field - default_allow: bool = False + @abstractmethod + def disable(self, state: 'HaketiloState') -> None: + """....""" + ... + + @abstractmethod + def forget_enabled(self, state: 'HaketiloState') -> None: + """....""" + ... - state_lock: Lock = dc.field(default_factory=Lock) - def select_policy(self, url: str, allow_disabled=False) -> policies.Policy: +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class MappingVersionRef(Ref): + """....""" + @abstractmethod + def enable(self, state: 'HaketiloState') -> None: """....""" - with self.state_lock: - pattern_tree = self.pattern_tree + ... + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class ResourceRef(Ref): + """....""" + pass + - try: - for policy_set in pattern_tree.search(url): - # if policy.enabled or allow_disabled: - # return policy - pass +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class ResourceVersionRef(Ref): + """....""" + pass - return policies.FallbackBlockPolicy() - except Exception as e: - return policies.ErrorBlockPolicy(e) + +@dc.dataclass(frozen=True) +class PayloadKey: + """....""" + payload_ref: 'PayloadRef' + + mapping_identifier: str + # mapping_version: VerTuple + # mapping_repo: str + # mapping_repo_iteration: int + pattern: ParsedPattern + + def __lt__(self, other: 'PayloadKey') -> bool: + """....""" + return ( + self.mapping_identifier, + # other.mapping_version, + # self.mapping_repo, + # other.mapping_repo_iteration, + self.pattern + ) < ( + other.mapping_identifier, + # self.mapping_version, + # other.mapping_repo, + # self.mapping_repo_iteration, + other.pattern + ) + +@dc.dataclass(frozen=True) +class PayloadData: + """....""" + payload_ref: 'PayloadRef' + + mapping_installed: bool + explicitly_enabled: bool + unique_token: str + pattern: ParsedPattern + eval_allowed: bool + cors_bypass_allowed: bool + +@dc.dataclass(frozen=True) +class FileData: + type: str + name: str + contents: bytes + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class PayloadRef(Ref): + """....""" + @abstractmethod + def get_data(self, state: 'HaketiloState') -> PayloadData: + """....""" + ... + + @abstractmethod + def get_mapping(self, state: 'HaketiloState') -> MappingVersionRef: + """....""" + ... + + @abstractmethod + def get_script_paths(self, state: 'HaketiloState') \ + -> t.Iterable[t.Sequence[str]]: + """....""" + ... + + @abstractmethod + def get_file_data(self, state: 'HaketiloState', path: t.Sequence[str]) \ + -> t.Optional[FileData]: + """....""" + ... + + +class MappingUseMode(Enum): + """ + AUTO - Apply mappings except for those explicitly disabled. + + WHEN_ENABLED - Only apply mappings explicitly marked as enabled. Don't apply + unmarked nor explicitly disabled mappings. + + QUESTION - Automatically apply mappings that are explicitly enabled. Ask + whether to enable unmarked mappings. Don't apply explicitly disabled + ones. + """ + AUTO = 'A' + WHEN_ENABLED = 'W' + QUESTION = 'Q' + +@dc.dataclass(frozen=True) +class HaketiloGlobalSettings: + """....""" + mapping_use_mode: MappingUseMode + default_allow_scripts: bool + repo_refresh_seconds: int + + +class MissingItemError(ValueError): + """....""" + pass + + +class HaketiloState(ABC): + """....""" + @abstractmethod + def get_repo(self, repo_id: str) -> RepoRef: + """....""" + ... + + @abstractmethod + def get_repo_iteration(self, repo_iteration_id: str) -> RepoIterationRef: + """....""" + ... + + @abstractmethod + def get_mapping(self, mapping_id: str) -> MappingRef: + """....""" + ... + + @abstractmethod + def get_mapping_version(self, mapping_version_id: str) \ + -> MappingVersionRef: + """....""" + ... + + @abstractmethod + def get_resource(self, resource_id: str) -> ResourceRef: + """....""" + ... + + @abstractmethod + def get_resource_version(self, resource_version_id: str) \ + -> ResourceVersionRef: + """....""" + ... + + @abstractmethod + def get_payload(self, payload_id: str) -> PayloadRef: + """....""" + ... + + @abstractmethod + def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \ + -> RepoRef: + """....""" + ... + + @abstractmethod + def get_settings(self) -> HaketiloGlobalSettings: + """....""" + ... + + @abstractmethod + def update_settings( + self, + *, + mapping_use_mode: t.Optional[MappingUseMode] = None, + default_allow_scripts: t.Optional[bool] = None, + repo_refresh_seconds: t.Optional[int] = None + ) -> None: + """....""" + ... diff --git a/src/hydrilla/proxy/state_impl.py b/src/hydrilla/proxy/state_impl.py new file mode 100644 index 0000000..059fee9 --- /dev/null +++ b/src/hydrilla/proxy/state_impl.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration. +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +This module contains logic for keeping track of all settings, rules, mappings +and resources. +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import secrets +import threading +import typing as t +import dataclasses as dc + +from pathlib import Path + +from immutables import Map + +from ..url_patterns import ParsedUrl +from ..pattern_tree import PatternTree +from .store import HaketiloStore +from . import state +from . import policies + + +PolicyTree = PatternTree[policies.PolicyFactory] + + +def register_builtin_policies(policy_tree: PolicyTree) -> PolicyTree: + """....""" + # TODO: implement + pass + + +def register_payload( + policy_tree: PolicyTree, + payload_ref: state.PayloadRef, + token: str +) -> tuple[PolicyTree, t.Iterable[policies.PolicyFactory]]: + """....""" + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_ref = payload_ref + ) + + policy_tree = policy_tree.register( + payload_ref.pattern, + payload_policy_factory + ) + + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_ref = payload_ref + ) + + policy_tree = policy_tree.register( + payload_ref.pattern.path_append(token, '***'), + resource_policy_factory + ) + + return policy_tree, (payload_policy_factory, resource_policy_factory) + + +def register_mapping( + policy_tree: PolicyTree, + payload_refs: t.Iterable[state.PayloadRef], + token: str +) -> tuple[PolicyTree, t.Iterable[policies.PolicyFactory]]: + """....""" + policy_factories: list[policies.PolicyFactory] = [] + + for payload_ref in payload_refs: + policy_tree, factories = register_payload( + policy_tree, + payload_ref, + token + ) + + policy_factories.extend(factories) + + return policy_tree, policy_factories + + +@dc.dataclass(frozen=True) +class RichMappingData(state.MappingData): + """....""" + policy_factories: t.Iterable[policies.PolicyFactory] + + +# @dc.dataclass(frozen=True) +# class HaketiloData: +# settings: state.HaketiloGlobalSettings +# policy_tree: PolicyTree +# mappings_data: Map[state.MappingRef, RichMappingData] = Map() + + +MonitoredMethodType = t.TypeVar('MonitoredMethodType', bound=t.Callable) + +def with_lock(wrapped_fun: MonitoredMethodType) -> MonitoredMethodType: + """....""" + def wrapper(self: 'ConcreteHaketiloState', *args, **kwargs): + """....""" + with self.lock: + return wrapped_fun(self, *args, **kwargs) + + return t.cast(MonitoredMethodType, wrapper) + +@dc.dataclass +class ConcreteHaketiloState(state.HaketiloState): + """....""" + store: HaketiloStore + settings: state.HaketiloGlobalSettings + + policy_tree: PolicyTree = PatternTree() + mappings_data: Map[state.MappingRef, RichMappingData] = Map() + + lock: threading.RLock = dc.field(default_factory=threading.RLock) + + def __post_init__(self) -> None: + """....""" + self.policy_tree = register_builtin_policies(self.policy_tree) + + self._init_mappings() + + def _init_mappings(self) -> None: + """....""" + store_mappings_data = self.store.load_installed_mappings_data() + + payload_items = self.store.load_payloads_data().items() + for mapping_ref, payload_refs in payload_items: + installed = True + enabled_status = store_mappings_data.get(mapping_ref) + + if enabled_status is None: + installed = False + enabled_status = state.EnabledStatus.NO_MARK + + self._register_mapping( + mapping_ref, + payload_refs, + enabled = enabled_status, + installed = installed + ) + + @with_lock + def _register_mapping( + self, + mapping_ref: state.MappingRef, + payload_refs: t.Iterable[state.PayloadRef], + enabled: state.EnabledStatus, + installed: bool + ) -> None: + """....""" + token = secrets.token_urlsafe(8) + + self.policy_tree, factories = register_mapping( + self.policy_tree, + payload_refs, + token + ) + + runtime_data = RichMappingData( + mapping_ref = mapping_ref, + enabled_status = enabled, + unique_token = token, + policy_factories = factories + ) + + self.mappings_data = self.mappings_data.set(mapping_ref, runtime_data) + + @with_lock + def get_mapping_data(self, mapping_ref: state.MappingRef) \ + -> state.MappingData: + try: + return self.mappings_data[mapping_ref] + except KeyError: + raise state.MissingItemError('no such mapping') + + @with_lock + def get_payload_data(self, payload_ref: state.PayloadRef) \ + -> state.PayloadData: + # TODO!!! + try: + return t.cast(state.PayloadData, None) + except: + raise state.MissingItemError('no such payload') + + @with_lock + def get_file_paths(self, payload_ref: state.PayloadRef) \ + -> t.Iterable[t.Sequence[str]]: + # TODO!!! + return [] + + @with_lock + def get_file_data( + self, + payload_ref: state.PayloadRef, + file_path: t.Sequence[str] + ) -> t.Optional[state.FileData]: + if len(file_path) == 0: + raise state.MissingItemError('empty file path') + + path_str = '/'.join(file_path[1:]) + + return self.store.load_file_data(payload_ref, file_path[0], path_str) + + @with_lock + def ensure_installed(self, mapping: state.MappingRef) -> None: + # TODO!!! + pass + + @with_lock + def get_settings(self) -> state.HaketiloGlobalSettings: + return self.settings + + @with_lock + def update_settings(self, updater: state.SettingsUpdater) -> None: + new_settings = updater(self.settings) + + self.store.write_global_settings(new_settings) + + self.settings = new_settings + + def select_policy(self, url: str) -> policies.Policy: + """....""" + with self.lock: + policy_tree = self.policy_tree + + try: + best_priority: int = 0 + best_policy: t.Optional[policies.Policy] = None + + for factories_set in policy_tree.search(url): + for stored_factory in sorted(factories_set): + factory = stored_factory.item + + policy = factory.make_policy(self) + policy_priority = policy.priority() + + if policy_priority > best_priority: + best_priority = policy_priority + best_policy = policy + except Exception as e: + return policies.ErrorBlockPolicy( + builtin = True, + error = e + ) + + if best_policy is not None: + return best_policy + + if self.get_settings().default_allow_scripts: + return policies.FallbackAllowPolicy() + else: + return policies.FallbackBlockPolicy() + + @staticmethod + def make(store_dir: Path) -> 'ConcreteHaketiloState': + """....""" + store = HaketiloStore(store_dir) + settings = store.load_global_settings() + + return ConcreteHaketiloState(store=store, settings=settings) diff --git a/src/hydrilla/proxy/state_impl/__init__.py b/src/hydrilla/proxy/state_impl/__init__.py new file mode 100644 index 0000000..5398cdd --- /dev/null +++ b/src/hydrilla/proxy/state_impl/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: CC0-1.0 + +# Copyright (C) 2022 Wojtek Kosior <koszko@koszko.org> +# +# Available under the terms of Creative Commons Zero v1.0 Universal. + +from .concrete_state import ConcreteHaketiloState diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py new file mode 100644 index 0000000..78a50c0 --- /dev/null +++ b/src/hydrilla/proxy/state_impl/base.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (definition of fields of a class that +# will implement HaketiloState). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +This module defines fields that will later be part of a concrete HaketiloState +subtype. +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import sqlite3 +import threading +import dataclasses as dc +import typing as t + +from pathlib import Path +from contextlib import contextmanager + +import sqlite3 + +from immutables import Map + +from ... import pattern_tree +from .. import state +from .. import policies + + +PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] + +#PayloadsDataMap = Map[state.PayloadRef, state.PayloadData] +DataById = t.Mapping[str, state.PayloadData] + +# mypy needs to be corrected: +# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704 +@dc.dataclass # type: ignore[misc] +class HaketiloStateWithFields(state.HaketiloState): + """....""" + store_dir: Path + connection: sqlite3.Connection + current_cursor: t.Optional[sqlite3.Cursor] = None + #settings: state.HaketiloGlobalSettings + + policy_tree: PolicyTree = PolicyTree() + #payloads_data: PayloadsDataMap = Map() + payloads_data: DataById = dc.field(default_factory=dict) + + lock: threading.RLock = dc.field(default_factory=threading.RLock) + + @contextmanager + def cursor(self, lock: bool = False, transaction: bool = False) \ + -> t.Iterator[sqlite3.Cursor]: + """....""" + start_transaction = transaction and not self.connection.in_transaction + + assert lock or not start_transaction + + try: + if lock: + self.lock.acquire() + + if self.current_cursor is not None: + yield self.current_cursor + return + + self.current_cursor = self.connection.cursor() + + if start_transaction: + self.current_cursor.execute('BEGIN TRANSACTION;') + + try: + yield self.current_cursor + + if start_transaction: + self.current_cursor.execute('COMMIT TRANSACTION;') + except: + if start_transaction: + self.current_cursor.execute('ROLLBACK TRANSACTION;') + raise + finally: + self.current_cursor = None + + if lock: + self.lock.release() + + def rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + """....""" + ... diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py new file mode 100644 index 0000000..cd6698c --- /dev/null +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -0,0 +1,704 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (instantiatable HaketiloState subtype). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +This module contains logic for keeping track of all settings, rules, mappings +and resources. +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import secrets +import io +import hashlib +import typing as t +import dataclasses as dc + +from pathlib import Path + +import sqlite3 + +from ...exceptions import HaketiloException +from ...translations import smart_gettext as _ +from ... import pattern_tree +from ... import url_patterns +from ... import versions +from ... import item_infos +from ..simple_dependency_satisfying import ItemsCollection, ComputedPayload +from .. import state as st +from .. import policies +from . import base + + +here = Path(__file__).resolve().parent + +AnyInfo = t.Union[item_infos.ResourceInfo, item_infos.MappingInfo] + +@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc] +class ConcreteRepoRef(st.RepoRef): + def remove(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + def update( + self, + state: st.HaketiloState, + *, + name: t.Optional[str] = None, + url: t.Optional[str] = None + ) -> ConcreteRepoRef: + raise NotImplementedError() + + def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteRepoIterationRef(st.RepoIterationRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteMappingRef(st.MappingRef): + def disable(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + def forget_enabled(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteMappingVersionRef(st.MappingVersionRef): + def enable(self, state: st.HaketiloState) -> None: + raise NotImplementedError() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteResourceRef(st.ResourceRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteResourceVersionRef(st.ResourceVersionRef): + pass + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcretePayloadRef(st.PayloadRef): + computed_payload: ComputedPayload = dc.field(hash=False, compare=False) + + def get_data(self, state: st.HaketiloState) -> st.PayloadData: + return t.cast(ConcreteHaketiloState, state).payloads_data[self.id] + + def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef: + return 'to implement' + + def get_script_paths(self, state: st.HaketiloState) \ + -> t.Iterator[t.Sequence[str]]: + for resource_info in self.computed_payload.resources: + for file_spec in resource_info.scripts: + yield (resource_info.identifier, *file_spec.name.split('/')) + + def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \ + -> t.Optional[st.FileData]: + if len(path) == 0: + raise st.MissingItemError() + + resource_identifier, *file_name_segments = path + + file_name = '/'.join(file_name_segments) + + script_sha256 = '' + + matched_resource_info = False + + for resource_info in self.computed_payload.resources: + if resource_info.identifier == resource_identifier: + matched_resource_info = True + + for script_spec in resource_info.scripts: + if script_spec.name == file_name: + script_sha256 = script_spec.sha256 + + break + + if not matched_resource_info: + raise st.MissingItemError(resource_identifier) + + if script_sha256 == '': + return None + + store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir + files_dir_path = store_dir_path / 'temporary_malcontent' / 'file' + file_path = files_dir_path / 'sha256' / script_sha256 + + return st.FileData( + type = 'application/javascript', + name = file_name, + contents = file_path.read_bytes() + ) + +PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory] + +def register_payload( + policy_tree: PolicyTree, + payload_key: st.PayloadKey, + token: str +) -> PolicyTree: + """....""" + payload_policy_factory = policies.PayloadPolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + payload_key.pattern, + payload_policy_factory + ) + + resource_policy_factory = policies.PayloadResourcePolicyFactory( + builtin = False, + payload_key = payload_key + ) + + policy_tree = policy_tree.register( + payload_key.pattern.path_append(token, '***'), + resource_policy_factory + ) + + return policy_tree + +DataById = t.Mapping[str, st.PayloadData] + +AnyInfoVar = t.TypeVar( + 'AnyInfoVar', + item_infos.ResourceInfo, + item_infos.MappingInfo +) + +# def newest_item_path(item_dir: Path) -> t.Optional[Path]: +# available_versions = tuple( +# versions.parse_normalize_version(ver_path.name) +# for ver_path in item_dir.iterdir() +# if ver_path.is_file() +# ) + +# if available_versions == (): +# return None + +# newest_version = max(available_versions) + +# version_path = item_dir / versions.version_string(newest_version) + +# assert version_path.is_file() + +# return version_path + +def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \ + -> t.Iterator[tuple[AnyInfoVar, str]]: + item_type_path = malcontent_path / item_class.type_name + if not item_type_path.is_dir(): + return + + for item_path in item_type_path.iterdir(): + if not item_path.is_dir(): + continue + + for item_version_path in item_path.iterdir(): + definition = item_version_path.read_text() + item_info = item_class.load(io.StringIO(definition)) + + assert item_info.identifier == item_path.name + assert versions.version_string(item_info.version) == \ + item_version_path.name + + yield item_info, definition + +def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int: + cursor.execute( + ''' + INSERT OR IGNORE INTO repos(name, url, deleted, next_iteration) + VALUES(?, '<dummy_url>', TRUE, 2); + ''', + (repo_name,) + ) + + cursor.execute( + ''' + SELECT + repo_id, next_iteration - 1 + FROM + repos + WHERE + name = ?; + ''', + (repo_name,) + ) + + (repo_id, last_iteration), = cursor.fetchall() + + cursor.execute( + ''' + INSERT OR IGNORE INTO repo_iterations(repo_id, iteration) + VALUES(?, ?); + ''', + (repo_id, last_iteration) + ) + + cursor.execute( + ''' + SELECT + repo_iteration_id + FROM + repo_iterations + WHERE + repo_id = ? AND iteration = ?; + ''', + (repo_id, last_iteration) + ) + + (repo_iteration_id,), = cursor.fetchall() + + return repo_iteration_id + +def get_or_make_item(cursor: sqlite3.Cursor, type: str, identifier: str) -> int: + type_letter = {'resource': 'R', 'mapping': 'M'}[type] + + cursor.execute( + ''' + INSERT OR IGNORE INTO items(type, identifier) + VALUES(?, ?); + ''', + (type_letter, identifier) + ) + + cursor.execute( + ''' + SELECT + item_id + FROM + items + WHERE + type = ? AND identifier = ?; + ''', + (type_letter, identifier) + ) + + (item_id,), = cursor.fetchall() + + return item_id + +def get_or_make_item_version( + cursor: sqlite3.Cursor, + item_id: int, + repo_iteration_id: int, + definition: str, + info: AnyInfo +) -> int: + ver_str = versions.version_string(info.version) + + values = ( + item_id, + ver_str, + repo_iteration_id, + definition, + info.allows_eval, + info.allows_cors_bypass + ) + + cursor.execute( + ''' + INSERT OR IGNORE INTO item_versions( + item_id, + version, + repo_iteration_id, + definition, + eval_allowed, + cors_bypass_allowed + ) + VALUES(?, ?, ?, ?, ?, ?); + ''', + values + ) + + cursor.execute( + ''' + SELECT + item_version_id + FROM + item_versions + WHERE + item_id = ? AND version = ? AND repo_iteration_id = ?; + ''', + (item_id, ver_str, repo_iteration_id) + ) + + (item_version_id,), = cursor.fetchall() + + return item_version_id + +def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None: + cursor.execute( + ''' + INSERT OR IGNORE INTO mapping_statuses(item_id, enabled) + VALUES(?, 'N'); + ''', + (item_id,) + ) + +def get_or_make_file(cursor: sqlite3.Cursor, sha256: str, file_bytes: bytes) \ + -> int: + cursor.execute( + ''' + INSERT OR IGNORE INTO files(sha256, data) + VALUES(?, ?) + ''', + (sha256, file_bytes) + ) + + cursor.execute( + ''' + SELECT + file_id + FROM + files + WHERE + sha256 = ?; + ''', + (sha256,) + ) + + (file_id,), = cursor.fetchall() + + return file_id + +def make_file_use( + cursor: sqlite3.Cursor, + item_version_id: int, + file_id: int, + name: str, + type: str, + mime_type: str, + idx: int +) -> None: + cursor.execute( + ''' + INSERT OR IGNORE INTO file_uses( + item_version_id, + file_id, + name, + type, + mime_type, + idx + ) + VALUES(?, ?, ?, ?, ?, ?); + ''', + (item_version_id, file_id, name, type, mime_type, idx) + ) + +@dc.dataclass +class ConcreteHaketiloState(base.HaketiloStateWithFields): + def __post_init__(self) -> None: + self._prepare_database() + + self._populate_database_with_stuff_from_temporary_malcontent_dir() + + with self.cursor() as cursor: + self.rebuild_structures(cursor) + + def _prepare_database(self) -> None: + """....""" + cursor = self.connection.cursor() + + try: + cursor.execute( + ''' + SELECT COUNT(name) + FROM sqlite_master + WHERE name = 'general' AND type = 'table'; + ''' + ) + + (db_initialized,), = cursor.fetchall() + + if not db_initialized: + cursor.executescript((here.parent / 'tables.sql').read_text()) + + else: + cursor.execute( + ''' + SELECT haketilo_version + FROM general; + ''' + ) + + (db_haketilo_version,) = cursor.fetchone() + if db_haketilo_version != '3.0b1': + raise HaketiloException(_('err.unknown_db_schema')) + + cursor.execute('PRAGMA FOREIGN_KEYS;') + if cursor.fetchall() == []: + raise HaketiloException(_('err.proxy.no_sqlite_foreign_keys')) + + cursor.execute('PRAGMA FOREIGN_KEYS=ON;') + finally: + cursor.close() + + def _populate_database_with_stuff_from_temporary_malcontent_dir(self) \ + -> None: + malcontent_dir_path = self.store_dir / 'temporary_malcontent' + files_by_sha256_path = malcontent_dir_path / 'file' / 'sha256' + + with self.cursor(lock=True, transaction=True) as cursor: + for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]: + info: AnyInfo + for info, definition in read_items( + malcontent_dir_path, + info_type # type: ignore + ): + repo_iteration_id = get_or_make_repo_iteration( + cursor, + info.repo + ) + + item_id = get_or_make_item( + cursor, + info.type_name, + info.identifier + ) + + item_version_id = get_or_make_item_version( + cursor, + item_id, + repo_iteration_id, + definition, + info + ) + + if info_type is item_infos.MappingInfo: + make_mapping_status(cursor, item_id) + + file_ids_bytes = {} + + file_specifiers = [*info.source_copyright] + if isinstance(info, item_infos.ResourceInfo): + file_specifiers.extend(info.scripts) + + for file_spec in file_specifiers: + file_path = files_by_sha256_path / file_spec.sha256 + file_bytes = file_path.read_bytes() + + sha256 = hashlib.sha256(file_bytes).digest().hex() + assert sha256 == file_spec.sha256 + + file_id = get_or_make_file(cursor, sha256, file_bytes) + + file_ids_bytes[sha256] = (file_id, file_bytes) + + for idx, file_spec in enumerate(info.source_copyright): + file_id, file_bytes = file_ids_bytes[file_spec.sha256] + if file_bytes.isascii(): + mime = 'text/plain' + else: + mime = 'application/octet-stream' + + make_file_use( + cursor, + item_version_id = item_version_id, + file_id = file_id, + name = file_spec.name, + type = 'L', + mime_type = mime, + idx = idx + ) + + if isinstance(info, item_infos.MappingInfo): + continue + + for idx, file_spec in enumerate(info.scripts): + file_id, _ = file_ids_bytes[file_spec.sha256] + make_file_use( + cursor, + item_version_id = item_version_id, + file_id = file_id, + name = file_spec.name, + type = 'W', + mime_type = 'application/javascript', + idx = idx + ) + + def rebuild_structures(self, cursor: sqlite3.Cursor) -> None: + cursor.execute( + ''' + SELECT + item_id, type, version, definition + FROM + item_versions JOIN items USING (item_id); + ''' + ) + + best_versions: dict[int, versions.VerTuple] = {} + definitions = {} + types = {} + + for item_id, item_type, ver_str, definition in cursor.fetchall(): + # TODO: what we're doing in this loop does not yet take different + # repos and different repo iterations into account. + ver = versions.parse_normalize_version(ver_str) + if best_versions.get(item_id, (0,)) < ver: + best_versions[item_id] = ver + definitions[item_id] = definition + types[item_id] = item_type + + resources = {} + mappings = {} + + for item_id, definition in definitions.items(): + if types[item_id] == 'R': + r_info = item_infos.ResourceInfo.load(io.StringIO(definition)) + resources[r_info.identifier] = r_info + else: + m_info = item_infos.MappingInfo.load(io.StringIO(definition)) + mappings[m_info.identifier] = m_info + + items_collection = ItemsCollection(resources, mappings) + computed_payloads = items_collection.compute_payloads() + + payloads_data = {} + + for mapping_info, by_pattern in computed_payloads.items(): + for num, (pattern, payload) in enumerate(by_pattern.items()): + payload_id = f'{num}@{mapping_info.identifier}' + + ref = ConcretePayloadRef(payload_id, payload) + + data = st.PayloadData( + payload_ref = ref, + mapping_installed = True, + explicitly_enabled = True, + unique_token = secrets.token_urlsafe(16), + pattern = pattern, + eval_allowed = payload.allows_eval, + cors_bypass_allowed = payload.allows_cors_bypass + ) + + payloads_data[payload_id] = data + + key = st.PayloadKey( + payload_ref = ref, + mapping_identifier = mapping_info.identifier, + pattern = pattern + ) + + self.policy_tree = register_payload( + self.policy_tree, + key, + data.unique_token + ) + + self.payloads_data = payloads_data + + def get_repo(self, repo_id: str) -> st.RepoRef: + return ConcreteRepoRef(repo_id) + + def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef: + return ConcreteRepoIterationRef(repo_iteration_id) + + def get_mapping(self, mapping_id: str) -> st.MappingRef: + return ConcreteMappingRef(mapping_id) + + def get_mapping_version(self, mapping_version_id: str) \ + -> st.MappingVersionRef: + return ConcreteMappingVersionRef(mapping_version_id) + + def get_resource(self, resource_id: str) -> st.ResourceRef: + return ConcreteResourceRef(resource_id) + + def get_resource_version(self, resource_version_id: str) \ + -> st.ResourceVersionRef: + return ConcreteResourceVersionRef(resource_version_id) + + def get_payload(self, payload_id: str) -> st.PayloadRef: + return 'not implemented' + + def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \ + -> st.RepoRef: + raise NotImplementedError() + + def get_settings(self) -> st.HaketiloGlobalSettings: + return st.HaketiloGlobalSettings( + mapping_use_mode = st.MappingUseMode.AUTO, + default_allow_scripts = True, + repo_refresh_seconds = 0 + ) + + def update_settings( + self, + *, + mapping_use_mode: t.Optional[st.MappingUseMode] = None, + default_allow_scripts: t.Optional[bool] = None, + repo_refresh_seconds: t.Optional[int] = None + ) -> None: + raise NotImplementedError() + + def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy: + """....""" + with self.lock: + policy_tree = self.policy_tree + + try: + best_priority: int = 0 + best_policy: t.Optional[policies.Policy] = None + + for factories_set in policy_tree.search(url): + for stored_factory in sorted(factories_set): + factory = stored_factory.item + + policy = factory.make_policy(self) + + if policy.priority > best_priority: + best_priority = policy.priority + best_policy = policy + except Exception as e: + return policies.ErrorBlockPolicy( + builtin = True, + error = e + ) + + if best_policy is not None: + return best_policy + + if self.get_settings().default_allow_scripts: + return policies.FallbackAllowPolicy() + else: + return policies.FallbackBlockPolicy() + + @staticmethod + def make(store_dir: Path) -> 'ConcreteHaketiloState': + return ConcreteHaketiloState( + store_dir = store_dir, + connection = sqlite3.connect(str(store_dir / 'sqlite3.db')) + ) diff --git a/src/hydrilla/proxy/store.py b/src/hydrilla/proxy/store.py index 72852d8..4978b65 100644 --- a/src/hydrilla/proxy/store.py +++ b/src/hydrilla/proxy/store.py @@ -29,12 +29,153 @@ # Enable using with Python 3.7. from __future__ import annotations +import threading import dataclasses as dc +import typing as t from pathlib import Path +from enum import Enum + +from immutables import Map + +from .. url_patterns import parse_pattern +from .. import versions +from . import state + + +@dc.dataclass(frozen=True, eq=False) +class StoredItemRef(state.ItemRef): + item_id: int + + def __eq__(self, other: object) -> bool: + return isinstance(other, StoredItemRef) and \ + self.item_id == other.item_id + + def __hash__(self) -> int: + return hash(self.item_id) + + def _id(self) -> str: + return str(self.item_id) + + +@dc.dataclass(frozen=True, eq=False) +class StoredPayloadRef(state.PayloadRef): + payload_id: int + + def __eq__(self, other: object) -> bool: + return isinstance(other, StoredPayloadRef) and \ + self.payload_id == other.payload_id + + def __hash__(self) -> int: + return hash(self.payload_id) + + def _id(self) -> str: + return str(self.payload_id) + + +# class ItemStoredData: +# """....""" +# def __init__( +# self, +# item_id: int +# ty#pe: ItemType +# repository_id: int +# version: str +# identifier: str +# orphan: bool +# installed: bool +# enabled: EnabledStatus +# ) -> None: +# """....""" +# self.item_id = item_id +# self.type = ItemType(type) +# self.repository_id = repository_id +# self.version = parse +# identifier: str +# orphan: bool +# installed: bool +# enabled: EnabledStatus + @dc.dataclass class HaketiloStore: """....""" store_dir: Path - # TODO: implement + + lock: threading.RLock = dc.field(default_factory=threading.RLock) + + # def load_all_resources(self) -> t.Sequence[item_infos.ResourceInfo]: + # """....""" + # # TODO: implement + # with self.lock: + # return [] + + def load_installed_mappings_data(self) \ + -> t.Mapping[state.MappingRef, state.EnabledStatus]: + """....""" + # TODO: implement + with self.lock: + dummy_item_ref = StoredItemRef( + item_id = 47, + identifier = 'somemapping', + version = versions.parse_normalize_version('1.2.3'), + repository = 'somerepo', + orphan = False + ) + + return Map({ + state.MappingRef(dummy_item_ref): state.EnabledStatus.ENABLED + }) + + def load_payloads_data(self) \ + -> t.Mapping[state.MappingRef, t.Iterable[state.PayloadRef]]: + """....""" + # TODO: implement + with self.lock: + dummy_item_ref = StoredItemRef( + item_id = 47, + identifier = 'somemapping', + version = versions.parse_normalize_version('1.2.3'), + repository = 'somerepo', + orphan = False + ) + + dummy_mapping_ref = state.MappingRef(dummy_item_ref) + + payload_refs = [] + for parsed_pattern in parse_pattern('http*://example.com/a/***'): + dummy_payload_ref = StoredPayloadRef( + payload_id = 22, + mapping_ref = dummy_mapping_ref, + pattern = parsed_pattern + ) + + payload_refs.append(dummy_payload_ref) + + return Map({dummy_mapping_ref: payload_refs}) + + def load_file_data( + self, + payload_ref: state.PayloadRef, + resource_identifier: str, + file_path: t.Sequence[str] + ) -> t.Optional[state.FileData]: + # TODO: implement + with self.lock: + return None + + def load_global_settings(self) -> state.HaketiloGlobalSettings: + """....""" + # TODO: implement + with self.lock: + return state.HaketiloGlobalSettings( + state.MappingApplicationMode.WHEN_ENABLED, + False + ) + + def write_global_settings(self, settings: state.HaketiloGlobalSettings) \ + -> None: + """....""" + # TODO: implement + with self.lock: + pass diff --git a/src/hydrilla/proxy/tables.sql b/src/hydrilla/proxy/tables.sql new file mode 100644 index 0000000..53539a7 --- /dev/null +++ b/src/hydrilla/proxy/tables.sql @@ -0,0 +1,235 @@ +-- SPDX-License-Identifier: GPL-3.0-or-later + +-- SQLite tables definitions for Haketilo proxy. +-- +-- This file is part of Hydrilla&Haketilo. +-- +-- Copyright (C) 2022 Wojtek Kosior +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation, either version 3 of the License, or +-- (at your option) any later version. +-- +-- This program is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. +-- +-- You should have received a copy of the GNU General Public License +-- along with this program. If not, see <https://www.gnu.org/licenses/>. +-- +-- +-- I, Wojtek Kosior, thereby promise not to sue for violation of this +-- file's license. Although I request that you do not make use this code +-- in a proprietary program, I am not going to enforce this in court. + +BEGIN TRANSACTION; + +CREATE TABLE general( + haketilo_version VARCHAR NOT NULL, + default_allow_scripts BOOLEAN NOT NULL, + repo_refresh_seconds INTEGER NOT NULL, + -- "mapping_use_mode" determines whether current mode is AUTO, + -- WHEN_ENABLED or QUESTION. + mapping_use_mode CHAR(1) NOT NULL, + + CHECK (rowid = 1), + CHECK (mapping_use_mode IN ('A', 'W', 'Q')), + CHECK (haketilo_version = '3.0b1') +); + +INSERT INTO general( + rowid, + haketilo_version, + default_allow_scripts, + repo_refresh_seconds, + mapping_use_mode +) VALUES( + 1, + '3.0b1', + FALSE, + 24 * 60 * 60, + 'Q' +); + +CREATE TABLE rules( + rule_id INTEGER PRIMARY KEY, + + pattern VARCHAR NOT NULL, + allow_scripts BOOLEAN NOT NULL, + + UNIQUE (pattern) +); + +CREATE TABLE repos( + repo_id INTEGER PRIMARY KEY, + + name VARCHAR NOT NULL, + url VARCHAR NOT NULL, + deleted BOOLEAN NOT NULL, + next_iteration INTEGER NOT NULL, + active_iteration_id INTEGER NULL, + last_refreshed INTEGER NULL, + + UNIQUE (name), + + FOREIGN KEY (active_iteration_id) + REFERENCES repo_iterations(repo_iteration_id) + ON DELETE SET NULL +); + +CREATE TABLE repo_iterations( + repo_iteration_id INTEGER PRIMARY KEY, + + repo_id INTEGER NOT NULL, + iteration INTEGER NOT NULL, + + UNIQUE (repo_id, iteration), + + FOREIGN KEY (repo_id) + REFERENCES repos (repo_id) +); + +CREATE VIEW orphan_iterations +AS +SELECT + ri.repo_iteration_id, + ri.repo_id, + ri.iteration +FROM + repo_iterations AS ri + JOIN repos AS r USING (repo_id) +WHERE + COALESCE(r.active_iteration_id != ri.repo_iteration_id, TRUE); + +CREATE TABLE items( + item_id INTEGER PRIMARY KEY, + + -- "type" determines whether it's resource or mapping. + type CHAR(1) NOT NULL, + identifier VARCHAR NOT NULL, + + UNIQUE (type, identifier), + CHECK (type IN ('R', 'M')) +); + +CREATE TABLE mapping_statuses( + -- The item with this id shall be a mapping ("type" = 'M'). + item_id INTEGER PRIMARY KEY, + + -- "enabled" determines whether mapping's status is ENABLED, + -- DISABLED or NO_MARK. + enabled CHAR(1) NOT NULL, + enabled_version_id INTEGER NULL, + -- "frozen" determines whether an enabled mapping is to be kept in its + -- EXACT_VERSION, is to be updated only with versions from the same + -- REPOSITORY or is NOT_FROZEN at all. + frozen CHAR(1) NULL, + + CHECK (NOT (enabled = 'D' AND enabled_version_id IS NOT NULL)), + CHECK (NOT (enabled = 'E' AND enabled_version_id IS NULL)), + + CHECK ((frozen IS NULL) = (enabled != 'E')), + CHECK (frozen IS NULL OR frozen in ('E', 'R', 'N')) +); + +CREATE TABLE item_versions( + item_version_id INTEGER PRIMARY KEY, + + item_id INTEGER NOT NULL, + version VARCHAR NOT NULL, + repo_iteration_id INTEGER NOT NULL, + definition TEXT NOT NULL, + -- What privileges should be granted on pages where this + -- resource/mapping is used. + eval_allowed BOOLEAN NOT NULL, + cors_bypass_allowed BOOLEAN NOT NULL, + + UNIQUE (item_id, version, repo_iteration_id), + -- Allow foreign key from "mapping_statuses". + UNIQUE (item_version_id, item_id), + + FOREIGN KEY (item_id) + REFERENCES items (item_id), + FOREIGN KEY (repo_iteration_id) + REFERENCES repo_iterations (repo_iteration_id) +); + +CREATE TABLE payloads( + payload_id INTEGER PRIMARY KEY, + + mapping_item_id INTEGER NOT NULL, + pattern VARCHAR NOT NULL, + + UNIQUE (mapping_item_id, pattern), + + FOREIGN KEY (mapping_item_id) + REFERENCES item_versions (versioned_item_id) + ON DELETE CASCADE +); + +CREATE TABLE resolved_depended_resources( + payload_id INTEGER, + resource_item_id INTEGER, + + -- "idx" determines the ordering of resources. + idx INTEGER, + + PRIMARY KEY (payload_id, resource_item_id), + + FOREIGN KEY (payload_id) + REFERENCES payloads (payload_id), + FOREIGN KEY (resource_item_id) + REFERENCES item_versions (item_version_id) +) WITHOUT ROWID; + +-- CREATE TABLE resolved_required_mappings( +-- requiring_item_id INTEGER, +-- required_mapping_item_id INTEGER, + +-- PRIMARY KEY (requiring_item_id, required_mapping_item_id), + +-- CHECK (requiring_item_id != required_mapping_item_id), + +-- FOREIGN KEY (requiring_item_id) +-- REFERENCES items (item_id), +-- -- Note: the referenced mapping shall have installed=TRUE. +-- FOREIGN KEY (required_mapping_item_id) +-- REFERENCES mappings (item_id), +-- FOREIGN KEY (required_mapping_item_id) +-- REFERENCES items (item_id) +-- ); + +CREATE TABLE files( + file_id INTEGER PRIMARY KEY, + + sha256 CHAR(64) NOT NULL, + data BLOB NULL, + + UNIQUE (sha256) +); + +CREATE TABLE file_uses( + file_use_id INTEGER PRIMARY KEY, + + item_version_id INTEGER NOT NULL, + file_id INTEGER NOT NULL, + name VARCHAR NOT NULL, + -- "type" determines whether it's license file or web resource. + type CHAR(1) NOT NULL, + mime_type VARCHAR NOT NULL, + -- "idx" determines the ordering of item's files of given type. + idx INTEGER NOT NULL, + + CHECK (type IN ('L', 'W')), + UNIQUE(item_version_id, type, idx), + + FOREIGN KEY (item_version_id) + REFERENCES item_versions(item_version_id) + ON DELETE CASCADE, + FOREIGN KEY (file_id) + REFERENCES files(file_id) +); + +COMMIT TRANSACTION; diff --git a/src/hydrilla/url_patterns.py b/src/hydrilla/url_patterns.py index 8e80379..0a242e3 100644 --- a/src/hydrilla/url_patterns.py +++ b/src/hydrilla/url_patterns.py @@ -41,36 +41,73 @@ import dataclasses as dc from immutables import Map -from hydrilla.translations import smart_gettext as _ -from hydrilla.exceptions import HaketiloException +from .translations import smart_gettext as _ +from .exceptions import HaketiloException default_ports: t.Mapping[str, int] = Map(http=80, https=443, ftp=21) -@dc.dataclass(frozen=True, unsafe_hash=True) +ParsedUrlType = t.TypeVar('ParsedUrlType', bound='ParsedUrl') + +@dc.dataclass(frozen=True, unsafe_hash=True, order=True) class ParsedUrl: """....""" - orig_url: str # orig_url used in __hash__() - scheme: str = dc.field(hash=False) - domain_labels: tuple[str, ...] = dc.field(hash=False) - path_segments: tuple[str, ...] = dc.field(hash=False) - has_trailing_slash: bool = dc.field(hash=False) - port: int = dc.field(hash=False) - - # def reconstruct_url(self) -> str: - # """....""" - # scheme = self.orig_scheme - - # netloc = '.'.join(reversed(self.domain_labels)) - # if scheme == self.scheme and \ - # self.port is not None and \ - # default_ports[scheme] != self.port: - # netloc += f':{self.port}' - - # path = '/'.join(('', *self.path_segments)) - # if self.has_trailing_slash: - # path += '/' + orig_url: str # used in __hash__() and __lt__() + scheme: str = dc.field(hash=False, compare=False) + domain_labels: tuple[str, ...] = dc.field(hash=False, compare=False) + path_segments: tuple[str, ...] = dc.field(hash=False, compare=False) + has_trailing_slash: bool = dc.field(hash=False, compare=False) + port: int = dc.field(hash=False, compare=False) + + @property + def url_without_path(self) -> str: + """....""" + scheme = self.scheme + + netloc = '.'.join(reversed(self.domain_labels)) + + if self.port is not None and \ + default_ports[scheme] != self.port: + netloc += f':{self.port}' + + return f'{scheme}://{netloc}' + + def _reconstruct_url(self) -> str: + """....""" + path = '/'.join(('', *self.path_segments)) + if self.has_trailing_slash: + path += '/' + + return self.url_without_path + path + + def path_append(self: ParsedUrlType, *new_segments: str) -> ParsedUrlType: + """....""" + new_url = self._reconstruct_url() + if not self.has_trailing_slash: + new_url += '/' + + new_url += '/'.join(new_segments) + + return dc.replace( + self, + orig_url = new_url, + path_segments = tuple((*self.path_segments, *new_segments)), + has_trailing_slash = False + ) + +ParsedPattern = t.NewType('ParsedPattern', ParsedUrl) + +# # We sometimes need a dummy pattern that means "match everything". +# catchall_pattern = ParsedPattern( +# ParsedUrl( +# orig_url = '<dummy_catchall_url_pattern>' +# scheme = '<dummy_all-scheme>' +# domain_labels = ('***',) +# path_segments = ('***',) +# has_trailing_slash = False +# port = 0 +# ) +# ) - # return f'{scheme}://{netloc}{path}' # URLs with those schemes will be recognized but not all of them have to be # actually supported by Hydrilla server and Haketilo proxy. @@ -163,7 +200,7 @@ def _parse_pattern_or_url(url: str, orig_url: str, is_pattern: bool = False) \ replace_scheme_regex = re.compile(r'^[^:]*') -def parse_pattern(url_pattern: str) -> t.Sequence[ParsedUrl]: +def parse_pattern(url_pattern: str) -> t.Iterator[ParsedPattern]: """....""" if url_pattern.startswith('http*:'): patterns = [ @@ -173,8 +210,8 @@ def parse_pattern(url_pattern: str) -> t.Sequence[ParsedUrl]: else: patterns = [url_pattern] - return tuple(_parse_pattern_or_url(pat, url_pattern, True) - for pat in patterns) + for pat in patterns: + yield ParsedPattern(_parse_pattern_or_url(pat, url_pattern, True)) def parse_url(url: str) -> ParsedUrl: """....""" diff --git a/src/hydrilla/versions.py b/src/hydrilla/versions.py index a7a9f29..7474d98 100644 --- a/src/hydrilla/versions.py +++ b/src/hydrilla/versions.py @@ -34,14 +34,16 @@ from __future__ import annotations import typing as t -def normalize_version(ver: t.Sequence[int]) -> tuple[int, ...]: - """Strip right-most zeroes from 'ver'. The original list is not modified.""" +VerTuple = t.NewType('VerTuple', 'tuple[int, ...]') + +def normalize_version(ver: t.Sequence[int]) -> VerTuple: + """Strip rightmost zeroes from 'ver'.""" new_len = 0 for i, num in enumerate(ver): if num != 0: new_len = i + 1 - return tuple(ver[:new_len]) + return VerTuple(tuple(ver[:new_len])) def parse_version(ver_str: str) -> tuple[int, ...]: """ @@ -50,10 +52,16 @@ def parse_version(ver_str: str) -> tuple[int, ...]: """ return tuple(int(num) for num in ver_str.split('.')) -def version_string(ver: t.Sequence[int], rev: t.Optional[int] = None) -> str: +def parse_normalize_version(ver_str: str) -> VerTuple: + """ + Convert 'ver_str' into a VerTuple representation, e.g. for + ver_str="4.6.13.0" return (4, 6, 13). + """ + return normalize_version(parse_version(ver_str)) + +def version_string(ver: VerTuple, rev: t.Optional[int] = None) -> str: """ Produce version's string representation (optionally with revision), like: 1.2.3-5 - No version normalization is performed. """ return '.'.join(str(n) for n in ver) + ('' if rev is None else f'-{rev}') |