aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla
diff options
context:
space:
mode:
authorWojtek Kosior <koszko@koszko.org>2022-07-27 15:56:24 +0200
committerWojtek Kosior <koszko@koszko.org>2022-08-10 17:25:05 +0200
commit879c41927171efc8d77d1de2739b18e2eb57580f (patch)
treede0e78afe2ea49e58c9bf2c662657392a00139ee /src/hydrilla
parent52d12a4fa124daa1595529e3e7008276a7986d95 (diff)
downloadhaketilo-hydrilla-879c41927171efc8d77d1de2739b18e2eb57580f.tar.gz
haketilo-hydrilla-879c41927171efc8d77d1de2739b18e2eb57580f.zip
unfinished partial work
Diffstat (limited to 'src/hydrilla')
-rw-r--r--src/hydrilla/item_infos.py529
-rw-r--r--src/hydrilla/json_instances.py14
-rw-r--r--src/hydrilla/mitmproxy_launcher/__init__.py5
-rw-r--r--src/hydrilla/mitmproxy_launcher/launch.py16
-rw-r--r--src/hydrilla/pattern_tree.py110
-rw-r--r--src/hydrilla/proxy/addon.py239
-rw-r--r--src/hydrilla/proxy/csp.py125
-rw-r--r--src/hydrilla/proxy/flow_handlers.py383
-rw-r--r--src/hydrilla/proxy/oversimplified_state_impl.py392
-rw-r--r--src/hydrilla/proxy/policies/__init__.py15
-rw-r--r--src/hydrilla/proxy/policies/base.py177
-rw-r--r--src/hydrilla/proxy/policies/fallback.py (renamed from src/hydrilla/proxy/policies.py)60
-rw-r--r--src/hydrilla/proxy/policies/payload.py315
-rw-r--r--src/hydrilla/proxy/policies/payload_resource.py160
-rw-r--r--src/hydrilla/proxy/policies/rule.py134
-rw-r--r--src/hydrilla/proxy/simple_dependency_satisfying.py189
-rw-r--r--src/hydrilla/proxy/state.py279
-rw-r--r--src/hydrilla/proxy/state_impl.py288
-rw-r--r--src/hydrilla/proxy/state_impl/__init__.py7
-rw-r--r--src/hydrilla/proxy/state_impl/base.py112
-rw-r--r--src/hydrilla/proxy/state_impl/concrete_state.py704
-rw-r--r--src/hydrilla/proxy/store.py143
-rw-r--r--src/hydrilla/proxy/tables.sql235
-rw-r--r--src/hydrilla/url_patterns.py91
-rw-r--r--src/hydrilla/versions.py18
25 files changed, 3946 insertions, 794 deletions
diff --git a/src/hydrilla/item_infos.py b/src/hydrilla/item_infos.py
index c366ab5..9ba47bd 100644
--- a/src/hydrilla/item_infos.py
+++ b/src/hydrilla/item_infos.py
@@ -31,48 +31,75 @@
# Enable using with Python 3.7.
from __future__ import annotations
+import sys
+
+if sys.version_info >= (3, 8):
+ from typing import Protocol
+else:
+ from typing_extensions import Protocol
+
import typing as t
import dataclasses as dc
-from pathlib import Path, PurePath
+from pathlib import Path, PurePosixPath
+from abc import ABC
-from immutables import Map, MapMutation
+from immutables import Map
from . import versions, json_instances
-from .url_patterns import parse_pattern, ParsedUrl
+from .url_patterns import parse_pattern, ParsedUrl, ParsedPattern
from .exceptions import HaketiloException
from .translations import smart_gettext as _
VerTuple = t.Tuple[int, ...]
@dc.dataclass(frozen=True, unsafe_hash=True)
-class ItemRef:
+class ItemSpecifier:
"""...."""
identifier: str
-RefObjs = t.Sequence[t.Mapping[str, t.Any]]
+SpecifierObjs = t.Sequence[t.Mapping[str, t.Any]]
-def make_item_refs_seq(ref_objs: RefObjs) -> tuple[ItemRef, ...]:
+def make_item_specifiers_seq(spec_objs: SpecifierObjs) \
+ -> tuple[ItemSpecifier, ...]:
"""...."""
- return tuple(ItemRef(ref['identifier']) for ref in ref_objs)
+ return tuple(ItemSpecifier(obj['identifier']) for obj in spec_objs)
-def make_required_mappings(refs_objs: t.Any, schema_compat: int) \
- -> tuple[ItemRef, ...]:
+def make_required_mappings(spec_objs: t.Any, schema_compat: int) \
+ -> tuple[ItemSpecifier, ...]:
"""...."""
if schema_compat < 2:
return ()
- return make_item_refs_seq(refs_objs)
+ return make_item_specifiers_seq(spec_objs)
@dc.dataclass(frozen=True, unsafe_hash=True)
-class FileRef:
+class FileSpecifier:
"""...."""
name: str
sha256: str
-def make_file_refs_seq(ref_objs: RefObjs) -> tuple[FileRef, ...]:
+def normalize_filename(name: str):
+ """
+ This function eliminated double slashes in file name and ensures it does not
+ try to reference parent directories.
+ """
+ path = PurePosixPath(name)
+
+ if '.' in path.parts or '..' in path.parts:
+ msg = _('err.item_info.filename_invalid_{}').format(name)
+ raise HaketiloException(msg)
+
+ return str(path)
+
+def make_file_specifiers_seq(spec_objs: SpecifierObjs) \
+ -> tuple[FileSpecifier, ...]:
"""...."""
- return tuple(FileRef(ref['file'], ref['sha256']) for ref in ref_objs)
+ return tuple(
+ FileSpecifier(normalize_filename(obj['file']), obj['sha256'])
+ for obj
+ in spec_objs
+ )
@dc.dataclass(frozen=True, unsafe_hash=True)
class GeneratedBy:
@@ -81,60 +108,107 @@ class GeneratedBy:
version: t.Optional[str]
@staticmethod
- def make(generated_obj: t.Optional[t.Mapping[str, t.Any]]) -> \
+ def make(generated_by_obj: t.Optional[t.Mapping[str, t.Any]]) -> \
t.Optional['GeneratedBy']:
"""...."""
- if generated_obj is None:
+ if generated_by_obj is None:
return None
return GeneratedBy(
- name = generated_obj['name'],
- version = generated_obj.get('version')
+ name = generated_by_obj['name'],
+ version = generated_by_obj.get('version')
)
-@dc.dataclass(frozen=True, unsafe_hash=True)
-class ItemInfoBase:
+
+def make_eval_permission(perms_obj: t.Any, schema_compat: int) -> bool:
+ if schema_compat < 2:
+ return False
+
+ return perms_obj.get('eval', False)
+
+
+def make_cors_bypass_permission(perms_obj: t.Any, schema_compat: int) -> bool:
+ if schema_compat < 2:
+ return False
+
+ return perms_obj.get('cors_bypass', False)
+
+
+class Categorizable(Protocol):
"""...."""
- repository: str # repository used in __hash__()
- source_name: str = dc.field(hash=False)
- source_copyright: tuple[FileRef, ...] = dc.field(hash=False)
- version: VerTuple # version used in __hash__()
- identifier: str # identifier used in __hash__()
- uuid: t.Optional[str] = dc.field(hash=False)
- long_name: str = dc.field(hash=False)
- required_mappings: tuple[ItemRef, ...] = dc.field(hash=False)
- generated_by: t.Optional[GeneratedBy] = dc.field(hash=False)
-
- def path_relative_to_type(self) -> str:
- """
- Get a relative path to this item's JSON definition with respect to
- directory containing items of this type.
- """
- return f'{self.identifier}/{versions.version_string(self.version)}'
-
- def path(self) -> str:
- """
- Get a relative path to this item's JSON definition with respect to
- malcontent directory containing loadable items.
- """
- return f'{self.type_name}/{self.path_relative_to_type()}'
+ uuid: t.Optional[str]
+ identifier: str
- @property
- def versioned_identifier(self):
- """...."""
- return f'{self.identifier}-{versions.version_string(self.version)}'
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ItemIdentity:
+ repo: str
+ repo_iteration: int
+ version: versions.VerTuple
+ identifier: str
+
+# mypy needs to be corrected:
+# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
+@dc.dataclass(frozen=True) # type: ignore[misc]
+class ItemInfoBase(ABC, ItemIdentity, Categorizable):
+ """...."""
+ type_name: t.ClassVar[str] = '!INVALID!'
+
+ source_name: str = dc.field(hash=False)
+ source_copyright: tuple[FileSpecifier, ...] = dc.field(hash=False)
+ uuid: t.Optional[str] = dc.field(hash=False)
+ long_name: str = dc.field(hash=False)
+ allows_eval: bool = dc.field(hash=False)
+ allows_cors_bypass: bool = dc.field(hash=False)
+ required_mappings: tuple[ItemSpecifier, ...] = dc.field(hash=False)
+ generated_by: t.Optional[GeneratedBy] = dc.field(hash=False)
+
+ # def path_relative_to_type(self) -> str:
+ # """
+ # Get a relative path to this item's JSON definition with respect to
+ # directory containing items of this type.
+ # """
+ # return f'{self.identifier}/{versions.version_string(self.version)}'
+
+ # def path(self) -> str:
+ # """
+ # Get a relative path to this item's JSON definition with respect to
+ # malcontent directory containing loadable items.
+ # """
+ # return f'{self.type_name}/{self.path_relative_to_type()}'
+
+ # @property
+ # def identity(self):
+ # """...."""
+ # return ItemIdentity(
+ # repository = self.repository,
+ # version = self.version,
+ # identifier = self.identifier
+ # )
+
+ # @property
+ # def versioned_identifier(self):
+ # """...."""
+ # return f'{self.identifier}-{versions.version_string(self.version)}'
@staticmethod
def _get_base_init_kwargs(
- item_obj: t.Mapping[str, t.Any],
- schema_compat: int,
- repository: str
+ item_obj: t.Mapping[str, t.Any],
+ schema_compat: int,
+ repo: str,
+ repo_iteration: int
) -> t.Mapping[str, t.Any]:
"""...."""
- source_copyright = make_file_refs_seq(item_obj['source_copyright'])
+ source_copyright = make_file_specifiers_seq(
+ item_obj['source_copyright']
+ )
version = versions.normalize_version(item_obj['version'])
+ perms_obj = item_obj.get('permissions', {})
+
+ eval_perm = make_eval_permission(perms_obj, schema_compat)
+ cors_bypass_perm = make_cors_bypass_permission(perms_obj, schema_compat)
+
required_mappings = make_required_mappings(
item_obj.get('required_mappings', []),
schema_compat
@@ -143,28 +217,29 @@ class ItemInfoBase:
generated_by = GeneratedBy.make(item_obj.get('generated_by'))
return Map(
- repository = repository,
- source_name = item_obj['source_name'],
- source_copyright = source_copyright,
- version = version,
- identifier = item_obj['identifier'],
- uuid = item_obj.get('uuid'),
- long_name = item_obj['long_name'],
- required_mappings = required_mappings,
- generated_by = generated_by
+ repo = repo,
+ repo_iteration = repo_iteration,
+ source_name = item_obj['source_name'],
+ source_copyright = source_copyright,
+ version = version,
+ identifier = item_obj['identifier'],
+ uuid = item_obj.get('uuid'),
+ long_name = item_obj['long_name'],
+ allows_eval = eval_perm,
+ allows_cors_bypass = cors_bypass_perm,
+ required_mappings = required_mappings,
+ generated_by = generated_by
)
- # class property
- type_name = '!INVALID!'
-
-InstanceOrPath = t.Union[Path, str, dict[str, t.Any]]
@dc.dataclass(frozen=True, unsafe_hash=True)
class ResourceInfo(ItemInfoBase):
"""...."""
- revision: int = dc.field(hash=False)
- dependencies: tuple[ItemRef, ...] = dc.field(hash=False)
- scripts: tuple[FileRef, ...] = dc.field(hash=False)
+ type_name: t.ClassVar[str] = 'resource'
+
+ revision: int = dc.field(hash=False)
+ dependencies: tuple[ItemSpecifier, ...] = dc.field(hash=False)
+ scripts: tuple[FileSpecifier, ...] = dc.field(hash=False)
@property
def versioned_identifier(self):
@@ -173,41 +248,70 @@ class ResourceInfo(ItemInfoBase):
@staticmethod
def make(
- item_obj: t.Mapping[str, t.Any],
- schema_compat: int,
- repository: str
+ item_obj: t.Mapping[str, t.Any],
+ schema_compat: int,
+ repo: str,
+ repo_iteration: int
) -> 'ResourceInfo':
"""...."""
base_init_kwargs = ItemInfoBase._get_base_init_kwargs(
item_obj,
schema_compat,
- repository
+ repo,
+ repo_iteration
+ )
+
+ dependencies = make_item_specifiers_seq(
+ item_obj.get('dependencies', [])
+ )
+
+ scripts = make_file_specifiers_seq(
+ item_obj.get('scripts', [])
)
return ResourceInfo(
**base_init_kwargs,
revision = item_obj['revision'],
- dependencies = make_item_refs_seq(item_obj.get('dependencies', [])),
- scripts = make_file_refs_seq(item_obj.get('scripts', [])),
+ dependencies = dependencies,
+ scripts = scripts
)
@staticmethod
- def load(instance_or_path: 'InstanceOrPath', repository: str) \
- -> 'ResourceInfo':
+ def load(
+ instance_or_path: json_instances.InstanceOrPathOrIO,
+ repo: str = '<dummyrepo>',
+ repo_iteration: int = -1
+ ) -> 'ResourceInfo':
"""...."""
- return _load_item_info(ResourceInfo, instance_or_path, repository)
+ return _load_item_info(
+ ResourceInfo,
+ instance_or_path,
+ repo,
+ repo_iteration
+ )
- # class property
- type_name = 'resource'
+ # def __lt__(self, other: 'ResourceInfo') -> bool:
+ # """...."""
+ # return (
+ # self.identifier,
+ # self.version,
+ # self.revision,
+ # self.repository
+ # ) < (
+ # other.identifier,
+ # other.version,
+ # other.revision,
+ # other.repository
+ # )
def make_payloads(payloads_obj: t.Mapping[str, t.Any]) \
- -> t.Mapping[ParsedUrl, ItemRef]:
+ -> t.Mapping[ParsedPattern, ItemSpecifier]:
"""...."""
- mapping: list[tuple[ParsedUrl, ItemRef]] = []
+ mapping: list[tuple[ParsedPattern, ItemSpecifier]] = []
- for pattern, ref_obj in payloads_obj.items():
- ref = ItemRef(ref_obj['identifier'])
+ for pattern, spec_obj in payloads_obj.items():
+ ref = ItemSpecifier(spec_obj['identifier'])
mapping.extend((parsed, ref) for parsed in parse_pattern(pattern))
return Map(mapping)
@@ -215,19 +319,23 @@ def make_payloads(payloads_obj: t.Mapping[str, t.Any]) \
@dc.dataclass(frozen=True, unsafe_hash=True)
class MappingInfo(ItemInfoBase):
"""...."""
- payloads: t.Mapping[ParsedUrl, ItemRef] = dc.field(hash=False)
+ type_name: t.ClassVar[str] = 'mapping'
+
+ payloads: t.Mapping[ParsedPattern, ItemSpecifier] = dc.field(hash=False)
@staticmethod
def make(
- item_obj: t.Mapping[str, t.Any],
- schema_compat: int,
- repository: str
+ item_obj: t.Mapping[str, t.Any],
+ schema_compat: int,
+ repo: str,
+ repo_iteration: int
) -> 'MappingInfo':
"""...."""
base_init_kwargs = ItemInfoBase._get_base_init_kwargs(
item_obj,
schema_compat,
- repository
+ repo,
+ repo_iteration
)
return MappingInfo(
@@ -237,10 +345,23 @@ class MappingInfo(ItemInfoBase):
)
@staticmethod
- def load(instance_or_path: 'InstanceOrPath', repository: str) \
- -> 'MappingInfo':
+ def load(
+ instance_or_path: json_instances.InstanceOrPathOrIO,
+ repo: str = '<dummyrepo>',
+ repo_iteration: int = -1
+ ) -> 'MappingInfo':
"""...."""
- return _load_item_info(MappingInfo, instance_or_path, repository)
+ return _load_item_info(
+ MappingInfo,
+ instance_or_path,
+ repo,
+ repo_iteration
+ )
+
+ # def __lt__(self, other: 'MappingInfo') -> bool:
+ # """...."""
+ # return (self.identifier, self.version, self.repository) < \
+ # (other.identifier, other.version, other.repository)
# class property
type_name = 'mapping'
@@ -250,8 +371,9 @@ LoadedType = t.TypeVar('LoadedType', ResourceInfo, MappingInfo)
def _load_item_info(
info_type: t.Type[LoadedType],
- instance_or_path: InstanceOrPath,
- repository: str
+ instance_or_path: json_instances.InstanceOrPathOrIO,
+ repo: str,
+ repo_iteration: int
) -> LoadedType:
"""Read, validate and autocomplete a mapping/resource description."""
instance = json_instances.read_instance(instance_or_path)
@@ -264,81 +386,156 @@ def _load_item_info(
return info_type.make(
t.cast('dict[str, t.Any]', instance),
schema_compat,
- repository
+ repo,
+ repo_iteration
)
-VersionedType = t.TypeVar('VersionedType', ResourceInfo, MappingInfo)
-
-@dc.dataclass(frozen=True)
-class VersionedItemInfo(t.Generic[VersionedType]):
- """Stores data of multiple versions of given resource/mapping."""
- uuid: t.Optional[str] = None
- identifier: str = '<dummy>'
- _by_version: Map[VerTuple, VersionedType] = Map()
- _initialized: bool = False
-
- def register(self, item_info: VersionedType) -> 'VersionedInfoSelfType':
- """
- Make item info queryable by version. Perform sanity checks for uuid.
- """
- identifier = item_info.identifier
- if self._initialized:
- assert identifier == self.identifier
-
- if self.uuid is not None:
- uuid: t.Optional[str] = self.uuid
- if item_info.uuid is not None and self.uuid != item_info.uuid:
- raise HaketiloException(_('uuid_mismatch_{identifier}')
- .format(identifier=identifier))
- else:
- uuid = item_info.uuid
-
- by_version = self._by_version.set(item_info.version, item_info)
-
- return VersionedItemInfo(
- identifier = identifier,
- uuid = uuid,
- _by_version = by_version,
- _initialized = True
- )
-
- def unregister(self, version: VerTuple) -> 'VersionedInfoSelfType':
- """...."""
- try:
- by_version = self._by_version.delete(version)
- except KeyError:
- by_version = self._by_version
-
- return dc.replace(self, _by_version=by_version)
-
- def is_empty(self) -> bool:
- """...."""
- return len(self._by_version) == 0
-
- def newest_version(self) -> VerTuple:
- """...."""
- assert not self.is_empty()
-
- return max(self._by_version.keys())
-
- def get_newest(self) -> VersionedType:
- """Find and return info of the newest version of item."""
- newest = self._by_version[self.newest_version()]
- assert newest is not None
- return newest
-
- def get_by_ver(self, ver: t.Iterable[int]) -> t.Optional[VersionedType]:
- """
- Find and return info of the specified version of the item (or None if
- absent).
- """
- return self._by_version.get(tuple(ver))
-
- def get_all(self) -> t.Iterator[VersionedType]:
- """Generate item info for all its versions, from oldest ot newest."""
- for version in sorted(self._by_version.keys()):
- yield self._by_version[version]
-
-# Below we define 1 type used by recursively-typed VersionedItemInfo.
-VersionedInfoSelfType = VersionedItemInfo[VersionedType]
+# CategorizedType = t.TypeVar(
+# 'CategorizedType',
+# bound=Categorizable
+# )
+
+# CategorizedUpdater = t.Callable[
+# [t.Optional[CategorizedType]],
+# t.Optional[CategorizedType]
+# ]
+
+# CategoryKeyType = t.TypeVar('CategoryKeyType', bound=t.Hashable)
+
+# @dc.dataclass(frozen=True)
+# class CategorizedItemInfo(Categorizable, t.Generic[CategorizedType, CategoryKeyType]):
+# """...."""
+# SelfType = t.TypeVar(
+# 'SelfType',
+# bound = 'CategorizedItemInfo[CategorizedType, CategoryKeyType]'
+# )
+
+# uuid: t.Optional[str] = None
+# identifier: str = '<dummy>'
+# items: Map[CategoryKeyType, CategorizedType] = Map()
+# _initialized: bool = False
+
+# def _update(
+# self: 'SelfType',
+# key: CategoryKeyType,
+# updater: CategorizedUpdater
+# ) -> 'SelfType':
+# """...... Perform sanity checks for uuid."""
+# uuid = self.uuid
+
+# items = self.items.mutate()
+
+# updated = updater(items.get(key))
+# if updated is None:
+# items.pop(key, None)
+
+# identifier = self.identifier
+# else:
+# items[key] = updated
+
+# identifier = updated.identifier
+# if self._initialized:
+# assert identifier == self.identifier
+
+# if uuid is not None:
+# if updated.uuid is not None and uuid != updated.uuid:
+# raise HaketiloException(_('uuid_mismatch_{identifier}')
+# .format(identifier=identifier))
+# else:
+# uuid = updated.uuid
+
+# return dc.replace(
+# self,
+# identifier = identifier,
+# uuid = uuid,
+# items = items.finish(),
+# _initialized = self._initialized or updated is not None
+# )
+
+# def is_empty(self) -> bool:
+# """...."""
+# return len(self.items) == 0
+
+
+# VersionedType = t.TypeVar('VersionedType', ResourceInfo, MappingInfo)
+
+# class VersionedItemInfo(
+# CategorizedItemInfo[VersionedType, VerTuple],
+# t.Generic[VersionedType]
+# ):
+# """Stores data of multiple versions of given resource/mapping."""
+# SelfType = t.TypeVar('SelfType', bound='VersionedItemInfo[VersionedType]')
+
+# def register(self: 'SelfType', item_info: VersionedType) -> 'SelfType':
+# """
+# Make item info queryable by version. Perform sanity checks for uuid.
+# """
+# return self._update(item_info.version, lambda old_info: item_info)
+
+# def unregister(self: 'SelfType', version: VerTuple) -> 'SelfType':
+# """...."""
+# return self._update(version, lambda old_info: None)
+
+# def newest_version(self) -> VerTuple:
+# """...."""
+# assert not self.is_empty()
+
+# return max(self.items.keys())
+
+# def get_newest(self) -> VersionedType:
+# """Find and return info of the newest version of item."""
+# newest = self.items[self.newest_version()]
+# assert newest is not None
+# return newest
+
+# def get_by_ver(self, ver: t.Iterable[int]) -> t.Optional[VersionedType]:
+# """
+# Find and return info of the specified version of the item (or None if
+# absent).
+# """
+# return self.items.get(tuple(ver))
+
+# def get_all(self) -> t.Iterator[VersionedType]:
+# """Generate item info for all its versions, from oldest to newest."""
+# for version in sorted(self.items.keys()):
+# yield self.items[version]
+
+
+# MultiRepoType = t.TypeVar('MultiRepoType', ResourceInfo, MappingInfo)
+# MultiRepoVersioned = VersionedItemInfo[MultiRepoType]
+
+# class MultiRepoItemInfo(
+# CategorizedItemInfo[MultiRepoVersioned, str],
+# t.Generic[MultiRepoType]
+# ):
+# SelfType = t.TypeVar('SelfType', bound='MultiRepoItemInfo[MultiRepoType]')
+
+# def register(self: 'SelfType', item_info: MultiRepoType) -> 'SelfType':
+# """
+# Make item info queryable by version. Perform sanity checks for uuid.
+# """
+# def updater(old_item: t.Optional[MultiRepoVersioned]) \
+# -> MultiRepoVersioned:
+# """...."""
+# if old_item is None:
+# old_item = VersionedItemInfo()
+
+# return old_item.register(item_info)
+
+# return self._update(item_info.repository, updater)
+
+# def unregister(self: 'SelfType', version: VerTuple, repository: str) \
+# -> 'SelfType':
+# """...."""
+# def updater(old_item: t.Optional[MultiRepoVersioned]) -> \
+# t.Optional[MultiRepoVersioned]:
+# """...."""
+# if old_item is None:
+# return None
+
+# new_item = old_item.unregister(version)
+
+# return None if new_item.is_empty() else new_item
+
+# return self._update(repository, updater)
diff --git a/src/hydrilla/json_instances.py b/src/hydrilla/json_instances.py
index 40b213b..33a3785 100644
--- a/src/hydrilla/json_instances.py
+++ b/src/hydrilla/json_instances.py
@@ -34,6 +34,7 @@ from __future__ import annotations
import re
import json
import os
+import io
import typing as t
from pathlib import Path, PurePath
@@ -158,15 +159,22 @@ def parse_instance(text: str) -> object:
"""Parse 'text' as JSON with additional '//' comments support."""
return json.loads(strip_json_comments(text))
-InstanceOrPath = t.Union[Path, str, dict[str, t.Any]]
+InstanceOrPathOrIO = t.Union[Path, str, io.TextIOBase, dict[str, t.Any]]
-def read_instance(instance_or_path: InstanceOrPath) -> object:
+def read_instance(instance_or_path: InstanceOrPathOrIO) -> object:
"""...."""
if isinstance(instance_or_path, dict):
return instance_or_path
- with open(instance_or_path, 'rt') as handle:
+ if isinstance(instance_or_path, io.TextIOBase):
+ handle = instance_or_path
+ else:
+ handle = t.cast(io.TextIOBase, open(instance_or_path, 'rt'))
+
+ try:
text = handle.read()
+ finally:
+ handle.close()
try:
return parse_instance(text)
diff --git a/src/hydrilla/mitmproxy_launcher/__init__.py b/src/hydrilla/mitmproxy_launcher/__init__.py
new file mode 100644
index 0000000..d382ead
--- /dev/null
+++ b/src/hydrilla/mitmproxy_launcher/__init__.py
@@ -0,0 +1,5 @@
+# SPDX-License-Identifier: CC0-1.0
+
+# Copyright (C) 2022 Wojtek Kosior <koszko@koszko.org>
+#
+# Available under the terms of Creative Commons Zero v1.0 Universal.
diff --git a/src/hydrilla/mitmproxy_launcher/launch.py b/src/hydrilla/mitmproxy_launcher/launch.py
index c826598..765a9ce 100644
--- a/src/hydrilla/mitmproxy_launcher/launch.py
+++ b/src/hydrilla/mitmproxy_launcher/launch.py
@@ -42,6 +42,12 @@ import click
from .. import _version
from ..translations import smart_gettext as _
+addon_script_text = '''
+from hydrilla.proxy.addon import HaketiloAddon
+
+addons = [HaketiloAddon()]
+'''
+
@click.command(help=_('cli_help.haketilo'))
@click.option('-p', '--port', default=8080, type=click.IntRange(0, 65535),
help=_('cli_opt.haketilo.port'))
@@ -61,17 +67,13 @@ def launch(port: int, directory: str):
script_path = directory_path / 'addon.py'
- script_path.write_text('''
-from hydrilla.mitmproxy_addon.addon import Haketilo
-
-addons = [Haketilo()]
-''')
+ script_path.write_text(addon_script_text)
code = sp.call(['mitmdump',
'-p', str(port),
- '--set', f'confdir={directory_path / "mitmproxy"}'
+ '--set', f'confdir={directory_path / "mitmproxy"}',
'--set', 'upstream_cert=false',
- '--set', f'haketilo_dir={directory_path}'
+ '--set', f'haketilo_dir={directory_path}',
'--scripts', str(script_path)])
sys.exit(code)
diff --git a/src/hydrilla/pattern_tree.py b/src/hydrilla/pattern_tree.py
index 1128a06..99f45a5 100644
--- a/src/hydrilla/pattern_tree.py
+++ b/src/hydrilla/pattern_tree.py
@@ -31,44 +31,37 @@ This module defines data structures for querying data using URL patterns.
# Enable using with Python 3.7.
from __future__ import annotations
-import sys
import typing as t
import dataclasses as dc
from immutables import Map
-from .url_patterns import ParsedUrl, parse_url
+from .url_patterns import ParsedPattern, ParsedUrl, parse_url#, catchall_pattern
from .translations import smart_gettext as _
WrapperStoredType = t.TypeVar('WrapperStoredType', bound=t.Hashable)
-@dc.dataclass(frozen=True, unsafe_hash=True)
+@dc.dataclass(frozen=True, unsafe_hash=True, order=True)
class StoredTreeItem(t.Generic[WrapperStoredType]):
"""
In the Pattern Tree, each item is stored together with the pattern used to
register it.
"""
- pattern: ParsedUrl
item: WrapperStoredType
+ pattern: ParsedPattern
-# if sys.version_info >= (3, 8):
-# CopyableType = t.TypeVar('CopyableType', bound='Copyable')
-
-# class Copyable(t.Protocol):
-# """Certain classes in Pattern Tree depend on this interface."""
-# def copy(self: CopyableType) -> CopyableType:
-# """Make a distinct instance with the same properties as this one."""
-# ...
-# else:
-# Copyable = t.Any
NodeStoredType = t.TypeVar('NodeStoredType')
@dc.dataclass(frozen=True)
class PatternTreeNode(t.Generic[NodeStoredType]):
"""...."""
- children: 'NodeChildrenType' = Map()
+ SelfType = t.TypeVar('SelfType', bound='PatternTreeNode[NodeStoredType]')
+
+ ChildrenType = Map[str, SelfType]
+
+ children: 'ChildrenType' = Map()
literal_match: t.Optional[NodeStoredType] = None
def is_empty(self) -> bool:
@@ -76,17 +69,17 @@ class PatternTreeNode(t.Generic[NodeStoredType]):
return len(self.children) == 0 and self.literal_match is None
def update_literal_match(
- self,
+ self: 'SelfType',
new_match_item: t.Optional[NodeStoredType]
- ) -> 'NodeSelfType':
+ ) -> 'SelfType':
"""...."""
return dc.replace(self, literal_match=new_match_item)
- def get_child(self, child_key: str) -> t.Optional['NodeSelfType']:
+ def get_child(self: 'SelfType', child_key: str) -> t.Optional['SelfType']:
"""...."""
return self.children.get(child_key)
- def remove_child(self, child_key: str) -> 'NodeSelfType':
+ def remove_child(self: 'SelfType', child_key: str) -> 'SelfType':
"""...."""
try:
children = self.children.delete(child_key)
@@ -95,19 +88,15 @@ class PatternTreeNode(t.Generic[NodeStoredType]):
return dc.replace(self, children=children)
- def set_child(self, child_key: str, child: 'NodeSelfType') \
- -> 'NodeSelfType':
+ def set_child(self: 'SelfType', child_key: str, child: 'SelfType') \
+ -> 'SelfType':
"""...."""
return dc.replace(self, children=self.children.set(child_key, child))
-# Below we define 2 types used by recursively-typed PatternTreeNode.
-NodeSelfType = PatternTreeNode[NodeStoredType]
-NodeChildrenType = Map[str, NodeSelfType]
-
BranchStoredType = t.TypeVar('BranchStoredType')
-ItemUpdater = t.Callable[
+BranchItemUpdater = t.Callable[
[t.Optional[BranchStoredType]],
t.Optional[BranchStoredType]
]
@@ -115,18 +104,22 @@ ItemUpdater = t.Callable[
@dc.dataclass(frozen=True)
class PatternTreeBranch(t.Generic[BranchStoredType]):
"""...."""
+ SelfType = t.TypeVar(
+ 'SelfType',
+ bound = 'PatternTreeBranch[BranchStoredType]'
+ )
+
root_node: PatternTreeNode[BranchStoredType] = PatternTreeNode()
def is_empty(self) -> bool:
"""...."""
return self.root_node.is_empty()
- # def copy(self) -> 'BranchSelfType':
- # """...."""
- # return dc.replace(self)
-
- def update(self, segments: t.Iterable[str], item_updater: ItemUpdater) \
- -> 'BranchSelfType':
+ def update(
+ self: 'SelfType',
+ segments: t.Iterable[str],
+ item_updater: BranchItemUpdater
+ ) -> 'SelfType':
"""
.......
"""
@@ -188,9 +181,6 @@ class PatternTreeBranch(t.Generic[BranchStoredType]):
if condition():
yield match_node.literal_match
-# Below we define 1 type used by recursively-typed PatternTreeBranch.
-BranchSelfType = PatternTreeBranch[BranchStoredType]
-
FilterStoredType = t.TypeVar('FilterStoredType', bound=t.Hashable)
FilterWrappedType = StoredTreeItem[FilterStoredType]
@@ -218,19 +208,21 @@ class PatternTree(t.Generic[TreeStoredType]):
is to make it possible to quickly retrieve all known patterns that match
a given URL.
"""
- _by_scheme_and_port: TreeRoot = Map()
+ SelfType = t.TypeVar('SelfType', bound='PatternTree[TreeStoredType]')
+
+ _by_scheme_and_port: TreeRoot = Map()
def _register(
- self,
- parsed_pattern: ParsedUrl,
+ self: 'SelfType',
+ parsed_pattern: ParsedPattern,
item: TreeStoredType,
register: bool = True
- ) -> 'TreeSelfType':
+ ) -> 'SelfType':
"""
Make an item wrapped in StoredTreeItem object queryable through the
Pattern Tree by the given parsed URL pattern.
"""
- wrapped_item = StoredTreeItem(parsed_pattern, item)
+ wrapped_item = StoredTreeItem(item, parsed_pattern)
def item_updater(item_set: t.Optional[StoredSet]) \
-> t.Optional[StoredSet]:
@@ -276,36 +268,21 @@ class PatternTree(t.Generic[TreeStoredType]):
return dc.replace(self, _by_scheme_and_port=new_root)
- # def _register(
- # self,
- # url_pattern: str,
- # item: TreeStoredType,
- # register: bool = True
- # ) -> 'TreeSelfType':
- # """
- # ....
- # """
- # tree = self
-
- # for parsed_pat in parse_pattern(url_pattern):
- # wrapped_item = StoredTreeItem(parsed_pat, item)
- # tree = tree._register_with_parsed_pattern(
- # parsed_pat,
- # wrapped_item,
- # register
- # )
-
- # return tree
-
- def register(self, parsed_pattern: ParsedUrl, item: TreeStoredType) \
- -> 'TreeSelfType':
+ def register(
+ self: 'SelfType',
+ parsed_pattern: ParsedPattern,
+ item: TreeStoredType
+ ) -> 'SelfType':
"""
Make item queryable through the Pattern Tree by the given URL pattern.
"""
return self._register(parsed_pattern, item)
- def deregister(self, parsed_pattern: ParsedUrl, item: TreeStoredType) \
- -> 'TreeSelfType':
+ def deregister(
+ self: 'SelfType',
+ parsed_pattern: ParsedPattern,
+ item: TreeStoredType
+ ) -> 'SelfType':
"""
Make item no longer queryable through the Pattern Tree by the given URL
pattern.
@@ -334,6 +311,3 @@ class PatternTree(t.Generic[TreeStoredType]):
items = filter_by_trailing_slash(item_set, with_slash)
if len(items) > 0:
yield items
-
-# Below we define 1 type used by recursively-typed PatternTree.
-TreeSelfType = PatternTree[TreeStoredType]
diff --git a/src/hydrilla/proxy/addon.py b/src/hydrilla/proxy/addon.py
index 7d6487b..16c2841 100644
--- a/src/hydrilla/proxy/addon.py
+++ b/src/hydrilla/proxy/addon.py
@@ -32,58 +32,70 @@ from addon script.
# Enable using with Python 3.7.
from __future__ import annotations
-import os.path
+import sys
import typing as t
import dataclasses as dc
+import traceback as tb
from threading import Lock
from pathlib import Path
from contextlib import contextmanager
-from mitmproxy import http, addonmanager, ctx
+# for mitmproxy 6.*
+from mitmproxy.net import http
+if not hasattr(http, 'Headers'):
+ # for mitmproxy 8.*
+ from mitmproxy import http # type: ignore
+
+from mitmproxy.http import HTTPFlow
+
+from mitmproxy import addonmanager, ctx
from mitmproxy.script import concurrent
-from .flow_handlers import make_flow_handler, FlowHandler
-from .state import HaketiloState
+from ..exceptions import HaketiloException
from ..translations import smart_gettext as _
+from ..url_patterns import parse_url
+from .state_impl import ConcreteHaketiloState
+from . import policies
-FlowHandlers = dict[int, FlowHandler]
-StateUpdater = t.Callable[[HaketiloState], None]
+DefaultGetValue = t.TypeVar('DefaultGetValue', object, None)
-HTTPHandlerFun = t.Callable[
- ['HaketiloAddon', http.HTTPFlow],
- t.Optional[StateUpdater]
-]
-
-def http_event_handler(handler_fun: HTTPHandlerFun):
- """....decorator"""
- def wrapped_handler(self: 'HaketiloAddon', flow: http.HTTPFlow):
+class MitmproxyHeadersWrapper():
+ """...."""
+ def __init__(self, headers: http.Headers) -> None:
"""...."""
- with self.configured_lock:
- assert self.configured
+ self.headers = headers
- assert self.state is not None
+ __getitem__ = lambda self, key: self.headers[key]
+ get_all = lambda self, key: self.headers.get_all(key)
- state_updater = handler_fun(self, flow)
+ def get(self, key: str, default: DefaultGetValue = None) \
+ -> t.Union[str, DefaultGetValue]:
+ """...."""
+ value = self.headers.get(key)
- if state_updater is not None:
- state_updater(self.state)
+ if value is None:
+ return default
+ else:
+ return t.cast(str, value)
- return wrapped_handler
+ def items(self) -> t.Iterable[tuple[str, str]]:
+ """...."""
+ return self.headers.items(multi=True)
@dc.dataclass
class HaketiloAddon:
"""
.......
"""
- configured: bool = False
- configured_lock: Lock = dc.field(default_factory=Lock)
+ configured: bool = False
+ configured_lock: Lock = dc.field(default_factory=Lock)
- state: t.Optional[HaketiloState] = None
+ flow_policies: dict[int, policies.Policy] = dc.field(default_factory=dict)
+ policies_lock: Lock = dc.field(default_factory=Lock)
- flow_handlers: FlowHandlers = dc.field(default_factory=dict)
- handlers_lock: Lock = dc.field(default_factory=Lock)
+ state: t.Optional[ConcreteHaketiloState] = None
def load(self, loader: addonmanager.Loader) -> None:
"""...."""
@@ -104,74 +116,165 @@ class HaketiloAddon:
ctx.log.warn(_('haketilo_dir_already_configured'))
return
- haketilo_dir = Path(ctx.options.haketilo_dir)
- self.state = HaketiloState(haketilo_dir / 'store')
+ try:
+ haketilo_dir = Path(ctx.options.haketilo_dir)
+
+ self.state = ConcreteHaketiloState.make(haketilo_dir / 'store')
+ except Exception as e:
+ tb.print_exception(None, e, e.__traceback__)
+ sys.exit(1)
- def assign_handler(self, flow: http.HTTPFlow, flow_handler: FlowHandler) \
- -> None:
+ self.configured = True
+
+ def try_get_policy(self, flow: HTTPFlow, fail_ok: bool = True) -> \
+ t.Optional[policies.Policy]:
"""...."""
- with self.handlers_lock:
- self.flow_handlers[id(flow)] = flow_handler
+ with self.policies_lock:
+ policy = self.flow_policies.get(id(flow))
+
+ if policy is None:
+ try:
+ parsed_url = parse_url(flow.request.url)
+ except HaketiloException:
+ if fail_ok:
+ return None
+ else:
+ raise
+
+ assert self.state is not None
+
+ policy = self.state.select_policy(parsed_url)
+
+ with self.policies_lock:
+ self.flow_policies[id(flow)] = policy
+
+ return policy
+
+ def get_policy(self, flow: HTTPFlow) -> policies.Policy:
+ return t.cast(policies.Policy, self.try_get_policy(flow, fail_ok=False))
- def lookup_handler(self, flow: http.HTTPFlow) -> FlowHandler:
+ def forget_policy(self, flow: HTTPFlow) -> None:
"""...."""
- with self.handlers_lock:
- return self.flow_handlers[id(flow)]
+ with self.policies_lock:
+ self.flow_policies.pop(id(flow), None)
- def forget_handler(self, flow: http.HTTPFlow) -> None:
+ @contextmanager
+ def http_safe_event_handling(self, flow: HTTPFlow) -> t.Iterator:
"""...."""
- with self.handlers_lock:
- self.flow_handlers.pop(id(flow), None)
+ with self.configured_lock:
+ assert self.configured
+
+ try:
+ yield
+ except Exception as e:
+ tb_string = ''.join(tb.format_exception(None, e, e.__traceback__))
+ error_text = _('err.proxy.unknown_error_{}_try_again')\
+ .format(tb_string)\
+ .encode()
+ flow.response = http.Response.make(
+ status_code = 500,
+ content = error_text,
+ headers = [(b'Content-Type', b'text/plain; charset=utf-8')]
+ )
+
+ self.forget_policy(flow)
@concurrent
- @http_event_handler
- def requestheaders(self, flow: http.HTTPFlow) -> t.Optional[StateUpdater]:
+ def requestheaders(self, flow: HTTPFlow) -> None:
+ # TODO: don't account for mitmproxy 6 in the code
+ # Mitmproxy 6 causes even more strange behavior than described below.
+ # This cannot be easily worked around. Let's just use version 8 and
+ # make an APT package for it.
"""
- .....
+ Under mitmproxy 8 this handler deduces an appropriate policy for flow's
+ URL and assigns it to the flow. Under mitmproxy 6 the URL is not yet
+ available at this point, so the handler effectively does nothing.
"""
- assert self.state is not None
-
- policy = self.state.select_policy(flow.request.url)
+ with self.http_safe_event_handling(flow):
+ policy = self.try_get_policy(flow)
- flow_handler = make_flow_handler(flow, policy)
-
- self.assign_handler(flow, flow_handler)
-
- return flow_handler.on_requestheaders()
+ if policy is not None:
+ if not policy.process_request:
+ flow.request.stream = True
+ if policy.anticache:
+ flow.request.anticache()
@concurrent
- @http_event_handler
- def request(self, flow: http.HTTPFlow) -> t.Optional[StateUpdater]:
+ def request(self, flow: HTTPFlow) -> None:
"""
....
"""
- return self.lookup_handler(flow).on_request()
+ if flow.request.stream:
+ return
- @concurrent
- @http_event_handler
- def responseheaders(self, flow: http.HTTPFlow) -> t.Optional[StateUpdater]:
+ with self.http_safe_event_handling(flow):
+ policy = self.get_policy(flow)
+
+ request_info = policies.RequestInfo(
+ url = parse_url(flow.request.url),
+ method = flow.request.method,
+ headers = MitmproxyHeadersWrapper(flow.request.headers),
+ body = flow.request.get_content(strict=False) or b''
+ )
+
+ result = policy.consume_request(request_info)
+
+ if result is not None:
+ if isinstance(result, policies.ProducedRequest):
+ flow.request = http.Request.make(
+ url = result.url,
+ method = result.method,
+ headers = http.Headers(result.headers),
+ content = result.body
+ )
+ else:
+ # isinstance(result, policies.ProducedResponse)
+ flow.response = http.Response.make(
+ status_code = result.status_code,
+ headers = http.Headers(result.headers),
+ content = result.body
+ )
+
+ def responseheaders(self, flow: HTTPFlow) -> None:
"""
......
"""
- return self.lookup_handler(flow).on_responseheaders()
+ assert flow.response is not None
+
+ with self.http_safe_event_handling(flow):
+ policy = self.get_policy(flow)
+
+ if not policy.process_response:
+ flow.response.stream = True
@concurrent
- @http_event_handler
- def response(self, flow: http.HTTPFlow) -> t.Optional[StateUpdater]:
+ def response(self, flow: HTTPFlow) -> None:
"""
......
"""
- updater = self.lookup_handler(flow).on_response()
+ assert flow.response is not None
- self.forget_handler(flow)
+ if flow.response.stream:
+ return
- return updater
+ with self.http_safe_event_handling(flow):
+ policy = self.get_policy(flow)
- @http_event_handler
- def error(self, flow: http.HTTPFlow) -> None:
- """...."""
- self.forget_handler(flow)
+ response_info = policies.ResponseInfo(
+ url = parse_url(flow.request.url),
+ status_code = flow.response.status_code,
+ headers = MitmproxyHeadersWrapper(flow.response.headers),
+ body = flow.response.get_content(strict=False) or b''
+ )
-addons = [
- HaketiloAddon()
-]
+ result = policy.consume_response(response_info)
+ if result is not None:
+ flow.response.status_code = result.status_code
+ flow.response.headers = http.Headers(result.headers)
+ flow.response.set_content(result.body)
+
+ self.forget_policy(flow)
+
+ def error(self, flow: HTTPFlow) -> None:
+ """...."""
+ self.forget_policy(flow)
diff --git a/src/hydrilla/proxy/csp.py b/src/hydrilla/proxy/csp.py
new file mode 100644
index 0000000..59d93f2
--- /dev/null
+++ b/src/hydrilla/proxy/csp.py
@@ -0,0 +1,125 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Tools for working with Content Security Policy headers.
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+.....
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import re
+import typing as t
+import dataclasses as dc
+
+from immutables import Map, MapMutation
+
+from .policies.base import IHeaders
+
+
+header_names_and_dispositions = (
+ ('content-security-policy', 'enforce'),
+ ('content-security-policy-report-only', 'report'),
+ ('x-content-security-policy', 'enforce'),
+ ('x-content-security-policy', 'report'),
+ ('x-webkit-csp', 'enforce'),
+ ('x-webkit-csp', 'report')
+)
+
+enforce_header_names_set = {
+ name for name, disposition in header_names_and_dispositions
+ if disposition == 'enforce'
+}
+
+@dc.dataclass
+class ContentSecurityPolicy:
+ directives: Map[str, t.Sequence[str]]
+ header_name: str
+ disposition: str
+
+ def serialize(self) -> str:
+ """...."""
+ serialized_directives = []
+ for name, value_list in self.directives.items():
+ serialized_directives.append(f'{name} {" ".join(value_list)}')
+
+ return ';'.join(serialized_directives)
+
+ @staticmethod
+ def deserialize(
+ serialized: str,
+ header_name: str,
+ disposition: str = 'enforce'
+ ) -> 'ContentSecurityPolicy':
+ """...."""
+ # For more info, see:
+ # https://www.w3.org/TR/CSP3/#parse-serialized-policy
+ empty_directives: Map[str, t.Sequence[str]] = Map()
+
+ directives = empty_directives.mutate()
+
+ for serialized_directive in serialized.split(';'):
+ if not serialized_directive.isascii():
+ continue
+
+ serialized_directive = serialized_directive.strip()
+ if len(serialized_directive) == 0:
+ continue
+
+ tokens = serialized_directive.split()
+ directive_name = tokens.pop(0).lower()
+ directive_value = tokens
+
+ # Specs mention giving warnings for duplicate directive names but
+ # from our proxy's perspective this is not important right now.
+ if directive_name in directives:
+ continue
+
+ directives[directive_name] = directive_value
+
+ return ContentSecurityPolicy(
+ directives = directives.finish(),
+ header_name = header_name,
+ disposition = disposition
+ )
+
+def extract(headers: IHeaders) -> tuple[ContentSecurityPolicy, ...]:
+ """...."""
+ csp_policies = []
+
+ for header_name, disposition in header_names_and_dispositions:
+ for serialized_list in headers.get_all(header_name):
+ for serialized in serialized_list.split(','):
+ policy = ContentSecurityPolicy.deserialize(
+ serialized,
+ header_name,
+ disposition
+ )
+
+ if policy.directives != Map():
+ csp_policies.append(policy)
+
+ return tuple(csp_policies)
diff --git a/src/hydrilla/proxy/flow_handlers.py b/src/hydrilla/proxy/flow_handlers.py
deleted file mode 100644
index 605c7f9..0000000
--- a/src/hydrilla/proxy/flow_handlers.py
+++ /dev/null
@@ -1,383 +0,0 @@
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-# Logic for modifying mitmproxy's HTTP flows.
-#
-# This file is part of Hydrilla&Haketilo.
-#
-# Copyright (C) 2022 Wojtek Kosior
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see <https://www.gnu.org/licenses/>.
-#
-#
-# I, Wojtek Kosior, thereby promise not to sue for violation of this
-# file's license. Although I request that you do not make use this code
-# in a proprietary program, I am not going to enforce this in court.
-
-"""
-This module's file gets passed to Mitmproxy as addon script and makes it serve
-as Haketilo proxy.
-"""
-
-# Enable using with Python 3.7.
-from __future__ import annotations
-
-import re
-import typing as t
-import dataclasses as dc
-
-import bs4 # type: ignore
-
-from mitmproxy import http
-from mitmproxy.net.http import Headers
-from mitmproxy.script import concurrent
-
-from .state import HaketiloState
-from . import policies
-
-StateUpdater = t.Callable[[HaketiloState], None]
-
-@dc.dataclass(frozen=True)
-class FlowHandler:
- """...."""
- flow: http.HTTPFlow
- policy: policies.Policy
-
- stream_request: bool = False
- stream_response: bool = False
-
- def on_requestheaders(self) -> t.Optional[StateUpdater]:
- """...."""
- if self.stream_request:
- self.flow.request.stream = True
-
- return None
-
- def on_request(self) -> t.Optional[StateUpdater]:
- """...."""
- return None
-
- def on_responseheaders(self) -> t.Optional[StateUpdater]:
- """...."""
- assert self.flow.response is not None
-
- if self.stream_response:
- self.flow.response.stream = True
-
- return None
-
- def on_response(self) -> t.Optional[StateUpdater]:
- """...."""
- return None
-
-@dc.dataclass(frozen=True)
-class FlowHandlerAllowScripts(FlowHandler):
- """...."""
- policy: policies.AllowPolicy
-
- stream_request: bool = True
- stream_response: bool = True
-
-csp_header_names_and_dispositions = (
- ('content-security-policy', 'enforce'),
- ('content-security-policy-report-only', 'report'),
- ('x-content-security-policy', 'enforce'),
- ('x-content-security-policy', 'report'),
- ('x-webkit-csp', 'enforce'),
- ('x-webkit-csp', 'report')
-)
-
-csp_enforce_header_names_set = {
- name for name, disposition in csp_header_names_and_dispositions
- if disposition == 'enforce'
-}
-
-@dc.dataclass
-class ContentSecurityPolicy:
- directives: dict[str, list[str]]
- header_name: str
- disposition: str
-
- @staticmethod
- def deserialize(
- serialized: str,
- header_name: str,
- disposition: str = 'enforce'
- ) -> 'ContentSecurityPolicy':
- """...."""
- # For more info, see:
- # https://www.w3.org/TR/CSP3/#parse-serialized-policy
- directives = {}
-
- for serialized_directive in serialized.split(';'):
- if not serialized_directive.isascii():
- continue
-
- serialized_directive = serialized_directive.strip()
- if len(serialized_directive) == 0:
- continue
-
- tokens = serialized_directive.split()
- directive_name = tokens.pop(0).lower()
- directive_value = tokens
-
- # Specs mention giving warnings for duplicate directive names but
- # from our proxy's perspective this is not important right now.
- if directive_name in directives:
- continue
-
- directives[directive_name] = directive_value
-
- return ContentSecurityPolicy(directives, header_name, disposition)
-
- def serialize(self) -> str:
- """...."""
- serialized_directives = []
- for name, value_list in self.directives.items():
- serialized_directives.append(f'{name} {" ".join(value_list)}')
-
- return ';'.join(serialized_directives)
-
-def extract_csp(headers: Headers) -> tuple[ContentSecurityPolicy, ...]:
- """...."""
- csp_policies = []
-
- for header_name, disposition in csp_header_names_and_dispositions:
- for serialized_list in headers.get(header_name, ''):
- for serialized in serialized_list.split(','):
- policy = ContentSecurityPolicy.deserialize(
- serialized,
- header_name,
- disposition
- )
-
- if policy.directives != {}:
- csp_policies.append(policy)
-
- return tuple(csp_policies)
-
-csp_script_directive_names = (
- 'script-src',
- 'script-src-elem',
- 'script-src-attr'
-)
-
-@dc.dataclass(frozen=True)
-class FlowHandlerBlockScripts(FlowHandler):
- policy: policies.BlockPolicy
-
- stream_request: bool = True
- stream_response: bool = True
-
- def on_responseheaders(self) -> t.Optional[StateUpdater]:
- """...."""
- super().on_responseheaders()
-
- assert self.flow.response is not None
-
- csp_policies = extract_csp(self.flow.response.headers)
-
- for header_name, _ in csp_header_names_and_dispositions:
- del self.flow.response.headers[header_name]
-
- for policy in csp_policies:
- if policy.disposition != 'enforce':
- continue
-
- policy.directives.pop('report-to')
- policy.directives.pop('report-uri')
-
- self.flow.response.headers.add(
- policy.header_name,
- policy.serialize()
- )
-
- extra_csp = ';'.join((
- "script-src 'none'",
- "script-src-elem 'none'",
- "script-src-attr 'none'"
- ))
-
- self.flow.response.headers.add('Content-Security-Policy', extra_csp)
-
- return None
-
-# For details of 'Content-Type' header's structure, see:
-# https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.1
-content_type_reg = re.compile(r'''
-^
-(?P<mime>[\w-]+/[\w-]+)
-\s*
-(?:
- ;
- (?:[^;]*;)* # match possible parameter other than "charset"
-)
-\s*
-charset= # no whitespace allowed in parameter as per RFC
-(?P<encoding>
- [\w-]+
- |
- "[\w-]+" # quotes are optional per RFC
-)
-(?:;[^;]+)* # match possible parameter other than "charset"
-$ # forbid possible dangling characters after closing '"'
-''', re.VERBOSE | re.IGNORECASE)
-
-def deduce_content_type(headers: Headers) \
- -> tuple[t.Optional[str], t.Optional[str]]:
- """...."""
- content_type = headers.get('content-type')
- if content_type is None:
- return (None, None)
-
- match = content_type_reg.match(content_type)
- if match is None:
- return (None, None)
-
- mime, encoding = match.group('mime'), match.group('encoding')
-
- if encoding is not None:
- encoding = encoding.lower()
-
- return mime, encoding
-
-UTF8_BOM = b'\xEF\xBB\xBF'
-BOMs = (
- (UTF8_BOM, 'utf-8'),
- (b'\xFE\xFF', 'utf-16be'),
- (b'\xFF\xFE', 'utf-16le')
-)
-
-def block_attr(element: bs4.PageElement, atrr_name: str) -> None:
- """...."""
- # TODO: implement
- pass
-
-@dc.dataclass(frozen=True)
-class FlowHandlerInjectPayload(FlowHandler):
- """...."""
- policy: policies.PayloadPolicy
-
- stream_request: bool = True
-
- def __post_init__(self) -> None:
- """...."""
- script_src = f"script-src {self.policy.assets_base_url()}"
- if self.policy.is_eval_allowed():
- script_src = f"{script_src} 'unsafe-eval'"
-
- self.new_csp = '; '.join((
- script_src,
- "script-src-elem 'none'",
- "script-src-attr 'none'"
- ))
-
- def on_responseheaders(self) -> t.Optional[StateUpdater]:
- """...."""
- super().on_responseheaders()
-
- assert self.flow.response is not None
-
- for header_name, _ in csp_header_names_and_dispositions:
- del self.flow.response.headers[header_name]
-
- self.flow.response.headers.add('Content-Security-Policy', self.new_csp)
-
- return None
-
- def on_response(self) -> t.Optional[StateUpdater]:
- """...."""
- super().on_response()
-
- assert self.flow.response is not None
-
- if self.flow.response.content is None:
- return None
-
- mime, encoding = deduce_content_type(self.flow.response.headers)
- if mime is None or 'html' not in mime:
- return None
-
- # A UTF BOM overrides encoding specified by the header.
- for bom, encoding_name in BOMs:
- if self.flow.response.content.startswith(bom):
- encoding = encoding_name
-
- soup = bs4.BeautifulSoup(
- markup = self.flow.response.content,
- from_encoding = encoding,
- features = 'html5lib'
- )
-
- # Inject scripts.
- script_parent = soup.find('body') or soup.find('html')
- if script_parent is None:
- return None
-
- for url in self.policy.script_urls():
- script_parent.append(bs4.Tag(name='script', attrs={'src': url}))
-
- # Remove Content Security Policy that could possibly block injected
- # scripts.
- for meta in soup.select('head meta[http-equiv]'):
- header_name = meta.attrs.get('http-equiv', '').lower().strip()
- if header_name in csp_enforce_header_names_set:
- block_attr(meta, 'http-equiv')
- block_attr(meta, 'content')
-
- # Appending a three-byte Byte Order Mark (BOM) will force the browser to
- # decode this as UTF-8 regardless of the 'Content-Type' header. See:
- # https://www.w3.org/International/tests/repository/html5/the-input-byte-stream/results-basics#precedence
- self.flow.response.content = UTF8_BOM + soup.encode()
-
- return None
-
-@dc.dataclass(frozen=True)
-class FlowHandlerMetaResource(FlowHandler):
- """...."""
- policy: policies.MetaResourcePolicy
-
- def on_request(self) -> t.Optional[StateUpdater]:
- """...."""
- super().on_request()
- # TODO: implement
- #self.flow.response = ....
-
- return None
-
-def make_flow_handler(flow: http.HTTPFlow, policy: policies.Policy) \
- -> FlowHandler:
- """...."""
- if isinstance(policy, policies.BlockPolicy):
- return FlowHandlerBlockScripts(flow, policy)
-
- if isinstance(policy, policies.AllowPolicy):
- return FlowHandlerAllowScripts(flow, policy)
-
- if isinstance(policy, policies.PayloadPolicy):
- return FlowHandlerInjectPayload(flow, policy)
-
- assert isinstance(policy, policies.MetaResourcePolicy)
- # def response_creator(request: http.HTTPRequest) -> http.HTTPResponse:
- # """...."""
- # replacement_details = make_replacement_resource(
- # policy.replacement,
- # request.path
- # )
-
- # return http.HTTPResponse.make(
- # replacement_details.status_code,
- # replacement_details.content,
- # replacement_details.content_type
- # )
- return FlowHandlerMetaResource(flow, policy)
diff --git a/src/hydrilla/proxy/oversimplified_state_impl.py b/src/hydrilla/proxy/oversimplified_state_impl.py
new file mode 100644
index 0000000..c082add
--- /dev/null
+++ b/src/hydrilla/proxy/oversimplified_state_impl.py
@@ -0,0 +1,392 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy data and configuration (instantiatable HaketiloState subtype).
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+This module contains logic for keeping track of all settings, rules, mappings
+and resources.
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import secrets
+import threading
+import typing as t
+import dataclasses as dc
+
+from pathlib import Path
+
+from ..pattern_tree import PatternTree
+from .. import url_patterns
+from .. import versions
+from .. import item_infos
+from .simple_dependency_satisfying import ItemsCollection, ComputedPayload
+from . import state as st
+from . import policies
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class ConcreteRepoRef(st.RepoRef):
+ def remove(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+ def update(
+ self,
+ state: st.HaketiloState,
+ *,
+ name: t.Optional[str] = None,
+ url: t.Optional[str] = None
+ ) -> ConcreteRepoRef:
+ raise NotImplementedError()
+
+ def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteRepoIterationRef(st.RepoIterationRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteMappingRef(st.MappingRef):
+ def disable(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+ def forget_enabled(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteMappingVersionRef(st.MappingVersionRef):
+ def enable(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteResourceRef(st.ResourceRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteResourceVersionRef(st.ResourceVersionRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcretePayloadRef(st.PayloadRef):
+ computed_payload: ComputedPayload = dc.field(hash=False, compare=False)
+
+ def get_data(self, state: st.HaketiloState) -> st.PayloadData:
+ return t.cast(ConcreteHaketiloState, state).payloads_data[self.id]
+
+ def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
+ return 'to implement'
+
+ def get_script_paths(self, state: st.HaketiloState) \
+ -> t.Iterator[t.Sequence[str]]:
+ for resource_info in self.computed_payload.resources:
+ for file_spec in resource_info.scripts:
+ yield (resource_info.identifier, *file_spec.name.split('/'))
+
+ def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
+ -> t.Optional[st.FileData]:
+ if len(path) == 0:
+ raise st.MissingItemError()
+
+ resource_identifier, *file_name_segments = path
+
+ file_name = '/'.join(file_name_segments)
+
+ script_sha256 = ''
+
+ matched_resource_info = False
+
+ for resource_info in self.computed_payload.resources:
+ if resource_info.identifier == resource_identifier:
+ matched_resource_info = True
+
+ for script_spec in resource_info.scripts:
+ if script_spec.name == file_name:
+ script_sha256 = script_spec.sha256
+
+ break
+
+ if not matched_resource_info:
+ raise st.MissingItemError(resource_identifier)
+
+ if script_sha256 == '':
+ return None
+
+ store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir
+ files_dir_path = store_dir_path / 'temporary_malcontent' / 'file'
+ file_path = files_dir_path / 'sha256' / script_sha256
+
+ return st.FileData(
+ type = 'application/javascript',
+ name = file_name,
+ contents = file_path.read_bytes()
+ )
+
+
+# @dc.dataclass(frozen=True, unsafe_hash=True)
+# class DummyPayloadRef(ConcretePayloadRef):
+# paths = {
+# ('someresource', 'somefolder', 'somescript.js'): st.FileData(
+# type = 'application/javascript',
+# name = 'somefolder/somescript.js',
+# contents = b'console.log("hello, mitmproxy")'
+# )
+# }
+
+# def get_data(self, state: st.HaketiloState) -> st.PayloadData:
+# parsed_pattern = next(url_patterns.parse_pattern('https://example.com'))
+
+# return st.PayloadData(
+# payload_ref = self,
+# mapping_installed = True,
+# explicitly_enabled = True,
+# unique_token = 'g54v45g456h4r',
+# pattern = parsed_pattern,
+# eval_allowed = True,
+# cors_bypass_allowed = True
+# )
+
+# def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
+# return ConcreteMappingVersionRef('somemapping')
+
+# def get_file_paths(self, state: st.HaketiloState) \
+# -> t.Iterable[t.Sequence[str]]:
+# return tuple(self.paths.keys())
+
+# def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
+# -> t.Optional[st.FileData]:
+# return self.paths[tuple(path)]
+
+
+PolicyTree = PatternTree[policies.PolicyFactory]
+
+def register_payload(
+ policy_tree: PolicyTree,
+ payload_key: st.PayloadKey,
+ token: str
+) -> PolicyTree:
+ """...."""
+ payload_policy_factory = policies.PayloadPolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ payload_key.pattern,
+ payload_policy_factory
+ )
+
+ resource_policy_factory = policies.PayloadResourcePolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ payload_key.pattern.path_append(token, '***'),
+ resource_policy_factory
+ )
+
+ return policy_tree
+
+DataById = t.Mapping[str, st.PayloadData]
+
+AnyInfo = t.TypeVar('AnyInfo', item_infos.ResourceInfo, item_infos.MappingInfo)
+
+@dc.dataclass
+class ConcreteHaketiloState(st.HaketiloState):
+ store_dir: Path
+ # settings: state.HaketiloGlobalSettings
+ policy_tree: PolicyTree = PatternTree()
+ payloads_data: DataById = dc.field(default_factory=dict)
+
+ lock: threading.RLock = dc.field(default_factory=threading.RLock)
+
+ def __post_init__(self) -> None:
+ def newest_item_path(item_dir: Path) -> t.Optional[Path]:
+ available_versions = tuple(
+ versions.parse_normalize_version(ver_path.name)
+ for ver_path in item_dir.iterdir()
+ if ver_path.is_file()
+ )
+
+ if available_versions == ():
+ return None
+
+ newest_version = max(available_versions)
+
+ version_path = item_dir / versions.version_string(newest_version)
+
+ assert version_path.is_file()
+
+ return version_path
+
+ def read_items(dir_path: Path, item_class: t.Type[AnyInfo]) \
+ -> t.Mapping[str, AnyInfo]:
+ items: dict[str, AnyInfo] = {}
+
+ for resource_dir in dir_path.iterdir():
+ if not resource_dir.is_dir():
+ continue
+
+ item_path = newest_item_path(resource_dir)
+ if item_path is None:
+ continue
+
+ item = item_class.load(item_path)
+
+ assert versions.version_string(item.version) == item_path.name
+ assert item.identifier == resource_dir.name
+
+ items[item.identifier] = item
+
+ return items
+
+ malcontent_dir = self.store_dir / 'temporary_malcontent'
+
+ items_collection = ItemsCollection(
+ read_items(malcontent_dir / 'resource', item_infos.ResourceInfo),
+ read_items(malcontent_dir / 'mapping', item_infos.MappingInfo)
+ )
+ computed_payloads = items_collection.compute_payloads()
+
+ payloads_data = {}
+
+ for mapping_info, by_pattern in computed_payloads.items():
+ for num, (pattern, payload) in enumerate(by_pattern.items()):
+ payload_id = f'{num}@{mapping_info.identifier}'
+
+ ref = ConcretePayloadRef(payload_id, payload)
+
+ data = st.PayloadData(
+ payload_ref = ref,
+ mapping_installed = True,
+ explicitly_enabled = True,
+ unique_token = secrets.token_urlsafe(16),
+ pattern = pattern,
+ eval_allowed = payload.allows_eval,
+ cors_bypass_allowed = payload.allows_cors_bypass
+ )
+
+ payloads_data[payload_id] = data
+
+ key = st.PayloadKey(
+ payload_ref = ref,
+ mapping_identifier = mapping_info.identifier,
+ pattern = pattern
+ )
+
+ self.policy_tree = register_payload(
+ self.policy_tree,
+ key,
+ data.unique_token
+ )
+
+ self.payloads_data = payloads_data
+
+ def get_repo(self, repo_id: str) -> st.RepoRef:
+ return ConcreteRepoRef(repo_id)
+
+ def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef:
+ return ConcreteRepoIterationRef(repo_iteration_id)
+
+ def get_mapping(self, mapping_id: str) -> st.MappingRef:
+ return ConcreteMappingRef(mapping_id)
+
+ def get_mapping_version(self, mapping_version_id: str) \
+ -> st.MappingVersionRef:
+ return ConcreteMappingVersionRef(mapping_version_id)
+
+ def get_resource(self, resource_id: str) -> st.ResourceRef:
+ return ConcreteResourceRef(resource_id)
+
+ def get_resource_version(self, resource_version_id: str) \
+ -> st.ResourceVersionRef:
+ return ConcreteResourceVersionRef(resource_version_id)
+
+ def get_payload(self, payload_id: str) -> st.PayloadRef:
+ return 'not implemented'
+
+ def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \
+ -> st.RepoRef:
+ raise NotImplementedError()
+
+ def get_settings(self) -> st.HaketiloGlobalSettings:
+ return st.HaketiloGlobalSettings(
+ mapping_use_mode = st.MappingUseMode.AUTO,
+ default_allow_scripts = True,
+ repo_refresh_seconds = 0
+ )
+
+ def update_settings(
+ self,
+ *,
+ mapping_use_mode: t.Optional[st.MappingUseMode] = None,
+ default_allow_scripts: t.Optional[bool] = None,
+ repo_refresh_seconds: t.Optional[int] = None
+ ) -> None:
+ raise NotImplementedError()
+
+ def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy:
+ """...."""
+ with self.lock:
+ policy_tree = self.policy_tree
+
+ try:
+ best_priority: int = 0
+ best_policy: t.Optional[policies.Policy] = None
+
+ for factories_set in policy_tree.search(url):
+ for stored_factory in sorted(factories_set):
+ factory = stored_factory.item
+
+ policy = factory.make_policy(self)
+
+ if policy.priority > best_priority:
+ best_priority = policy.priority
+ best_policy = policy
+ except Exception as e:
+ return policies.ErrorBlockPolicy(
+ builtin = True,
+ error = e
+ )
+
+ if best_policy is not None:
+ return best_policy
+
+ if self.get_settings().default_allow_scripts:
+ return policies.FallbackAllowPolicy()
+ else:
+ return policies.FallbackBlockPolicy()
+
+ @staticmethod
+ def make(store_dir: Path) -> 'ConcreteHaketiloState':
+ return ConcreteHaketiloState(store_dir=store_dir)
diff --git a/src/hydrilla/proxy/policies/__init__.py b/src/hydrilla/proxy/policies/__init__.py
new file mode 100644
index 0000000..66e07ee
--- /dev/null
+++ b/src/hydrilla/proxy/policies/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-License-Identifier: CC0-1.0
+
+# Copyright (C) 2022 Wojtek Kosior <koszko@koszko.org>
+#
+# Available under the terms of Creative Commons Zero v1.0 Universal.
+
+from .base import *
+
+from .payload import PayloadPolicyFactory
+
+from .payload_resource import PayloadResourcePolicyFactory
+
+from .rule import RuleBlockPolicyFactory, RuleAllowPolicyFactory
+
+from .fallback import FallbackAllowPolicy, FallbackBlockPolicy, ErrorBlockPolicy
diff --git a/src/hydrilla/proxy/policies/base.py b/src/hydrilla/proxy/policies/base.py
new file mode 100644
index 0000000..3bde6f2
--- /dev/null
+++ b/src/hydrilla/proxy/policies/base.py
@@ -0,0 +1,177 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Base defintions for policies for altering HTTP requests.
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+.....
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import sys
+
+if sys.version_info >= (3, 8):
+ from typing import Protocol
+else:
+ from typing_extensions import Protocol
+
+import dataclasses as dc
+import typing as t
+import enum
+
+from abc import ABC, abstractmethod
+
+from immutables import Map
+
+from ...url_patterns import ParsedUrl
+from .. import state
+
+
+class PolicyPriority(int, enum.Enum):
+ """...."""
+ _ONE = 1
+ _TWO = 2
+ _THREE = 3
+
+DefaultGetValue = t.TypeVar('DefaultGetValue', object, None)
+
+class IHeaders(Protocol):
+ """...."""
+ def __getitem__(self, key: str) -> str: ...
+
+ def get_all(self, key: str) -> t.Iterable[str]: ...
+
+ def get(self, key: str, default: DefaultGetValue = None) \
+ -> t.Union[str, DefaultGetValue]: ...
+
+ def items(self) -> t.Iterable[tuple[str, str]]: ...
+
+def encode_headers_items(headers: t.Iterable[tuple[str, str]]) \
+ -> t.Iterable[tuple[bytes, bytes]]:
+ """...."""
+ for name, value in headers:
+ yield name.encode(), value.encode()
+
+@dc.dataclass(frozen=True)
+class ProducedRequest:
+ """...."""
+ url: str
+ method: str
+ headers: t.Iterable[tuple[bytes, bytes]]
+ body: bytes
+
+@dc.dataclass(frozen=True)
+class RequestInfo:
+ """...."""
+ url: ParsedUrl
+ method: str
+ headers: IHeaders
+ body: bytes
+
+ def make_produced_request(self) -> ProducedRequest:
+ """...."""
+ return ProducedRequest(
+ url = self.url.orig_url,
+ method = self.method,
+ headers = encode_headers_items(self.headers.items()),
+ body = self.body
+ )
+
+@dc.dataclass(frozen=True)
+class ProducedResponse:
+ """...."""
+ status_code: int
+ headers: t.Iterable[tuple[bytes, bytes]]
+ body: bytes
+
+@dc.dataclass(frozen=True)
+class ResponseInfo:
+ """...."""
+ url: ParsedUrl
+ status_code: int
+ headers: IHeaders
+ body: bytes
+
+ def make_produced_response(self) -> ProducedResponse:
+ """...."""
+ return ProducedResponse(
+ status_code = self.status_code,
+ headers = encode_headers_items(self.headers.items()),
+ body = self.body
+ )
+
+class Policy(ABC):
+ """...."""
+ process_request: t.ClassVar[bool] = False
+ process_response: t.ClassVar[bool] = False
+
+ priority: t.ClassVar[PolicyPriority]
+
+ @property
+ def anticache(self) -> bool:
+ return self.process_request or self.process_response
+
+ def consume_request(self, request_info: RequestInfo) \
+ -> t.Optional[t.Union[ProducedRequest, ProducedResponse]]:
+ """...."""
+ return None
+
+ def consume_response(self, response_info: ResponseInfo) \
+ -> t.Optional[ProducedResponse]:
+ """...."""
+ return None
+
+
+# mypy needs to be corrected:
+# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class PolicyFactory(ABC):
+ """...."""
+ builtin: bool
+
+ @abstractmethod
+ def make_policy(self, haketilo_state: state.HaketiloState) \
+ -> t.Optional[Policy]:
+ """...."""
+ ...
+
+ def __lt__(self, other: 'PolicyFactory'):
+ """...."""
+ return sorting_keys.get(self.__class__.__name__, 999) < \
+ sorting_keys.get(other.__class__.__name__, 999)
+
+sorting_order = (
+ 'PayloadResourcePolicyFactory',
+
+ 'PayloadPolicyFactory',
+
+ 'RuleBlockPolicyFactory',
+ 'RuleAllowPolicyFactory',
+
+ 'FallbackPolicyFactory'
+)
+
+sorting_keys = Map((cls, name) for name, cls in enumerate(sorting_order))
diff --git a/src/hydrilla/proxy/policies.py b/src/hydrilla/proxy/policies/fallback.py
index 5e9451b..75da61c 100644
--- a/src/hydrilla/proxy/policies.py
+++ b/src/hydrilla/proxy/policies/fallback.py
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: GPL-3.0-or-later
-# Various policies for altering HTTP requests.
+# Policies for blocking and allowing JS when no other policies match.
#
# This file is part of Hydrilla&Haketilo.
#
@@ -24,53 +24,37 @@
# file's license. Although I request that you do not make use this code
# in a proprietary program, I am not going to enforce this in court.
-import dataclasses as dc
-import typing as t
+"""
+.....
+"""
-from abc import ABC
+# Enable using with Python 3.7.
+from __future__ import annotations
-class Policy(ABC):
- pass
-
-class PayloadPolicy(Policy):
- """...."""
- def assets_base_url(self) -> str:
- """...."""
- return 'https://example.com/static/'
-
- def script_urls(self) -> t.Sequence[str]:
- """...."""
- # TODO: implement
- return ('https://example.com/static/somescript.js',)
-
- def is_eval_allowed(self) -> bool:
- """...."""
- # TODO: implement
- return True
+import dataclasses as dc
+import typing as t
+import enum
-class MetaResourcePolicy(Policy):
- pass
+from abc import ABC, abstractmethod
-class AllowPolicy(Policy):
- pass
+from .. import state
+from . import base
+from .rule import AllowPolicy, BlockPolicy
-@dc.dataclass
-class RuleAllowPolicy(AllowPolicy):
- pattern: str
class FallbackAllowPolicy(AllowPolicy):
- pass
-
-class BlockPolicy(Policy):
- pass
+ """....."""
+ priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE
-@dc.dataclass
-class RuleBlockPolicy(BlockPolicy):
- pattern: str
class FallbackBlockPolicy(BlockPolicy):
- pass
+ """...."""
+ priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE
-@dc.dataclass
+
+@dc.dataclass(frozen=True)
class ErrorBlockPolicy(BlockPolicy):
+ """...."""
error: Exception
+
+ builtin: bool = True
diff --git a/src/hydrilla/proxy/policies/payload.py b/src/hydrilla/proxy/policies/payload.py
new file mode 100644
index 0000000..d616f1b
--- /dev/null
+++ b/src/hydrilla/proxy/policies/payload.py
@@ -0,0 +1,315 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Policies for applying payload injections to HTTP requests.
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+.....
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import dataclasses as dc
+import typing as t
+import re
+
+import bs4 # type: ignore
+
+from ...url_patterns import ParsedUrl
+from .. import state
+from .. import csp
+from . import base
+
+@dc.dataclass(frozen=True) # type: ignore[misc]
+class PayloadAwarePolicy(base.Policy):
+ """...."""
+ haketilo_state: state.HaketiloState
+ payload_data: state.PayloadData
+
+ def assets_base_url(self, request_url: ParsedUrl):
+ """...."""
+ token = self.payload_data.unique_token
+
+ base_path_segments = (*self.payload_data.pattern.path_segments, token)
+
+ return f'{request_url.url_without_path}/{"/".join(base_path_segments)}/'
+
+
+@dc.dataclass(frozen=True) # type: ignore[misc]
+class PayloadAwarePolicyFactory(base.PolicyFactory):
+ """...."""
+ payload_key: state.PayloadKey
+
+ @property
+ def payload_ref(self) -> state.PayloadRef:
+ """...."""
+ return self.payload_key.payload_ref
+
+ def __lt__(self, other: base.PolicyFactory) -> bool:
+ """...."""
+ if isinstance(other, type(self)):
+ return self.payload_key < other.payload_key
+
+ return super().__lt__(other)
+
+
+# For details of 'Content-Type' header's structure, see:
+# https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.1
+content_type_reg = re.compile(r'''
+^
+(?P<mime>[\w-]+/[\w-]+)
+\s*
+(?:
+ ;
+ (?:[^;]*;)* # match possible parameter other than "charset"
+)
+\s*
+charset= # no whitespace allowed in parameter as per RFC
+(?P<encoding>
+ [\w-]+
+ |
+ "[\w-]+" # quotes are optional per RFC
+)
+(?:;[^;]+)* # match possible parameter other than "charset"
+$ # forbid possible dangling characters after closing '"'
+''', re.VERBOSE | re.IGNORECASE)
+
+def deduce_content_type(headers: base.IHeaders) \
+ -> tuple[t.Optional[str], t.Optional[str]]:
+ """...."""
+ content_type = headers.get('content-type')
+ if content_type is None:
+ return (None, None)
+
+ match = content_type_reg.match(content_type)
+ if match is None:
+ return (None, None)
+
+ mime, encoding = match.group('mime'), match.group('encoding')
+
+ if encoding is not None:
+ encoding = encoding.lower()
+
+ return mime, encoding
+
+UTF8_BOM = b'\xEF\xBB\xBF'
+BOMs = (
+ (UTF8_BOM, 'utf-8'),
+ (b'\xFE\xFF', 'utf-16be'),
+ (b'\xFF\xFE', 'utf-16le')
+)
+
+def block_attr(element: bs4.PageElement, attr_name: str) -> None:
+ """
+ Disable HTML node attributes by prepending `blocked-'. This allows them to
+ still be relatively easily accessed in case they contain some useful data.
+ """
+ blocked_value = element.attrs.pop(attr_name, None)
+
+ while blocked_value is not None:
+ attr_name = f'blocked-{attr_name}'
+ next_blocked_value = element.attrs.pop(attr_name, None)
+ element.attrs[attr_name] = blocked_value
+
+ blocked_value = next_blocked_value
+
+@dc.dataclass(frozen=True)
+class PayloadInjectPolicy(PayloadAwarePolicy):
+ """...."""
+ process_response: t.ClassVar[bool] = True
+
+ priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._TWO
+
+ def _new_csp(self, request_url: ParsedUrl) -> str:
+ """...."""
+ assets_base = self.assets_base_url(request_url)
+
+ script_src = f"script-src {assets_base}"
+
+ if self.payload_data.eval_allowed:
+ script_src = f"{script_src} 'unsafe-eval'"
+
+ return '; '.join((
+ script_src,
+ "script-src-elem 'none'",
+ "script-src-attr 'none'"
+ ))
+
+ def _modify_headers(self, response_info: base.ResponseInfo) \
+ -> t.Iterable[tuple[bytes, bytes]]:
+ """...."""
+ for header_name, header_value in response_info.headers.items():
+ if header_name.lower() not in csp.header_names_and_dispositions:
+ yield header_name.encode(), header_value.encode()
+
+ new_csp = self._new_csp(response_info.url)
+
+ yield b'Content-Security-Policy', new_csp.encode()
+
+ def _script_urls(self, url: ParsedUrl) -> t.Iterable[str]:
+ """...."""
+ base_url = self.assets_base_url(url)
+ payload_ref = self.payload_data.payload_ref
+
+ for path in payload_ref.get_script_paths(self.haketilo_state):
+ yield base_url + '/'.join(('static', *path))
+
+ def _modify_body(
+ self,
+ url: ParsedUrl,
+ body: bytes,
+ encoding: t.Optional[str]
+ ) -> bytes:
+ """...."""
+ soup = bs4.BeautifulSoup(
+ markup = body,
+ from_encoding = encoding,
+ features = 'html5lib'
+ )
+
+ # Inject scripts.
+ script_parent = soup.find('body') or soup.find('html')
+ if script_parent is None:
+ return body
+
+ for script_url in self._script_urls(url):
+ tag = bs4.Tag(name='script', attrs={'src': script_url})
+ script_parent.append(tag)
+
+ # Remove Content Security Policy that could possibly block injected
+ # scripts.
+ for meta in soup.select('head meta[http-equiv]'):
+ header_name = meta.attrs.get('http-equiv', '').lower().strip()
+ if header_name in csp.enforce_header_names_set:
+ block_attr(meta, 'http-equiv')
+ block_attr(meta, 'content')
+
+ # Appending a three-byte Byte Order Mark (BOM) will force the browser to
+ # decode this as UTF-8 regardless of the 'Content-Type' header. See:
+ # https://www.w3.org/International/tests/repository/html5/the-input-byte-stream/results-basics#precedence
+ return UTF8_BOM + soup.encode()
+
+ def _consume_response_unsafe(self, response_info: base.ResponseInfo) \
+ -> base.ProducedResponse:
+ """...."""
+ new_response = response_info.make_produced_response()
+
+ new_headers = self._modify_headers(response_info)
+
+ new_response = dc.replace(new_response, headers=new_headers)
+
+ mime, encoding = deduce_content_type(response_info.headers)
+ if mime is None or 'html' not in mime.lower():
+ return new_response
+
+ data = response_info.body
+ if data is None:
+ data = b''
+
+ # A UTF BOM overrides encoding specified by the header.
+ for bom, encoding_name in BOMs:
+ if data.startswith(bom):
+ encoding = encoding_name
+
+ new_data = self._modify_body(response_info.url, data, encoding)
+
+ return dc.replace(new_response, body=new_data)
+
+ def consume_response(self, response_info: base.ResponseInfo) \
+ -> base.ProducedResponse:
+ """...."""
+ try:
+ return self._consume_response_unsafe(response_info)
+ except Exception as e:
+ # TODO: actually describe the errors
+ import traceback
+
+ error_info_list = traceback.format_exception(
+ type(e),
+ e,
+ e.__traceback__
+ )
+
+ return base.ProducedResponse(
+ 500,
+ ((b'Content-Type', b'text/plain; charset=utf-8'),),
+ '\n'.join(error_info_list).encode()
+ )
+
+
+class AutoPayloadInjectPolicy(PayloadInjectPolicy):
+ """...."""
+ priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE
+
+ def _modify_body(
+ self,
+ url: ParsedUrl,
+ body: bytes,
+ encoding: t.Optional[str]
+ ) -> bytes:
+ """...."""
+ payload_ref = self.payload_data.payload_ref
+ mapping_ref = payload_ref.get_mapping(self.haketilo_state)
+ mapping_ref.enable(self.haketilo_state)
+
+ return super()._modify_body(url, body, encoding)
+
+
+@dc.dataclass(frozen=True)
+class PayloadSuggestPolicy(PayloadAwarePolicy):
+ """...."""
+ priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE
+
+ def make_response(self, request_info: base.RequestInfo) \
+ -> base.ProducedResponse:
+ """...."""
+ # TODO: implement
+ return base.ProducedResponse(200, ((b'a', b'b'),), b'')
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class PayloadPolicyFactory(PayloadAwarePolicyFactory):
+ """...."""
+ def make_policy(self, haketilo_state: state.HaketiloState) \
+ -> t.Optional[base.Policy]:
+ """...."""
+ try:
+ payload_data = self.payload_ref.get_data(haketilo_state)
+ except:
+ return None
+
+ if payload_data.explicitly_enabled:
+ return PayloadInjectPolicy(haketilo_state, payload_data)
+
+ mode = haketilo_state.get_settings().mapping_use_mode
+
+ if mode == state.MappingUseMode.QUESTION:
+ return PayloadSuggestPolicy(haketilo_state, payload_data)
+
+ if mode == state.MappingUseMode.WHEN_ENABLED:
+ return None
+
+ # mode == state.MappingUseMode.AUTO
+ return AutoPayloadInjectPolicy(haketilo_state, payload_data)
diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py
new file mode 100644
index 0000000..84d0919
--- /dev/null
+++ b/src/hydrilla/proxy/policies/payload_resource.py
@@ -0,0 +1,160 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Policies for resolving HTTP requests with local resources.
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+.....
+
+We make file resources available to HTTP clients by mapping them
+at:
+ http(s)://<pattern-matching_origin>/<pattern_path>/<token>/
+where <token> is a per-session secret unique for every mapping.
+For example, a payload with pattern like the following:
+ http*://***.example.com/a/b/**
+Could cause resources to be mapped (among others) at each of:
+ https://example.com/a/b/**/Da2uiF2UGfg/
+ https://www.example.com/a/b/**/Da2uiF2UGfg/
+ http://gnome.vs.kde.example.com/a/b/**/Da2uiF2UGfg/
+
+Unauthorized web pages running in the user's browser are exected to be
+unable to guess the secret. This way we stop them from spying on the
+user and from interfering with Haketilo's normal operation.
+
+This is only a soft prevention method. With some mechanisms
+(e.g. service workers), under certain scenarios, it might be possible
+to bypass it. Thus, to make the risk slightly smaller, we also block
+the unauthorized accesses that we can detect.
+
+Since a web page authorized to access the resources may only be served
+when the corresponding mapping is enabled (or AUTO mode is on), we
+consider accesses to non-enabled mappings' resources a security breach
+and block them by responding with 403 Not Found.
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import dataclasses as dc
+import typing as t
+
+from ...translations import smart_gettext as _
+from .. import state
+from . import base
+from .payload import PayloadAwarePolicy, PayloadAwarePolicyFactory
+
+
+@dc.dataclass(frozen=True)
+class PayloadResourcePolicy(PayloadAwarePolicy):
+ """...."""
+ process_request: t.ClassVar[bool] = True
+
+ priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._THREE
+
+ def _make_file_resource_response(self, path: tuple[str, ...]) \
+ -> base.ProducedResponse:
+ """...."""
+ try:
+ file_data = self.payload_data.payload_ref.get_file_data(
+ self.haketilo_state,
+ path
+ )
+ except state.MissingItemError:
+ return resource_blocked_response
+
+ if file_data is None:
+ return base.ProducedResponse(
+ 404,
+ [(b'Content-Type', b'text/plain; charset=utf-8')],
+ _('api.file_not_found').encode()
+ )
+
+ return base.ProducedResponse(
+ 200,
+ ((b'Content-Type', file_data.type.encode()),),
+ file_data.contents
+ )
+
+ def consume_request(self, request_info: base.RequestInfo) \
+ -> base.ProducedResponse:
+ """...."""
+ # Payload resource pattern has path of the form:
+ # "/some/arbitrary/segments/<per-session_token>/***"
+ #
+ # Corresponding requests shall have path of the form:
+ # "/some/arbitrary/segments/<per-session_token>/actual/resource/path"
+ #
+ # Here we need to extract the "/actual/resource/path" part.
+ segments_to_drop = len(self.payload_data.pattern.path_segments) + 1
+ resource_path = request_info.url.path_segments[segments_to_drop:]
+
+ if resource_path == ():
+ return resource_blocked_response
+ elif resource_path[0] == 'static':
+ return self._make_file_resource_response(resource_path[1:])
+ elif resource_path[0] == 'api':
+ # TODO: implement Haketilo APIs
+ return resource_blocked_response
+ else:
+ return resource_blocked_response
+
+
+resource_blocked_response = base.ProducedResponse(
+ 403,
+ [(b'Content-Type', b'text/plain; charset=utf-8')],
+ _('api.resource_not_enabled_for_access').encode()
+)
+
+@dc.dataclass(frozen=True)
+class BlockedResponsePolicy(base.Policy):
+ """...."""
+ process_request: t.ClassVar[bool] = True
+
+ priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._THREE
+
+ def consume_request(self, request_info: base.RequestInfo) \
+ -> base.ProducedResponse:
+ """...."""
+ return resource_blocked_response
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class PayloadResourcePolicyFactory(PayloadAwarePolicyFactory):
+ """...."""
+ def make_policy(self, haketilo_state: state.HaketiloState) \
+ -> t.Union[PayloadResourcePolicy, BlockedResponsePolicy]:
+ """...."""
+ try:
+ payload_data = self.payload_ref.get_data(haketilo_state)
+ except state.MissingItemError:
+ return BlockedResponsePolicy()
+
+ if not payload_data.explicitly_enabled and \
+ haketilo_state.get_settings().mapping_use_mode != \
+ state.MappingUseMode.AUTO:
+ return BlockedResponsePolicy()
+
+ return PayloadResourcePolicy(haketilo_state, payload_data)
+
+
diff --git a/src/hydrilla/proxy/policies/rule.py b/src/hydrilla/proxy/policies/rule.py
new file mode 100644
index 0000000..eb70147
--- /dev/null
+++ b/src/hydrilla/proxy/policies/rule.py
@@ -0,0 +1,134 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Policies for blocking and allowing JS in pages fetched with HTTP.
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+.....
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import dataclasses as dc
+import typing as t
+
+from ...url_patterns import ParsedPattern
+from .. import csp
+from .. import state
+from . import base
+
+
+class AllowPolicy(base.Policy):
+ """...."""
+ priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._TWO
+
+class BlockPolicy(base.Policy):
+ """...."""
+ process_response: t.ClassVar[bool] = True
+
+ priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._TWO
+
+ def _modify_headers(self, response_info: base.ResponseInfo) \
+ -> t.Iterable[tuple[bytes, bytes]]:
+ """...."""
+ csp_policies = csp.extract(response_info.headers)
+
+ for header_name, header_value in response_info.headers.items():
+ if header_name.lower() not in csp.header_names_and_dispositions:
+ yield header_name.encode(), header_value.encode()
+
+ for policy in csp_policies:
+ if policy.disposition != 'enforce':
+ continue
+
+ directives = policy.directives.mutate()
+ directives.pop('report-to', None)
+ directives.pop('report-uri', None)
+
+ policy = dc.replace(policy, directives=directives.finish())
+
+ yield policy.header_name.encode(), policy.serialize().encode()
+
+ extra_csp = ';'.join((
+ "script-src 'none'",
+ "script-src-elem 'none'",
+ "script-src-attr 'none'"
+ ))
+
+ yield b'Content-Security-Policy', extra_csp.encode()
+
+
+ def consume_response(self, response_info: base.ResponseInfo) \
+ -> base.ProducedResponse:
+ """...."""
+ new_response = response_info.make_produced_response()
+
+ new_headers = self._modify_headers(response_info)
+
+ return dc.replace(new_response, headers=new_headers)
+
+@dc.dataclass(frozen=True)
+class RuleAllowPolicy(AllowPolicy):
+ """...."""
+ pattern: ParsedPattern
+
+
+@dc.dataclass(frozen=True)
+class RuleBlockPolicy(BlockPolicy):
+ """...."""
+ pattern: ParsedPattern
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class RulePolicyFactory(base.PolicyFactory):
+ """...."""
+ pattern: ParsedPattern
+
+ def __lt__(self, other: base.PolicyFactory) -> bool:
+ """...."""
+ if type(other) is type(self):
+ return super().__lt__(other)
+
+ assert isinstance(other, RulePolicyFactory)
+
+ return self.pattern < other.pattern
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class RuleBlockPolicyFactory(RulePolicyFactory):
+ """...."""
+ def make_policy(self, haketilo_state: state.HaketiloState) \
+ -> RuleBlockPolicy:
+ """...."""
+ return RuleBlockPolicy(self.pattern)
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class RuleAllowPolicyFactory(RulePolicyFactory):
+ """...."""
+ def make_policy(self, haketilo_state: state.HaketiloState) \
+ -> RuleAllowPolicy:
+ """...."""
+ return RuleAllowPolicy(self.pattern)
diff --git a/src/hydrilla/proxy/simple_dependency_satisfying.py b/src/hydrilla/proxy/simple_dependency_satisfying.py
new file mode 100644
index 0000000..9716fe5
--- /dev/null
+++ b/src/hydrilla/proxy/simple_dependency_satisfying.py
@@ -0,0 +1,189 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy payloads dependency resolution.
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+.....
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import dataclasses as dc
+import typing as t
+
+from .. import item_infos
+from .. import url_patterns
+
+@dc.dataclass
+class ComputedPayload:
+ resources: list[item_infos.ResourceInfo] = dc.field(default_factory=list)
+
+ allows_eval: bool = False
+ allows_cors_bypass: bool = False
+
+SingleMappingPayloads = t.Mapping[
+ url_patterns.ParsedPattern,
+ ComputedPayload
+]
+
+ComputedPayloadsDict = dict[
+ item_infos.MappingInfo,
+ SingleMappingPayloads
+]
+
+empty_identifiers_set: set[str] = set()
+
+@dc.dataclass(frozen=True)
+class ItemsCollection:
+ resources: t.Mapping[str, item_infos.ResourceInfo]
+ mappings: t.Mapping[str, item_infos.MappingInfo]
+
+ def _satisfy_payload_resource_rec(
+ self,
+ resource_identifier: str,
+ processed_resources: set[str],
+ computed_payload: ComputedPayload
+ ) -> t.Optional[ComputedPayload]:
+ if resource_identifier in processed_resources:
+ # We forbid circular dependencies.
+ return None
+
+ resource_info = self.resources.get(resource_identifier)
+ if resource_info is None:
+ return None
+
+ if resource_info in computed_payload.resources:
+ return computed_payload
+
+ processed_resources.add(resource_identifier)
+
+ if resource_info.allows_eval:
+ computed_payload.allows_eval = True
+
+ if resource_info.allows_cors_bypass:
+ computed_payload.allows_cors_bypass = True
+
+ for dependency_spec in resource_info.dependencies:
+ if self._satisfy_payload_resource_rec(
+ dependency_spec.identifier,
+ processed_resources,
+ computed_payload
+ ) is None:
+ return None
+
+ processed_resources.remove(resource_identifier)
+
+ computed_payload.resources.append(resource_info)
+
+ return computed_payload
+
+ def _satisfy_payload_resource(self, resource_identifier: str) \
+ -> t.Optional[ComputedPayload]:
+ return self._satisfy_payload_resource_rec(
+ resource_identifier,
+ set(),
+ ComputedPayload()
+ )
+
+ def _compute_payloads_no_mapping_requirements(self) -> ComputedPayloadsDict:
+ computed_result: ComputedPayloadsDict = ComputedPayloadsDict()
+
+ for mapping_info in self.mappings.values():
+ by_pattern: dict[url_patterns.ParsedPattern, ComputedPayload] = {}
+
+ failure = False
+
+ for pattern, resource_spec in mapping_info.payloads.items():
+ computed_payload = self._satisfy_payload_resource(
+ resource_spec.identifier
+ )
+ if computed_payload is None:
+ failure = True
+ break
+
+ if mapping_info.allows_eval:
+ computed_payload.allows_eval = True
+
+ if mapping_info.allows_cors_bypass:
+ computed_payload.allows_cors_bypass = True
+
+ by_pattern[pattern] = computed_payload
+
+ if not failure:
+ computed_result[mapping_info] = by_pattern
+
+ return computed_result
+
+ def _mark_mappings_bad(
+ self,
+ identifier: str,
+ reverse_mapping_deps: t.Mapping[str, set[str]],
+ bad_mappings: set[str]
+ ) -> None:
+ if identifier in bad_mappings:
+ return
+
+ bad_mappings.add(identifier)
+
+ for requiring in reverse_mapping_deps.get(identifier, ()):
+ self._mark_mappings_bad(
+ requiring,
+ reverse_mapping_deps,
+ bad_mappings
+ )
+
+ def compute_payloads(self) -> ComputedPayloadsDict:
+ computed_result = self._compute_payloads_no_mapping_requirements()
+
+ reverse_mapping_deps: dict[str, set[str]] = {}
+
+ for mapping_info, by_pattern in computed_result.items():
+ specs_to_resolve = [*mapping_info.required_mappings]
+
+ for computed_payload in by_pattern.values():
+ for resource_info in computed_payload.resources:
+ specs_to_resolve.extend(resource_info.required_mappings)
+
+ for required_mapping_spec in specs_to_resolve:
+ identifier = required_mapping_spec.identifier
+ requiring = reverse_mapping_deps.setdefault(identifier, set())
+ requiring.add(mapping_info.identifier)
+
+ bad_mappings: set[str] = set()
+
+ for required_identifier in reverse_mapping_deps.keys():
+ if self.mappings.get(required_identifier) not in computed_result:
+ self._mark_mappings_bad(
+ required_identifier,
+ reverse_mapping_deps,
+ bad_mappings
+ )
+
+ for identifier in bad_mappings:
+ if identifier in self.mappings:
+ computed_result.pop(self.mappings[identifier], None)
+
+ return computed_result
diff --git a/src/hydrilla/proxy/state.py b/src/hydrilla/proxy/state.py
index fc01536..e22c9fe 100644
--- a/src/hydrilla/proxy/state.py
+++ b/src/hydrilla/proxy/state.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
-# Haketilo proxy data and configuration.
+# Haketilo proxy data and configuration (interface definition through abstract
+# class).
#
# This file is part of Hydrilla&Haketilo.
#
@@ -25,49 +26,271 @@
# in a proprietary program, I am not going to enforce this in court.
"""
-This module contains logic for keeping track of all settings, rules, mappings
-and resources.
+This module defines API for keeping track of all settings, rules, mappings and
+resources.
"""
# Enable using with Python 3.7.
from __future__ import annotations
-import typing as t
import dataclasses as dc
+import typing as t
-from threading import Lock
from pathlib import Path
+from abc import ABC, abstractmethod
+from enum import Enum
+
+from immutables import Map
+
+from ..versions import VerTuple
+from ..url_patterns import ParsedPattern
+
+
+class EnabledStatus(Enum):
+ """
+ ENABLED - User wished to always apply given mapping when it matched.
+
+ DISABLED - User wished to never apply given mapping.
+
+ AUTO_ENABLED - User has not configured given mapping but it will still be
+ used when automatic application of mappings is turned on.
+
+ NO_MARK - User has not configured given mapping and it won't be used.
+ """
+ ENABLED = 'E'
+ DISABLED = 'D'
+ AUTO_ENABLED = 'A'
+ NO_MARK = 'N'
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class Ref:
+ """...."""
+ id: str
+
+
+# mypy needs to be corrected:
+# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class RepoRef(Ref):
+ """...."""
+ @abstractmethod
+ def remove(self, state: 'HaketiloState') -> None:
+ """...."""
+ ...
+
+ @abstractmethod
+ def update(
+ self,
+ state: 'HaketiloState',
+ *,
+ name: t.Optional[str] = None,
+ url: t.Optional[str] = None
+ ) -> 'RepoRef':
+ """...."""
+ ...
+
+ @abstractmethod
+ def refresh(self, state: 'HaketiloState') -> 'RepoIterationRef':
+ """...."""
+ ...
-from ..pattern_tree import PatternTree
-from .store import HaketiloStore
-from . import policies
-def make_pattern_tree_with_builtin_policies() -> PatternTree[policies.Policy]:
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class RepoIterationRef(Ref):
"""...."""
- # TODO: implement
- return PatternTree()
+ pass
-tree_field = dc.field(default_factory=make_pattern_tree_with_builtin_policies)
-@dc.dataclass
-class HaketiloState(HaketiloStore):
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class MappingRef(Ref):
"""...."""
- pattern_tree: PatternTree[policies.Policy] = tree_field
- default_allow: bool = False
+ @abstractmethod
+ def disable(self, state: 'HaketiloState') -> None:
+ """...."""
+ ...
+
+ @abstractmethod
+ def forget_enabled(self, state: 'HaketiloState') -> None:
+ """...."""
+ ...
- state_lock: Lock = dc.field(default_factory=Lock)
- def select_policy(self, url: str, allow_disabled=False) -> policies.Policy:
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class MappingVersionRef(Ref):
+ """...."""
+ @abstractmethod
+ def enable(self, state: 'HaketiloState') -> None:
"""...."""
- with self.state_lock:
- pattern_tree = self.pattern_tree
+ ...
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class ResourceRef(Ref):
+ """...."""
+ pass
+
- try:
- for policy_set in pattern_tree.search(url):
- # if policy.enabled or allow_disabled:
- # return policy
- pass
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class ResourceVersionRef(Ref):
+ """...."""
+ pass
- return policies.FallbackBlockPolicy()
- except Exception as e:
- return policies.ErrorBlockPolicy(e)
+
+@dc.dataclass(frozen=True)
+class PayloadKey:
+ """...."""
+ payload_ref: 'PayloadRef'
+
+ mapping_identifier: str
+ # mapping_version: VerTuple
+ # mapping_repo: str
+ # mapping_repo_iteration: int
+ pattern: ParsedPattern
+
+ def __lt__(self, other: 'PayloadKey') -> bool:
+ """...."""
+ return (
+ self.mapping_identifier,
+ # other.mapping_version,
+ # self.mapping_repo,
+ # other.mapping_repo_iteration,
+ self.pattern
+ ) < (
+ other.mapping_identifier,
+ # self.mapping_version,
+ # other.mapping_repo,
+ # self.mapping_repo_iteration,
+ other.pattern
+ )
+
+@dc.dataclass(frozen=True)
+class PayloadData:
+ """...."""
+ payload_ref: 'PayloadRef'
+
+ mapping_installed: bool
+ explicitly_enabled: bool
+ unique_token: str
+ pattern: ParsedPattern
+ eval_allowed: bool
+ cors_bypass_allowed: bool
+
+@dc.dataclass(frozen=True)
+class FileData:
+ type: str
+ name: str
+ contents: bytes
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class PayloadRef(Ref):
+ """...."""
+ @abstractmethod
+ def get_data(self, state: 'HaketiloState') -> PayloadData:
+ """...."""
+ ...
+
+ @abstractmethod
+ def get_mapping(self, state: 'HaketiloState') -> MappingVersionRef:
+ """...."""
+ ...
+
+ @abstractmethod
+ def get_script_paths(self, state: 'HaketiloState') \
+ -> t.Iterable[t.Sequence[str]]:
+ """...."""
+ ...
+
+ @abstractmethod
+ def get_file_data(self, state: 'HaketiloState', path: t.Sequence[str]) \
+ -> t.Optional[FileData]:
+ """...."""
+ ...
+
+
+class MappingUseMode(Enum):
+ """
+ AUTO - Apply mappings except for those explicitly disabled.
+
+ WHEN_ENABLED - Only apply mappings explicitly marked as enabled. Don't apply
+ unmarked nor explicitly disabled mappings.
+
+ QUESTION - Automatically apply mappings that are explicitly enabled. Ask
+ whether to enable unmarked mappings. Don't apply explicitly disabled
+ ones.
+ """
+ AUTO = 'A'
+ WHEN_ENABLED = 'W'
+ QUESTION = 'Q'
+
+@dc.dataclass(frozen=True)
+class HaketiloGlobalSettings:
+ """...."""
+ mapping_use_mode: MappingUseMode
+ default_allow_scripts: bool
+ repo_refresh_seconds: int
+
+
+class MissingItemError(ValueError):
+ """...."""
+ pass
+
+
+class HaketiloState(ABC):
+ """...."""
+ @abstractmethod
+ def get_repo(self, repo_id: str) -> RepoRef:
+ """...."""
+ ...
+
+ @abstractmethod
+ def get_repo_iteration(self, repo_iteration_id: str) -> RepoIterationRef:
+ """...."""
+ ...
+
+ @abstractmethod
+ def get_mapping(self, mapping_id: str) -> MappingRef:
+ """...."""
+ ...
+
+ @abstractmethod
+ def get_mapping_version(self, mapping_version_id: str) \
+ -> MappingVersionRef:
+ """...."""
+ ...
+
+ @abstractmethod
+ def get_resource(self, resource_id: str) -> ResourceRef:
+ """...."""
+ ...
+
+ @abstractmethod
+ def get_resource_version(self, resource_version_id: str) \
+ -> ResourceVersionRef:
+ """...."""
+ ...
+
+ @abstractmethod
+ def get_payload(self, payload_id: str) -> PayloadRef:
+ """...."""
+ ...
+
+ @abstractmethod
+ def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \
+ -> RepoRef:
+ """...."""
+ ...
+
+ @abstractmethod
+ def get_settings(self) -> HaketiloGlobalSettings:
+ """...."""
+ ...
+
+ @abstractmethod
+ def update_settings(
+ self,
+ *,
+ mapping_use_mode: t.Optional[MappingUseMode] = None,
+ default_allow_scripts: t.Optional[bool] = None,
+ repo_refresh_seconds: t.Optional[int] = None
+ ) -> None:
+ """...."""
+ ...
diff --git a/src/hydrilla/proxy/state_impl.py b/src/hydrilla/proxy/state_impl.py
new file mode 100644
index 0000000..059fee9
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl.py
@@ -0,0 +1,288 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy data and configuration.
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+This module contains logic for keeping track of all settings, rules, mappings
+and resources.
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import secrets
+import threading
+import typing as t
+import dataclasses as dc
+
+from pathlib import Path
+
+from immutables import Map
+
+from ..url_patterns import ParsedUrl
+from ..pattern_tree import PatternTree
+from .store import HaketiloStore
+from . import state
+from . import policies
+
+
+PolicyTree = PatternTree[policies.PolicyFactory]
+
+
+def register_builtin_policies(policy_tree: PolicyTree) -> PolicyTree:
+ """...."""
+ # TODO: implement
+ pass
+
+
+def register_payload(
+ policy_tree: PolicyTree,
+ payload_ref: state.PayloadRef,
+ token: str
+) -> tuple[PolicyTree, t.Iterable[policies.PolicyFactory]]:
+ """...."""
+ payload_policy_factory = policies.PayloadPolicyFactory(
+ builtin = False,
+ payload_ref = payload_ref
+ )
+
+ policy_tree = policy_tree.register(
+ payload_ref.pattern,
+ payload_policy_factory
+ )
+
+ resource_policy_factory = policies.PayloadResourcePolicyFactory(
+ builtin = False,
+ payload_ref = payload_ref
+ )
+
+ policy_tree = policy_tree.register(
+ payload_ref.pattern.path_append(token, '***'),
+ resource_policy_factory
+ )
+
+ return policy_tree, (payload_policy_factory, resource_policy_factory)
+
+
+def register_mapping(
+ policy_tree: PolicyTree,
+ payload_refs: t.Iterable[state.PayloadRef],
+ token: str
+) -> tuple[PolicyTree, t.Iterable[policies.PolicyFactory]]:
+ """...."""
+ policy_factories: list[policies.PolicyFactory] = []
+
+ for payload_ref in payload_refs:
+ policy_tree, factories = register_payload(
+ policy_tree,
+ payload_ref,
+ token
+ )
+
+ policy_factories.extend(factories)
+
+ return policy_tree, policy_factories
+
+
+@dc.dataclass(frozen=True)
+class RichMappingData(state.MappingData):
+ """...."""
+ policy_factories: t.Iterable[policies.PolicyFactory]
+
+
+# @dc.dataclass(frozen=True)
+# class HaketiloData:
+# settings: state.HaketiloGlobalSettings
+# policy_tree: PolicyTree
+# mappings_data: Map[state.MappingRef, RichMappingData] = Map()
+
+
+MonitoredMethodType = t.TypeVar('MonitoredMethodType', bound=t.Callable)
+
+def with_lock(wrapped_fun: MonitoredMethodType) -> MonitoredMethodType:
+ """...."""
+ def wrapper(self: 'ConcreteHaketiloState', *args, **kwargs):
+ """...."""
+ with self.lock:
+ return wrapped_fun(self, *args, **kwargs)
+
+ return t.cast(MonitoredMethodType, wrapper)
+
+@dc.dataclass
+class ConcreteHaketiloState(state.HaketiloState):
+ """...."""
+ store: HaketiloStore
+ settings: state.HaketiloGlobalSettings
+
+ policy_tree: PolicyTree = PatternTree()
+ mappings_data: Map[state.MappingRef, RichMappingData] = Map()
+
+ lock: threading.RLock = dc.field(default_factory=threading.RLock)
+
+ def __post_init__(self) -> None:
+ """...."""
+ self.policy_tree = register_builtin_policies(self.policy_tree)
+
+ self._init_mappings()
+
+ def _init_mappings(self) -> None:
+ """...."""
+ store_mappings_data = self.store.load_installed_mappings_data()
+
+ payload_items = self.store.load_payloads_data().items()
+ for mapping_ref, payload_refs in payload_items:
+ installed = True
+ enabled_status = store_mappings_data.get(mapping_ref)
+
+ if enabled_status is None:
+ installed = False
+ enabled_status = state.EnabledStatus.NO_MARK
+
+ self._register_mapping(
+ mapping_ref,
+ payload_refs,
+ enabled = enabled_status,
+ installed = installed
+ )
+
+ @with_lock
+ def _register_mapping(
+ self,
+ mapping_ref: state.MappingRef,
+ payload_refs: t.Iterable[state.PayloadRef],
+ enabled: state.EnabledStatus,
+ installed: bool
+ ) -> None:
+ """...."""
+ token = secrets.token_urlsafe(8)
+
+ self.policy_tree, factories = register_mapping(
+ self.policy_tree,
+ payload_refs,
+ token
+ )
+
+ runtime_data = RichMappingData(
+ mapping_ref = mapping_ref,
+ enabled_status = enabled,
+ unique_token = token,
+ policy_factories = factories
+ )
+
+ self.mappings_data = self.mappings_data.set(mapping_ref, runtime_data)
+
+ @with_lock
+ def get_mapping_data(self, mapping_ref: state.MappingRef) \
+ -> state.MappingData:
+ try:
+ return self.mappings_data[mapping_ref]
+ except KeyError:
+ raise state.MissingItemError('no such mapping')
+
+ @with_lock
+ def get_payload_data(self, payload_ref: state.PayloadRef) \
+ -> state.PayloadData:
+ # TODO!!!
+ try:
+ return t.cast(state.PayloadData, None)
+ except:
+ raise state.MissingItemError('no such payload')
+
+ @with_lock
+ def get_file_paths(self, payload_ref: state.PayloadRef) \
+ -> t.Iterable[t.Sequence[str]]:
+ # TODO!!!
+ return []
+
+ @with_lock
+ def get_file_data(
+ self,
+ payload_ref: state.PayloadRef,
+ file_path: t.Sequence[str]
+ ) -> t.Optional[state.FileData]:
+ if len(file_path) == 0:
+ raise state.MissingItemError('empty file path')
+
+ path_str = '/'.join(file_path[1:])
+
+ return self.store.load_file_data(payload_ref, file_path[0], path_str)
+
+ @with_lock
+ def ensure_installed(self, mapping: state.MappingRef) -> None:
+ # TODO!!!
+ pass
+
+ @with_lock
+ def get_settings(self) -> state.HaketiloGlobalSettings:
+ return self.settings
+
+ @with_lock
+ def update_settings(self, updater: state.SettingsUpdater) -> None:
+ new_settings = updater(self.settings)
+
+ self.store.write_global_settings(new_settings)
+
+ self.settings = new_settings
+
+ def select_policy(self, url: str) -> policies.Policy:
+ """...."""
+ with self.lock:
+ policy_tree = self.policy_tree
+
+ try:
+ best_priority: int = 0
+ best_policy: t.Optional[policies.Policy] = None
+
+ for factories_set in policy_tree.search(url):
+ for stored_factory in sorted(factories_set):
+ factory = stored_factory.item
+
+ policy = factory.make_policy(self)
+ policy_priority = policy.priority()
+
+ if policy_priority > best_priority:
+ best_priority = policy_priority
+ best_policy = policy
+ except Exception as e:
+ return policies.ErrorBlockPolicy(
+ builtin = True,
+ error = e
+ )
+
+ if best_policy is not None:
+ return best_policy
+
+ if self.get_settings().default_allow_scripts:
+ return policies.FallbackAllowPolicy()
+ else:
+ return policies.FallbackBlockPolicy()
+
+ @staticmethod
+ def make(store_dir: Path) -> 'ConcreteHaketiloState':
+ """...."""
+ store = HaketiloStore(store_dir)
+ settings = store.load_global_settings()
+
+ return ConcreteHaketiloState(store=store, settings=settings)
diff --git a/src/hydrilla/proxy/state_impl/__init__.py b/src/hydrilla/proxy/state_impl/__init__.py
new file mode 100644
index 0000000..5398cdd
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: CC0-1.0
+
+# Copyright (C) 2022 Wojtek Kosior <koszko@koszko.org>
+#
+# Available under the terms of Creative Commons Zero v1.0 Universal.
+
+from .concrete_state import ConcreteHaketiloState
diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py
new file mode 100644
index 0000000..78a50c0
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl/base.py
@@ -0,0 +1,112 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy data and configuration (definition of fields of a class that
+# will implement HaketiloState).
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+This module defines fields that will later be part of a concrete HaketiloState
+subtype.
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import sqlite3
+import threading
+import dataclasses as dc
+import typing as t
+
+from pathlib import Path
+from contextlib import contextmanager
+
+import sqlite3
+
+from immutables import Map
+
+from ... import pattern_tree
+from .. import state
+from .. import policies
+
+
+PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
+
+#PayloadsDataMap = Map[state.PayloadRef, state.PayloadData]
+DataById = t.Mapping[str, state.PayloadData]
+
+# mypy needs to be corrected:
+# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
+@dc.dataclass # type: ignore[misc]
+class HaketiloStateWithFields(state.HaketiloState):
+ """...."""
+ store_dir: Path
+ connection: sqlite3.Connection
+ current_cursor: t.Optional[sqlite3.Cursor] = None
+ #settings: state.HaketiloGlobalSettings
+
+ policy_tree: PolicyTree = PolicyTree()
+ #payloads_data: PayloadsDataMap = Map()
+ payloads_data: DataById = dc.field(default_factory=dict)
+
+ lock: threading.RLock = dc.field(default_factory=threading.RLock)
+
+ @contextmanager
+ def cursor(self, lock: bool = False, transaction: bool = False) \
+ -> t.Iterator[sqlite3.Cursor]:
+ """...."""
+ start_transaction = transaction and not self.connection.in_transaction
+
+ assert lock or not start_transaction
+
+ try:
+ if lock:
+ self.lock.acquire()
+
+ if self.current_cursor is not None:
+ yield self.current_cursor
+ return
+
+ self.current_cursor = self.connection.cursor()
+
+ if start_transaction:
+ self.current_cursor.execute('BEGIN TRANSACTION;')
+
+ try:
+ yield self.current_cursor
+
+ if start_transaction:
+ self.current_cursor.execute('COMMIT TRANSACTION;')
+ except:
+ if start_transaction:
+ self.current_cursor.execute('ROLLBACK TRANSACTION;')
+ raise
+ finally:
+ self.current_cursor = None
+
+ if lock:
+ self.lock.release()
+
+ def rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
+ """...."""
+ ...
diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py
new file mode 100644
index 0000000..cd6698c
--- /dev/null
+++ b/src/hydrilla/proxy/state_impl/concrete_state.py
@@ -0,0 +1,704 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# Haketilo proxy data and configuration (instantiatable HaketiloState subtype).
+#
+# This file is part of Hydrilla&Haketilo.
+#
+# Copyright (C) 2022 Wojtek Kosior
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+#
+# I, Wojtek Kosior, thereby promise not to sue for violation of this
+# file's license. Although I request that you do not make use this code
+# in a proprietary program, I am not going to enforce this in court.
+
+"""
+This module contains logic for keeping track of all settings, rules, mappings
+and resources.
+"""
+
+# Enable using with Python 3.7.
+from __future__ import annotations
+
+import secrets
+import io
+import hashlib
+import typing as t
+import dataclasses as dc
+
+from pathlib import Path
+
+import sqlite3
+
+from ...exceptions import HaketiloException
+from ...translations import smart_gettext as _
+from ... import pattern_tree
+from ... import url_patterns
+from ... import versions
+from ... import item_infos
+from ..simple_dependency_satisfying import ItemsCollection, ComputedPayload
+from .. import state as st
+from .. import policies
+from . import base
+
+
+here = Path(__file__).resolve().parent
+
+AnyInfo = t.Union[item_infos.ResourceInfo, item_infos.MappingInfo]
+
+@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
+class ConcreteRepoRef(st.RepoRef):
+ def remove(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+ def update(
+ self,
+ state: st.HaketiloState,
+ *,
+ name: t.Optional[str] = None,
+ url: t.Optional[str] = None
+ ) -> ConcreteRepoRef:
+ raise NotImplementedError()
+
+ def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteRepoIterationRef(st.RepoIterationRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteMappingRef(st.MappingRef):
+ def disable(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+ def forget_enabled(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteMappingVersionRef(st.MappingVersionRef):
+ def enable(self, state: st.HaketiloState) -> None:
+ raise NotImplementedError()
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteResourceRef(st.ResourceRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcreteResourceVersionRef(st.ResourceVersionRef):
+ pass
+
+
+@dc.dataclass(frozen=True, unsafe_hash=True)
+class ConcretePayloadRef(st.PayloadRef):
+ computed_payload: ComputedPayload = dc.field(hash=False, compare=False)
+
+ def get_data(self, state: st.HaketiloState) -> st.PayloadData:
+ return t.cast(ConcreteHaketiloState, state).payloads_data[self.id]
+
+ def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
+ return 'to implement'
+
+ def get_script_paths(self, state: st.HaketiloState) \
+ -> t.Iterator[t.Sequence[str]]:
+ for resource_info in self.computed_payload.resources:
+ for file_spec in resource_info.scripts:
+ yield (resource_info.identifier, *file_spec.name.split('/'))
+
+ def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
+ -> t.Optional[st.FileData]:
+ if len(path) == 0:
+ raise st.MissingItemError()
+
+ resource_identifier, *file_name_segments = path
+
+ file_name = '/'.join(file_name_segments)
+
+ script_sha256 = ''
+
+ matched_resource_info = False
+
+ for resource_info in self.computed_payload.resources:
+ if resource_info.identifier == resource_identifier:
+ matched_resource_info = True
+
+ for script_spec in resource_info.scripts:
+ if script_spec.name == file_name:
+ script_sha256 = script_spec.sha256
+
+ break
+
+ if not matched_resource_info:
+ raise st.MissingItemError(resource_identifier)
+
+ if script_sha256 == '':
+ return None
+
+ store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir
+ files_dir_path = store_dir_path / 'temporary_malcontent' / 'file'
+ file_path = files_dir_path / 'sha256' / script_sha256
+
+ return st.FileData(
+ type = 'application/javascript',
+ name = file_name,
+ contents = file_path.read_bytes()
+ )
+
+PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
+
+def register_payload(
+ policy_tree: PolicyTree,
+ payload_key: st.PayloadKey,
+ token: str
+) -> PolicyTree:
+ """...."""
+ payload_policy_factory = policies.PayloadPolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ payload_key.pattern,
+ payload_policy_factory
+ )
+
+ resource_policy_factory = policies.PayloadResourcePolicyFactory(
+ builtin = False,
+ payload_key = payload_key
+ )
+
+ policy_tree = policy_tree.register(
+ payload_key.pattern.path_append(token, '***'),
+ resource_policy_factory
+ )
+
+ return policy_tree
+
+DataById = t.Mapping[str, st.PayloadData]
+
+AnyInfoVar = t.TypeVar(
+ 'AnyInfoVar',
+ item_infos.ResourceInfo,
+ item_infos.MappingInfo
+)
+
+# def newest_item_path(item_dir: Path) -> t.Optional[Path]:
+# available_versions = tuple(
+# versions.parse_normalize_version(ver_path.name)
+# for ver_path in item_dir.iterdir()
+# if ver_path.is_file()
+# )
+
+# if available_versions == ():
+# return None
+
+# newest_version = max(available_versions)
+
+# version_path = item_dir / versions.version_string(newest_version)
+
+# assert version_path.is_file()
+
+# return version_path
+
+def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
+ -> t.Iterator[tuple[AnyInfoVar, str]]:
+ item_type_path = malcontent_path / item_class.type_name
+ if not item_type_path.is_dir():
+ return
+
+ for item_path in item_type_path.iterdir():
+ if not item_path.is_dir():
+ continue
+
+ for item_version_path in item_path.iterdir():
+ definition = item_version_path.read_text()
+ item_info = item_class.load(io.StringIO(definition))
+
+ assert item_info.identifier == item_path.name
+ assert versions.version_string(item_info.version) == \
+ item_version_path.name
+
+ yield item_info, definition
+
+def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int:
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO repos(name, url, deleted, next_iteration)
+ VALUES(?, '<dummy_url>', TRUE, 2);
+ ''',
+ (repo_name,)
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ repo_id, next_iteration - 1
+ FROM
+ repos
+ WHERE
+ name = ?;
+ ''',
+ (repo_name,)
+ )
+
+ (repo_id, last_iteration), = cursor.fetchall()
+
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO repo_iterations(repo_id, iteration)
+ VALUES(?, ?);
+ ''',
+ (repo_id, last_iteration)
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ repo_iteration_id
+ FROM
+ repo_iterations
+ WHERE
+ repo_id = ? AND iteration = ?;
+ ''',
+ (repo_id, last_iteration)
+ )
+
+ (repo_iteration_id,), = cursor.fetchall()
+
+ return repo_iteration_id
+
+def get_or_make_item(cursor: sqlite3.Cursor, type: str, identifier: str) -> int:
+ type_letter = {'resource': 'R', 'mapping': 'M'}[type]
+
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO items(type, identifier)
+ VALUES(?, ?);
+ ''',
+ (type_letter, identifier)
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ item_id
+ FROM
+ items
+ WHERE
+ type = ? AND identifier = ?;
+ ''',
+ (type_letter, identifier)
+ )
+
+ (item_id,), = cursor.fetchall()
+
+ return item_id
+
+def get_or_make_item_version(
+ cursor: sqlite3.Cursor,
+ item_id: int,
+ repo_iteration_id: int,
+ definition: str,
+ info: AnyInfo
+) -> int:
+ ver_str = versions.version_string(info.version)
+
+ values = (
+ item_id,
+ ver_str,
+ repo_iteration_id,
+ definition,
+ info.allows_eval,
+ info.allows_cors_bypass
+ )
+
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO item_versions(
+ item_id,
+ version,
+ repo_iteration_id,
+ definition,
+ eval_allowed,
+ cors_bypass_allowed
+ )
+ VALUES(?, ?, ?, ?, ?, ?);
+ ''',
+ values
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ item_version_id
+ FROM
+ item_versions
+ WHERE
+ item_id = ? AND version = ? AND repo_iteration_id = ?;
+ ''',
+ (item_id, ver_str, repo_iteration_id)
+ )
+
+ (item_version_id,), = cursor.fetchall()
+
+ return item_version_id
+
+def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None:
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO mapping_statuses(item_id, enabled)
+ VALUES(?, 'N');
+ ''',
+ (item_id,)
+ )
+
+def get_or_make_file(cursor: sqlite3.Cursor, sha256: str, file_bytes: bytes) \
+ -> int:
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO files(sha256, data)
+ VALUES(?, ?)
+ ''',
+ (sha256, file_bytes)
+ )
+
+ cursor.execute(
+ '''
+ SELECT
+ file_id
+ FROM
+ files
+ WHERE
+ sha256 = ?;
+ ''',
+ (sha256,)
+ )
+
+ (file_id,), = cursor.fetchall()
+
+ return file_id
+
+def make_file_use(
+ cursor: sqlite3.Cursor,
+ item_version_id: int,
+ file_id: int,
+ name: str,
+ type: str,
+ mime_type: str,
+ idx: int
+) -> None:
+ cursor.execute(
+ '''
+ INSERT OR IGNORE INTO file_uses(
+ item_version_id,
+ file_id,
+ name,
+ type,
+ mime_type,
+ idx
+ )
+ VALUES(?, ?, ?, ?, ?, ?);
+ ''',
+ (item_version_id, file_id, name, type, mime_type, idx)
+ )
+
+@dc.dataclass
+class ConcreteHaketiloState(base.HaketiloStateWithFields):
+ def __post_init__(self) -> None:
+ self._prepare_database()
+
+ self._populate_database_with_stuff_from_temporary_malcontent_dir()
+
+ with self.cursor() as cursor:
+ self.rebuild_structures(cursor)
+
+ def _prepare_database(self) -> None:
+ """...."""
+ cursor = self.connection.cursor()
+
+ try:
+ cursor.execute(
+ '''
+ SELECT COUNT(name)
+ FROM sqlite_master
+ WHERE name = 'general' AND type = 'table';
+ '''
+ )
+
+ (db_initialized,), = cursor.fetchall()
+
+ if not db_initialized:
+ cursor.executescript((here.parent / 'tables.sql').read_text())
+
+ else:
+ cursor.execute(
+ '''
+ SELECT haketilo_version
+ FROM general;
+ '''
+ )
+
+ (db_haketilo_version,) = cursor.fetchone()
+ if db_haketilo_version != '3.0b1':
+ raise HaketiloException(_('err.unknown_db_schema'))
+
+ cursor.execute('PRAGMA FOREIGN_KEYS;')
+ if cursor.fetchall() == []:
+ raise HaketiloException(_('err.proxy.no_sqlite_foreign_keys'))
+
+ cursor.execute('PRAGMA FOREIGN_KEYS=ON;')
+ finally:
+ cursor.close()
+
+ def _populate_database_with_stuff_from_temporary_malcontent_dir(self) \
+ -> None:
+ malcontent_dir_path = self.store_dir / 'temporary_malcontent'
+ files_by_sha256_path = malcontent_dir_path / 'file' / 'sha256'
+
+ with self.cursor(lock=True, transaction=True) as cursor:
+ for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]:
+ info: AnyInfo
+ for info, definition in read_items(
+ malcontent_dir_path,
+ info_type # type: ignore
+ ):
+ repo_iteration_id = get_or_make_repo_iteration(
+ cursor,
+ info.repo
+ )
+
+ item_id = get_or_make_item(
+ cursor,
+ info.type_name,
+ info.identifier
+ )
+
+ item_version_id = get_or_make_item_version(
+ cursor,
+ item_id,
+ repo_iteration_id,
+ definition,
+ info
+ )
+
+ if info_type is item_infos.MappingInfo:
+ make_mapping_status(cursor, item_id)
+
+ file_ids_bytes = {}
+
+ file_specifiers = [*info.source_copyright]
+ if isinstance(info, item_infos.ResourceInfo):
+ file_specifiers.extend(info.scripts)
+
+ for file_spec in file_specifiers:
+ file_path = files_by_sha256_path / file_spec.sha256
+ file_bytes = file_path.read_bytes()
+
+ sha256 = hashlib.sha256(file_bytes).digest().hex()
+ assert sha256 == file_spec.sha256
+
+ file_id = get_or_make_file(cursor, sha256, file_bytes)
+
+ file_ids_bytes[sha256] = (file_id, file_bytes)
+
+ for idx, file_spec in enumerate(info.source_copyright):
+ file_id, file_bytes = file_ids_bytes[file_spec.sha256]
+ if file_bytes.isascii():
+ mime = 'text/plain'
+ else:
+ mime = 'application/octet-stream'
+
+ make_file_use(
+ cursor,
+ item_version_id = item_version_id,
+ file_id = file_id,
+ name = file_spec.name,
+ type = 'L',
+ mime_type = mime,
+ idx = idx
+ )
+
+ if isinstance(info, item_infos.MappingInfo):
+ continue
+
+ for idx, file_spec in enumerate(info.scripts):
+ file_id, _ = file_ids_bytes[file_spec.sha256]
+ make_file_use(
+ cursor,
+ item_version_id = item_version_id,
+ file_id = file_id,
+ name = file_spec.name,
+ type = 'W',
+ mime_type = 'application/javascript',
+ idx = idx
+ )
+
+ def rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
+ cursor.execute(
+ '''
+ SELECT
+ item_id, type, version, definition
+ FROM
+ item_versions JOIN items USING (item_id);
+ '''
+ )
+
+ best_versions: dict[int, versions.VerTuple] = {}
+ definitions = {}
+ types = {}
+
+ for item_id, item_type, ver_str, definition in cursor.fetchall():
+ # TODO: what we're doing in this loop does not yet take different
+ # repos and different repo iterations into account.
+ ver = versions.parse_normalize_version(ver_str)
+ if best_versions.get(item_id, (0,)) < ver:
+ best_versions[item_id] = ver
+ definitions[item_id] = definition
+ types[item_id] = item_type
+
+ resources = {}
+ mappings = {}
+
+ for item_id, definition in definitions.items():
+ if types[item_id] == 'R':
+ r_info = item_infos.ResourceInfo.load(io.StringIO(definition))
+ resources[r_info.identifier] = r_info
+ else:
+ m_info = item_infos.MappingInfo.load(io.StringIO(definition))
+ mappings[m_info.identifier] = m_info
+
+ items_collection = ItemsCollection(resources, mappings)
+ computed_payloads = items_collection.compute_payloads()
+
+ payloads_data = {}
+
+ for mapping_info, by_pattern in computed_payloads.items():
+ for num, (pattern, payload) in enumerate(by_pattern.items()):
+ payload_id = f'{num}@{mapping_info.identifier}'
+
+ ref = ConcretePayloadRef(payload_id, payload)
+
+ data = st.PayloadData(
+ payload_ref = ref,
+ mapping_installed = True,
+ explicitly_enabled = True,
+ unique_token = secrets.token_urlsafe(16),
+ pattern = pattern,
+ eval_allowed = payload.allows_eval,
+ cors_bypass_allowed = payload.allows_cors_bypass
+ )
+
+ payloads_data[payload_id] = data
+
+ key = st.PayloadKey(
+ payload_ref = ref,
+ mapping_identifier = mapping_info.identifier,
+ pattern = pattern
+ )
+
+ self.policy_tree = register_payload(
+ self.policy_tree,
+ key,
+ data.unique_token
+ )
+
+ self.payloads_data = payloads_data
+
+ def get_repo(self, repo_id: str) -> st.RepoRef:
+ return ConcreteRepoRef(repo_id)
+
+ def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef:
+ return ConcreteRepoIterationRef(repo_iteration_id)
+
+ def get_mapping(self, mapping_id: str) -> st.MappingRef:
+ return ConcreteMappingRef(mapping_id)
+
+ def get_mapping_version(self, mapping_version_id: str) \
+ -> st.MappingVersionRef:
+ return ConcreteMappingVersionRef(mapping_version_id)
+
+ def get_resource(self, resource_id: str) -> st.ResourceRef:
+ return ConcreteResourceRef(resource_id)
+
+ def get_resource_version(self, resource_version_id: str) \
+ -> st.ResourceVersionRef:
+ return ConcreteResourceVersionRef(resource_version_id)
+
+ def get_payload(self, payload_id: str) -> st.PayloadRef:
+ return 'not implemented'
+
+ def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \
+ -> st.RepoRef:
+ raise NotImplementedError()
+
+ def get_settings(self) -> st.HaketiloGlobalSettings:
+ return st.HaketiloGlobalSettings(
+ mapping_use_mode = st.MappingUseMode.AUTO,
+ default_allow_scripts = True,
+ repo_refresh_seconds = 0
+ )
+
+ def update_settings(
+ self,
+ *,
+ mapping_use_mode: t.Optional[st.MappingUseMode] = None,
+ default_allow_scripts: t.Optional[bool] = None,
+ repo_refresh_seconds: t.Optional[int] = None
+ ) -> None:
+ raise NotImplementedError()
+
+ def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy:
+ """...."""
+ with self.lock:
+ policy_tree = self.policy_tree
+
+ try:
+ best_priority: int = 0
+ best_policy: t.Optional[policies.Policy] = None
+
+ for factories_set in policy_tree.search(url):
+ for stored_factory in sorted(factories_set):
+ factory = stored_factory.item
+
+ policy = factory.make_policy(self)
+
+ if policy.priority > best_priority:
+ best_priority = policy.priority
+ best_policy = policy
+ except Exception as e:
+ return policies.ErrorBlockPolicy(
+ builtin = True,
+ error = e
+ )
+
+ if best_policy is not None:
+ return best_policy
+
+ if self.get_settings().default_allow_scripts:
+ return policies.FallbackAllowPolicy()
+ else:
+ return policies.FallbackBlockPolicy()
+
+ @staticmethod
+ def make(store_dir: Path) -> 'ConcreteHaketiloState':
+ return ConcreteHaketiloState(
+ store_dir = store_dir,
+ connection = sqlite3.connect(str(store_dir / 'sqlite3.db'))
+ )
diff --git a/src/hydrilla/proxy/store.py b/src/hydrilla/proxy/store.py
index 72852d8..4978b65 100644
--- a/src/hydrilla/proxy/store.py
+++ b/src/hydrilla/proxy/store.py
@@ -29,12 +29,153 @@
# Enable using with Python 3.7.
from __future__ import annotations
+import threading
import dataclasses as dc
+import typing as t
from pathlib import Path
+from enum import Enum
+
+from immutables import Map
+
+from .. url_patterns import parse_pattern
+from .. import versions
+from . import state
+
+
+@dc.dataclass(frozen=True, eq=False)
+class StoredItemRef(state.ItemRef):
+ item_id: int
+
+ def __eq__(self, other: object) -> bool:
+ return isinstance(other, StoredItemRef) and \
+ self.item_id == other.item_id
+
+ def __hash__(self) -> int:
+ return hash(self.item_id)
+
+ def _id(self) -> str:
+ return str(self.item_id)
+
+
+@dc.dataclass(frozen=True, eq=False)
+class StoredPayloadRef(state.PayloadRef):
+ payload_id: int
+
+ def __eq__(self, other: object) -> bool:
+ return isinstance(other, StoredPayloadRef) and \
+ self.payload_id == other.payload_id
+
+ def __hash__(self) -> int:
+ return hash(self.payload_id)
+
+ def _id(self) -> str:
+ return str(self.payload_id)
+
+
+# class ItemStoredData:
+# """...."""
+# def __init__(
+# self,
+# item_id: int
+# ty#pe: ItemType
+# repository_id: int
+# version: str
+# identifier: str
+# orphan: bool
+# installed: bool
+# enabled: EnabledStatus
+# ) -> None:
+# """...."""
+# self.item_id = item_id
+# self.type = ItemType(type)
+# self.repository_id = repository_id
+# self.version = parse
+# identifier: str
+# orphan: bool
+# installed: bool
+# enabled: EnabledStatus
+
@dc.dataclass
class HaketiloStore:
"""...."""
store_dir: Path
- # TODO: implement
+
+ lock: threading.RLock = dc.field(default_factory=threading.RLock)
+
+ # def load_all_resources(self) -> t.Sequence[item_infos.ResourceInfo]:
+ # """...."""
+ # # TODO: implement
+ # with self.lock:
+ # return []
+
+ def load_installed_mappings_data(self) \
+ -> t.Mapping[state.MappingRef, state.EnabledStatus]:
+ """...."""
+ # TODO: implement
+ with self.lock:
+ dummy_item_ref = StoredItemRef(
+ item_id = 47,
+ identifier = 'somemapping',
+ version = versions.parse_normalize_version('1.2.3'),
+ repository = 'somerepo',
+ orphan = False
+ )
+
+ return Map({
+ state.MappingRef(dummy_item_ref): state.EnabledStatus.ENABLED
+ })
+
+ def load_payloads_data(self) \
+ -> t.Mapping[state.MappingRef, t.Iterable[state.PayloadRef]]:
+ """...."""
+ # TODO: implement
+ with self.lock:
+ dummy_item_ref = StoredItemRef(
+ item_id = 47,
+ identifier = 'somemapping',
+ version = versions.parse_normalize_version('1.2.3'),
+ repository = 'somerepo',
+ orphan = False
+ )
+
+ dummy_mapping_ref = state.MappingRef(dummy_item_ref)
+
+ payload_refs = []
+ for parsed_pattern in parse_pattern('http*://example.com/a/***'):
+ dummy_payload_ref = StoredPayloadRef(
+ payload_id = 22,
+ mapping_ref = dummy_mapping_ref,
+ pattern = parsed_pattern
+ )
+
+ payload_refs.append(dummy_payload_ref)
+
+ return Map({dummy_mapping_ref: payload_refs})
+
+ def load_file_data(
+ self,
+ payload_ref: state.PayloadRef,
+ resource_identifier: str,
+ file_path: t.Sequence[str]
+ ) -> t.Optional[state.FileData]:
+ # TODO: implement
+ with self.lock:
+ return None
+
+ def load_global_settings(self) -> state.HaketiloGlobalSettings:
+ """...."""
+ # TODO: implement
+ with self.lock:
+ return state.HaketiloGlobalSettings(
+ state.MappingApplicationMode.WHEN_ENABLED,
+ False
+ )
+
+ def write_global_settings(self, settings: state.HaketiloGlobalSettings) \
+ -> None:
+ """...."""
+ # TODO: implement
+ with self.lock:
+ pass
diff --git a/src/hydrilla/proxy/tables.sql b/src/hydrilla/proxy/tables.sql
new file mode 100644
index 0000000..53539a7
--- /dev/null
+++ b/src/hydrilla/proxy/tables.sql
@@ -0,0 +1,235 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+
+-- SQLite tables definitions for Haketilo proxy.
+--
+-- This file is part of Hydrilla&Haketilo.
+--
+-- Copyright (C) 2022 Wojtek Kosior
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU General Public License as published by
+-- the Free Software Foundation, either version 3 of the License, or
+-- (at your option) any later version.
+--
+-- This program is distributed in the hope that it will be useful,
+-- but WITHOUT ANY WARRANTY; without even the implied warranty of
+-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+-- GNU General Public License for more details.
+--
+-- You should have received a copy of the GNU General Public License
+-- along with this program. If not, see <https://www.gnu.org/licenses/>.
+--
+--
+-- I, Wojtek Kosior, thereby promise not to sue for violation of this
+-- file's license. Although I request that you do not make use this code
+-- in a proprietary program, I am not going to enforce this in court.
+
+BEGIN TRANSACTION;
+
+CREATE TABLE general(
+ haketilo_version VARCHAR NOT NULL,
+ default_allow_scripts BOOLEAN NOT NULL,
+ repo_refresh_seconds INTEGER NOT NULL,
+ -- "mapping_use_mode" determines whether current mode is AUTO,
+ -- WHEN_ENABLED or QUESTION.
+ mapping_use_mode CHAR(1) NOT NULL,
+
+ CHECK (rowid = 1),
+ CHECK (mapping_use_mode IN ('A', 'W', 'Q')),
+ CHECK (haketilo_version = '3.0b1')
+);
+
+INSERT INTO general(
+ rowid,
+ haketilo_version,
+ default_allow_scripts,
+ repo_refresh_seconds,
+ mapping_use_mode
+) VALUES(
+ 1,
+ '3.0b1',
+ FALSE,
+ 24 * 60 * 60,
+ 'Q'
+);
+
+CREATE TABLE rules(
+ rule_id INTEGER PRIMARY KEY,
+
+ pattern VARCHAR NOT NULL,
+ allow_scripts BOOLEAN NOT NULL,
+
+ UNIQUE (pattern)
+);
+
+CREATE TABLE repos(
+ repo_id INTEGER PRIMARY KEY,
+
+ name VARCHAR NOT NULL,
+ url VARCHAR NOT NULL,
+ deleted BOOLEAN NOT NULL,
+ next_iteration INTEGER NOT NULL,
+ active_iteration_id INTEGER NULL,
+ last_refreshed INTEGER NULL,
+
+ UNIQUE (name),
+
+ FOREIGN KEY (active_iteration_id)
+ REFERENCES repo_iterations(repo_iteration_id)
+ ON DELETE SET NULL
+);
+
+CREATE TABLE repo_iterations(
+ repo_iteration_id INTEGER PRIMARY KEY,
+
+ repo_id INTEGER NOT NULL,
+ iteration INTEGER NOT NULL,
+
+ UNIQUE (repo_id, iteration),
+
+ FOREIGN KEY (repo_id)
+ REFERENCES repos (repo_id)
+);
+
+CREATE VIEW orphan_iterations
+AS
+SELECT
+ ri.repo_iteration_id,
+ ri.repo_id,
+ ri.iteration
+FROM
+ repo_iterations AS ri
+ JOIN repos AS r USING (repo_id)
+WHERE
+ COALESCE(r.active_iteration_id != ri.repo_iteration_id, TRUE);
+
+CREATE TABLE items(
+ item_id INTEGER PRIMARY KEY,
+
+ -- "type" determines whether it's resource or mapping.
+ type CHAR(1) NOT NULL,
+ identifier VARCHAR NOT NULL,
+
+ UNIQUE (type, identifier),
+ CHECK (type IN ('R', 'M'))
+);
+
+CREATE TABLE mapping_statuses(
+ -- The item with this id shall be a mapping ("type" = 'M').
+ item_id INTEGER PRIMARY KEY,
+
+ -- "enabled" determines whether mapping's status is ENABLED,
+ -- DISABLED or NO_MARK.
+ enabled CHAR(1) NOT NULL,
+ enabled_version_id INTEGER NULL,
+ -- "frozen" determines whether an enabled mapping is to be kept in its
+ -- EXACT_VERSION, is to be updated only with versions from the same
+ -- REPOSITORY or is NOT_FROZEN at all.
+ frozen CHAR(1) NULL,
+
+ CHECK (NOT (enabled = 'D' AND enabled_version_id IS NOT NULL)),
+ CHECK (NOT (enabled = 'E' AND enabled_version_id IS NULL)),
+
+ CHECK ((frozen IS NULL) = (enabled != 'E')),
+ CHECK (frozen IS NULL OR frozen in ('E', 'R', 'N'))
+);
+
+CREATE TABLE item_versions(
+ item_version_id INTEGER PRIMARY KEY,
+
+ item_id INTEGER NOT NULL,
+ version VARCHAR NOT NULL,
+ repo_iteration_id INTEGER NOT NULL,
+ definition TEXT NOT NULL,
+ -- What privileges should be granted on pages where this
+ -- resource/mapping is used.
+ eval_allowed BOOLEAN NOT NULL,
+ cors_bypass_allowed BOOLEAN NOT NULL,
+
+ UNIQUE (item_id, version, repo_iteration_id),
+ -- Allow foreign key from "mapping_statuses".
+ UNIQUE (item_version_id, item_id),
+
+ FOREIGN KEY (item_id)
+ REFERENCES items (item_id),
+ FOREIGN KEY (repo_iteration_id)
+ REFERENCES repo_iterations (repo_iteration_id)
+);
+
+CREATE TABLE payloads(
+ payload_id INTEGER PRIMARY KEY,
+
+ mapping_item_id INTEGER NOT NULL,
+ pattern VARCHAR NOT NULL,
+
+ UNIQUE (mapping_item_id, pattern),
+
+ FOREIGN KEY (mapping_item_id)
+ REFERENCES item_versions (versioned_item_id)
+ ON DELETE CASCADE
+);
+
+CREATE TABLE resolved_depended_resources(
+ payload_id INTEGER,
+ resource_item_id INTEGER,
+
+ -- "idx" determines the ordering of resources.
+ idx INTEGER,
+
+ PRIMARY KEY (payload_id, resource_item_id),
+
+ FOREIGN KEY (payload_id)
+ REFERENCES payloads (payload_id),
+ FOREIGN KEY (resource_item_id)
+ REFERENCES item_versions (item_version_id)
+) WITHOUT ROWID;
+
+-- CREATE TABLE resolved_required_mappings(
+-- requiring_item_id INTEGER,
+-- required_mapping_item_id INTEGER,
+
+-- PRIMARY KEY (requiring_item_id, required_mapping_item_id),
+
+-- CHECK (requiring_item_id != required_mapping_item_id),
+
+-- FOREIGN KEY (requiring_item_id)
+-- REFERENCES items (item_id),
+-- -- Note: the referenced mapping shall have installed=TRUE.
+-- FOREIGN KEY (required_mapping_item_id)
+-- REFERENCES mappings (item_id),
+-- FOREIGN KEY (required_mapping_item_id)
+-- REFERENCES items (item_id)
+-- );
+
+CREATE TABLE files(
+ file_id INTEGER PRIMARY KEY,
+
+ sha256 CHAR(64) NOT NULL,
+ data BLOB NULL,
+
+ UNIQUE (sha256)
+);
+
+CREATE TABLE file_uses(
+ file_use_id INTEGER PRIMARY KEY,
+
+ item_version_id INTEGER NOT NULL,
+ file_id INTEGER NOT NULL,
+ name VARCHAR NOT NULL,
+ -- "type" determines whether it's license file or web resource.
+ type CHAR(1) NOT NULL,
+ mime_type VARCHAR NOT NULL,
+ -- "idx" determines the ordering of item's files of given type.
+ idx INTEGER NOT NULL,
+
+ CHECK (type IN ('L', 'W')),
+ UNIQUE(item_version_id, type, idx),
+
+ FOREIGN KEY (item_version_id)
+ REFERENCES item_versions(item_version_id)
+ ON DELETE CASCADE,
+ FOREIGN KEY (file_id)
+ REFERENCES files(file_id)
+);
+
+COMMIT TRANSACTION;
diff --git a/src/hydrilla/url_patterns.py b/src/hydrilla/url_patterns.py
index 8e80379..0a242e3 100644
--- a/src/hydrilla/url_patterns.py
+++ b/src/hydrilla/url_patterns.py
@@ -41,36 +41,73 @@ import dataclasses as dc
from immutables import Map
-from hydrilla.translations import smart_gettext as _
-from hydrilla.exceptions import HaketiloException
+from .translations import smart_gettext as _
+from .exceptions import HaketiloException
default_ports: t.Mapping[str, int] = Map(http=80, https=443, ftp=21)
-@dc.dataclass(frozen=True, unsafe_hash=True)
+ParsedUrlType = t.TypeVar('ParsedUrlType', bound='ParsedUrl')
+
+@dc.dataclass(frozen=True, unsafe_hash=True, order=True)
class ParsedUrl:
"""...."""
- orig_url: str # orig_url used in __hash__()
- scheme: str = dc.field(hash=False)
- domain_labels: tuple[str, ...] = dc.field(hash=False)
- path_segments: tuple[str, ...] = dc.field(hash=False)
- has_trailing_slash: bool = dc.field(hash=False)
- port: int = dc.field(hash=False)
-
- # def reconstruct_url(self) -> str:
- # """...."""
- # scheme = self.orig_scheme
-
- # netloc = '.'.join(reversed(self.domain_labels))
- # if scheme == self.scheme and \
- # self.port is not None and \
- # default_ports[scheme] != self.port:
- # netloc += f':{self.port}'
-
- # path = '/'.join(('', *self.path_segments))
- # if self.has_trailing_slash:
- # path += '/'
+ orig_url: str # used in __hash__() and __lt__()
+ scheme: str = dc.field(hash=False, compare=False)
+ domain_labels: tuple[str, ...] = dc.field(hash=False, compare=False)
+ path_segments: tuple[str, ...] = dc.field(hash=False, compare=False)
+ has_trailing_slash: bool = dc.field(hash=False, compare=False)
+ port: int = dc.field(hash=False, compare=False)
+
+ @property
+ def url_without_path(self) -> str:
+ """...."""
+ scheme = self.scheme
+
+ netloc = '.'.join(reversed(self.domain_labels))
+
+ if self.port is not None and \
+ default_ports[scheme] != self.port:
+ netloc += f':{self.port}'
+
+ return f'{scheme}://{netloc}'
+
+ def _reconstruct_url(self) -> str:
+ """...."""
+ path = '/'.join(('', *self.path_segments))
+ if self.has_trailing_slash:
+ path += '/'
+
+ return self.url_without_path + path
+
+ def path_append(self: ParsedUrlType, *new_segments: str) -> ParsedUrlType:
+ """...."""
+ new_url = self._reconstruct_url()
+ if not self.has_trailing_slash:
+ new_url += '/'
+
+ new_url += '/'.join(new_segments)
+
+ return dc.replace(
+ self,
+ orig_url = new_url,
+ path_segments = tuple((*self.path_segments, *new_segments)),
+ has_trailing_slash = False
+ )
+
+ParsedPattern = t.NewType('ParsedPattern', ParsedUrl)
+
+# # We sometimes need a dummy pattern that means "match everything".
+# catchall_pattern = ParsedPattern(
+# ParsedUrl(
+# orig_url = '<dummy_catchall_url_pattern>'
+# scheme = '<dummy_all-scheme>'
+# domain_labels = ('***',)
+# path_segments = ('***',)
+# has_trailing_slash = False
+# port = 0
+# )
+# )
- # return f'{scheme}://{netloc}{path}'
# URLs with those schemes will be recognized but not all of them have to be
# actually supported by Hydrilla server and Haketilo proxy.
@@ -163,7 +200,7 @@ def _parse_pattern_or_url(url: str, orig_url: str, is_pattern: bool = False) \
replace_scheme_regex = re.compile(r'^[^:]*')
-def parse_pattern(url_pattern: str) -> t.Sequence[ParsedUrl]:
+def parse_pattern(url_pattern: str) -> t.Iterator[ParsedPattern]:
"""...."""
if url_pattern.startswith('http*:'):
patterns = [
@@ -173,8 +210,8 @@ def parse_pattern(url_pattern: str) -> t.Sequence[ParsedUrl]:
else:
patterns = [url_pattern]
- return tuple(_parse_pattern_or_url(pat, url_pattern, True)
- for pat in patterns)
+ for pat in patterns:
+ yield ParsedPattern(_parse_pattern_or_url(pat, url_pattern, True))
def parse_url(url: str) -> ParsedUrl:
"""...."""
diff --git a/src/hydrilla/versions.py b/src/hydrilla/versions.py
index a7a9f29..7474d98 100644
--- a/src/hydrilla/versions.py
+++ b/src/hydrilla/versions.py
@@ -34,14 +34,16 @@ from __future__ import annotations
import typing as t
-def normalize_version(ver: t.Sequence[int]) -> tuple[int, ...]:
- """Strip right-most zeroes from 'ver'. The original list is not modified."""
+VerTuple = t.NewType('VerTuple', 'tuple[int, ...]')
+
+def normalize_version(ver: t.Sequence[int]) -> VerTuple:
+ """Strip rightmost zeroes from 'ver'."""
new_len = 0
for i, num in enumerate(ver):
if num != 0:
new_len = i + 1
- return tuple(ver[:new_len])
+ return VerTuple(tuple(ver[:new_len]))
def parse_version(ver_str: str) -> tuple[int, ...]:
"""
@@ -50,10 +52,16 @@ def parse_version(ver_str: str) -> tuple[int, ...]:
"""
return tuple(int(num) for num in ver_str.split('.'))
-def version_string(ver: t.Sequence[int], rev: t.Optional[int] = None) -> str:
+def parse_normalize_version(ver_str: str) -> VerTuple:
+ """
+ Convert 'ver_str' into a VerTuple representation, e.g. for
+ ver_str="4.6.13.0" return (4, 6, 13).
+ """
+ return normalize_version(parse_version(ver_str))
+
+def version_string(ver: VerTuple, rev: t.Optional[int] = None) -> str:
"""
Produce version's string representation (optionally with revision), like:
1.2.3-5
- No version normalization is performed.
"""
return '.'.join(str(n) for n in ver) + ('' if rev is None else f'-{rev}')