diff options
Diffstat (limited to 'src/hydrilla/proxy/state_impl')
-rw-r--r-- | src/hydrilla/proxy/state_impl/base.py | 3 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/concrete_state.py | 45 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/repos.py | 2 | ||||
-rw-r--r-- | src/hydrilla/proxy/state_impl/rules.py | 180 |
4 files changed, 218 insertions, 12 deletions
diff --git a/src/hydrilla/proxy/state_impl/base.py b/src/hydrilla/proxy/state_impl/base.py index f969b19..a8800cb 100644 --- a/src/hydrilla/proxy/state_impl/base.py +++ b/src/hydrilla/proxy/state_impl/base.py @@ -246,7 +246,8 @@ class HaketiloStateWithFields(st.HaketiloState): ... @abstractmethod - def rebuild_structures(self) -> None: + def rebuild_structures(self, *, payloads: bool = True, rules: bool = True) \ + -> None: """ Recreation of data structures as done after every recomputation of dependencies as well as at startup. diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index cd32e83..a6d32f1 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -47,6 +47,7 @@ from .. import state as st from .. import policies from .. import simple_dependency_satisfying as sds from . import base +from . import rules from . import items from . import repos from . import payloads @@ -143,7 +144,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): repo_id = repo_id ) - self.rebuild_structures() + self.rebuild_structures(rules=False) def prune_orphans(self) -> None: with self.cursor() as cursor: @@ -163,7 +164,7 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): unlocked_required_mappings = unlocked_required_mappings ) - self.rebuild_structures() + self.rebuild_structures(rules=False) def pull_missing_files(self) -> None: with self.cursor() as cursor: @@ -182,6 +183,29 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): ui_factory ) + # Put script blocking/allowing rules in policy tree. + cursor.execute('SELECT pattern, allow_scripts FROM rules;') + + for pattern, allow_scripts in cursor.fetchall(): + for parsed_pattern in url_patterns.parse_pattern(pattern): + factory: policies.PolicyFactory + if allow_scripts: + factory = policies.RuleAllowPolicyFactory( + builtin = False, + pattern = parsed_pattern + ) + else: + factory = policies.RuleBlockPolicyFactory( + builtin = False, + pattern = parsed_pattern + ) + + new_policy_tree = new_policy_tree.register( + parsed_pattern = parsed_pattern, + item = factory + ) + + # Put script payload rules in policy tree. cursor.execute( ''' SELECT @@ -202,15 +226,10 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): ''' ) - rows = cursor.fetchall() - new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {} - for row in rows: - (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, - enabled_status, - identifier) = row - + for (payload_id_int, pattern, eval_allowed, cors_bypass_allowed, + enabled_status, identifier) in cursor.fetchall(): payload_ref = payloads.ConcretePayloadRef(str(payload_id_int), self) previous_data = self.payloads_data.get(payload_ref) @@ -245,10 +264,16 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): self.policy_tree = new_policy_tree self.payloads_data = new_payloads_data - def rebuild_structures(self) -> None: + def rebuild_structures(self, *, payloads: bool = True, rules: bool = True) \ + -> None: + # The `payloads` and `rules` args will be useful for optimization but + # for now we're not yet using them. with self.cursor() as cursor: self._rebuild_structures(cursor) + def rule_store(self) -> st.RuleStore: + return rules.ConcreteRuleStore(self) + def repo_store(self) -> st.RepoStore: return repos.ConcreteRepoStore(self) diff --git a/src/hydrilla/proxy/state_impl/repos.py b/src/hydrilla/proxy/state_impl/repos.py index 383d147..4afd86f 100644 --- a/src/hydrilla/proxy/state_impl/repos.py +++ b/src/hydrilla/proxy/state_impl/repos.py @@ -235,7 +235,7 @@ class ConcreteRepoRef(st.RepoRef): except sqlite3.IntegrityError: raise st.RepoNameTaken() - self.state.rebuild_structures() + self.state.rebuild_structures(rules=False) def refresh(self) -> None: with self.state.cursor(transaction=True) as cursor: diff --git a/src/hydrilla/proxy/state_impl/rules.py b/src/hydrilla/proxy/state_impl/rules.py new file mode 100644 index 0000000..bd9480d --- /dev/null +++ b/src/hydrilla/proxy/state_impl/rules.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +# Haketilo proxy data and configuration (RuleRef and RuleStore subtypes). +# +# This file is part of Hydrilla&Haketilo. +# +# Copyright (C) 2022 Wojtek Kosior +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +# +# +# I, Wojtek Kosior, thereby promise not to sue for violation of this +# file's license. Although I request that you do not make use this code +# in a proprietary program, I am not going to enforce this in court. + +""" +This module provides an interface to interact with script allowing/blocking +rules configured inside Haketilo. +""" + +# Enable using with Python 3.7. +from __future__ import annotations + +import sqlite3 +import typing as t +import dataclasses as dc + +from ... import url_patterns +from .. import state as st +from . import base + + +def ensure_rule_not_deleted(cursor: sqlite3.Cursor, rule_id: str) -> None: + cursor.execute('SELECT COUNT(*) from rules where rule_id = ?;', (rule_id,)) + + (rule_present,), = cursor.fetchall() + + if not rule_present: + raise st.MissingItemError() + +def sanitize_rule_pattern(pattern: str) -> str: + pattern = pattern.strip() + + try: + assert pattern + return url_patterns.normalize_pattern(pattern) + except: + raise st.RulePatternInvalid() + + +@dc.dataclass(frozen=True, unsafe_hash=True) +class ConcreteRuleRef(st.RuleRef): + state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False) + + def remove(self) -> None: + with self.state.cursor(transaction=True) as cursor: + ensure_rule_not_deleted(cursor, self.id) + + cursor.execute('DELETE FROM rules WHERE rule_id = ?;', self.id) + + self.state.rebuild_structures(payloads=False) + + def update( + self, + *, + pattern: t.Optional[str] = None, + allow: t.Optional[bool] = None + ) -> None: + if pattern is not None: + pattern = sanitize_rule_pattern(pattern) + + if pattern is None and allow is None: + return + + with self.state.cursor(transaction=True) as cursor: + ensure_rule_not_deleted(cursor, self.id) + + if allow is not None: + cursor.execute( + 'UPDATE rules SET allow_scripts = ? WHERE rule_id = ?;', + (allow, self.id) + ) + + if pattern is not None: + cursor.execute( + 'DELETE FROM rules WHERE pattern = ? AND rule_id != ?;', + (pattern, self.id) + ) + + cursor.execute( + 'UPDATE rules SET pattern = ? WHERE rule_id = ?;', + (pattern, self.id) + ) + + self.state.rebuild_structures(payloads=False) + + def get_display_info(self) -> st.RuleDisplayInfo: + with self.state.cursor() as cursor: + cursor.execute( + 'SELECT pattern, allow_scripts FROM rules WHERE rule_id = ?;', + (self.id,) + ) + + rows = cursor.fetchall() + + if rows == []: + raise st.MissingItemError() + + (pattern, allow), = rows + + return st.RuleDisplayInfo(self, pattern, allow) + + +@dc.dataclass(frozen=True) +class ConcreteRuleStore(st.RuleStore): + state: base.HaketiloStateWithFields + + def get(self, id: str) -> st.RuleRef: + return ConcreteRuleRef(str(int(id)), self.state) + + def add(self, pattern: str, allow: bool) -> st.RuleRef: + pattern = sanitize_rule_pattern(pattern) + + with self.state.cursor(transaction=True) as cursor: + cursor.execute( + ''' + INSERT INTO rules(pattern, allow_scripts) + VALUES (?, ?) + ON CONFLICT (pattern) + DO UPDATE SET allow_scripts = excluded.allow_scripts; + ''', + (pattern, allow) + ) + + cursor.execute( + 'SELECT rule_id FROM rules WHERE pattern = ?;', + (pattern,) + ) + + (rule_id,), = cursor.fetchall() + + return ConcreteRuleRef(str(rule_id), self.state) + + def get_display_infos(self, allow: t.Optional[bool] = None) \ + -> t.Sequence[st.RuleDisplayInfo]: + with self.state.cursor() as cursor: + cursor.execute( + ''' + SELECT + rule_id, pattern, allow_scripts + FROM + rules + WHERE + COALESCE(allow_scripts = ?, TRUE) + ORDER BY + pattern; + ''', + (allow,) + ) + + rows = cursor.fetchall() + + result = [] + for rule_id, pattern, allow_scripts in rows: + ref = ConcreteRuleRef(str(rule_id), self.state) + + result.append(st.RuleDisplayInfo(ref, pattern, allow_scripts)) + + return result |