diff options
Diffstat (limited to 'src/hydrilla/pattern_tree.py')
-rw-r--r-- | src/hydrilla/pattern_tree.py | 339 |
1 files changed, 339 insertions, 0 deletions
diff --git a/src/hydrilla/pattern_tree.py b/src/hydrilla/pattern_tree.py new file mode 100644 index 0000000..1128a06 --- /dev/null +++ b/src/hydrilla/pattern_tree.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Data structure for querying URL patterns. +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2021, 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +This module defines data structures for querying data using URL patterns. +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import sys +import typing as t +import dataclasses as dc + +from immutables import Map + +from .url_patterns import ParsedUrl, parse_url +from .translations import smart_gettext as _ + + +WrapperStoredType = t.TypeVar('WrapperStoredType', bound=t.Hashable) + +@dc.dataclass(frozen=True, unsafe_hash=True) +class StoredTreeItem(t.Generic[WrapperStoredType]): + """ + In the Pattern Tree, each item is stored together with the pattern used to + register it. + """ + pattern: ParsedUrl + item: WrapperStoredType + +# if sys.version_info >= (3, 8): +# CopyableType = t.TypeVar('CopyableType', bound='Copyable') + +# class Copyable(t.Protocol): +# """Certain classes in Pattern Tree depend on this interface.""" +# def copy(self: CopyableType) -> CopyableType: +# """Make a distinct instance with the same properties as this one.""" +# ... +# else: +# Copyable = t.Any + +NodeStoredType = t.TypeVar('NodeStoredType') + +@dc.dataclass(frozen=True) +class PatternTreeNode(t.Generic[NodeStoredType]): + """....""" + children: 'NodeChildrenType' = Map() + literal_match: t.Optional[NodeStoredType] = None + + def is_empty(self) -> bool: + """....""" + return len(self.children) == 0 and self.literal_match is None + + def update_literal_match( + self, + new_match_item: t.Optional[NodeStoredType] + ) -> 'NodeSelfType': + """....""" + return dc.replace(self, literal_match=new_match_item) + + def get_child(self, child_key: str) -> t.Optional['NodeSelfType']: + """....""" + return self.children.get(child_key) + + def remove_child(self, child_key: str) -> 'NodeSelfType': + """....""" + try: + children = self.children.delete(child_key) + except: + children = self.children + + return dc.replace(self, children=children) + + def set_child(self, child_key: str, child: 'NodeSelfType') \ + -> 'NodeSelfType': + """....""" + return dc.replace(self, children=self.children.set(child_key, child)) + +# Below we define 2 types used by recursively-typed PatternTreeNode. +NodeSelfType = PatternTreeNode[NodeStoredType] +NodeChildrenType = Map[str, NodeSelfType] + + +BranchStoredType = t.TypeVar('BranchStoredType') + +ItemUpdater = t.Callable[ + [t.Optional[BranchStoredType]], + t.Optional[BranchStoredType] +] + +@dc.dataclass(frozen=True) +class PatternTreeBranch(t.Generic[BranchStoredType]): + """....""" + root_node: PatternTreeNode[BranchStoredType] = PatternTreeNode() + + def is_empty(self) -> bool: + """....""" + return self.root_node.is_empty() + + # def copy(self) -> 'BranchSelfType': + # """....""" + # return dc.replace(self) + + def update(self, segments: t.Iterable[str], item_updater: ItemUpdater) \ + -> 'BranchSelfType': + """ + ....... + """ + node = self.root_node + nodes_segments = [] + + for segment in segments: + next_node = node.get_child(segment) + + nodes_segments.append((node, segment)) + + node = PatternTreeNode() if next_node is None else next_node + + node = node.update_literal_match(item_updater(node.literal_match)) + + while nodes_segments: + prev_node, segment = nodes_segments.pop() + + if node.is_empty(): + node = prev_node.remove_child(segment) + else: + node = prev_node.set_child(segment, node) + + return dc.replace(self, root_node=node) + + def search(self, segments: t.Sequence[str]) -> t.Iterable[BranchStoredType]: + """ + Yields all matches of this segments sequence against the tree. Results + are produced in order from greatest to lowest pattern specificity. + """ + nodes = [self.root_node] + + for segment in segments: + next_node = nodes[-1].get_child(segment) + if next_node is None: + break + + nodes.append(next_node) + + nsegments = len(segments) + cond_literal = lambda: len(nodes) == nsegments + cond_wildcard = [ + lambda: len(nodes) + 1 == nsegments and segments[-1] != '*', + lambda: len(nodes) + 1 < nsegments, + lambda: len(nodes) + 1 != nsegments or segments[-1] != '***' + ] + + while nodes: + node = nodes.pop() + + wildcard_matches = [node.get_child(wc) for wc in ('*', '**', '***')] + + for match_node, condition in [ + (node, cond_literal), + *zip(wildcard_matches, cond_wildcard) + ]: + if match_node is not None: + if match_node.literal_match is not None: + if condition(): + yield match_node.literal_match + +# Below we define 1 type used by recursively-typed PatternTreeBranch. +BranchSelfType = PatternTreeBranch[BranchStoredType] + + +FilterStoredType = t.TypeVar('FilterStoredType', bound=t.Hashable) +FilterWrappedType = StoredTreeItem[FilterStoredType] + +def filter_by_trailing_slash( + items: t.Iterable[FilterWrappedType], + with_slash: bool +) -> frozenset[FilterWrappedType]: + """....""" + return frozenset(wrapped for wrapped in items + if with_slash == wrapped.pattern.has_trailing_slash) + +TreeStoredType = t.TypeVar('TreeStoredType', bound=t.Hashable) + +StoredSet = frozenset[StoredTreeItem[TreeStoredType]] +PathBranch = PatternTreeBranch[StoredSet] +DomainBranch = PatternTreeBranch[PathBranch] +TreeRoot = Map[tuple[str, int], DomainBranch] + +@dc.dataclass(frozen=True) +class PatternTree(t.Generic[TreeStoredType]): + """ + "Pattern Tree" is how we refer to the data structure used for querying + Haketilo patterns. Those look like 'https://*.example.com/ab/***'. The goal + is to make it possible to quickly retrieve all known patterns that match + a given URL. + """ + _by_scheme_and_port: TreeRoot = Map() + + def _register( + self, + parsed_pattern: ParsedUrl, + item: TreeStoredType, + register: bool = True + ) -> 'TreeSelfType': + """ + Make an item wrapped in StoredTreeItem object queryable through the + Pattern Tree by the given parsed URL pattern. + """ + wrapped_item = StoredTreeItem(parsed_pattern, item) + + def item_updater(item_set: t.Optional[StoredSet]) \ + -> t.Optional[StoredSet]: + """....""" + if item_set is None: + item_set = frozenset() + + if register: + item_set = item_set.union((wrapped_item,)) + else: + item_set = item_set.difference((wrapped_item,)) + + return None if len(item_set) == 0 else item_set + + def path_branch_updater(path_branch: t.Optional[PathBranch]) \ + -> t.Optional[PathBranch]: + """....""" + if path_branch is None: + path_branch = PatternTreeBranch() + + path_branch = path_branch.update( + parsed_pattern.path_segments, + item_updater + ) + + return None if path_branch.is_empty() else path_branch + + key = (parsed_pattern.scheme, parsed_pattern.port) + domain_tree = self._by_scheme_and_port.get(key, PatternTreeBranch()) + + new_domain_tree = domain_tree.update( + parsed_pattern.domain_labels, + path_branch_updater + ) + + if new_domain_tree.is_empty(): + try: + new_root = self._by_scheme_and_port.delete(key) + except KeyError: + new_root = self._by_scheme_and_port + else: + new_root = self._by_scheme_and_port.set(key, new_domain_tree) + + return dc.replace(self, _by_scheme_and_port=new_root) + + # def _register( + # self, + # url_pattern: str, + # item: TreeStoredType, + # register: bool = True + # ) -> 'TreeSelfType': + # """ + # .... + # """ + # tree = self + + # for parsed_pat in parse_pattern(url_pattern): + # wrapped_item = StoredTreeItem(parsed_pat, item) + # tree = tree._register_with_parsed_pattern( + # parsed_pat, + # wrapped_item, + # register + # ) + + # return tree + + def register(self, parsed_pattern: ParsedUrl, item: TreeStoredType) \ + -> 'TreeSelfType': + """ + Make item queryable through the Pattern Tree by the given URL pattern. + """ + return self._register(parsed_pattern, item) + + def deregister(self, parsed_pattern: ParsedUrl, item: TreeStoredType) \ + -> 'TreeSelfType': + """ + Make item no longer queryable through the Pattern Tree by the given URL + pattern. + """ + return self._register(parsed_pattern, item, register=False) + + def search(self, url: t.Union[ParsedUrl, str]) -> t.Iterable[StoredSet]: + """ + .... + """ + parsed_url = parse_url(url) if isinstance(url, str) else url + + key = (parsed_url.scheme, parsed_url.port) + domain_tree = self._by_scheme_and_port.get(key) + if domain_tree is None: + return + + if parsed_url.has_trailing_slash: + slash_options = [True, False] + else: + slash_options = [False] + + for path_tree in domain_tree.search(parsed_url.domain_labels): + for item_set in path_tree.search(parsed_url.path_segments): + for with_slash in slash_options: + items = filter_by_trailing_slash(item_set, with_slash) + if len(items) > 0: + yield items + +# Below we define 1 type used by recursively-typed PatternTree. +TreeSelfType = PatternTree[TreeStoredType] |