aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/state_impl/rules.py
blob: 0cdcf2c949e33255610f88a33aa11293f7562240 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# SPDX-License-Identifier: GPL-3.0-or-later

# Haketilo proxy data and configuration (RuleRef and RuleStore subtypes).
#
# This file is part of Hydrilla&Haketilo.
#
# Copyright (C) 2022 Wojtek Kosior
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
#
# I, Wojtek Kosior, thereby promise not to sue for violation of this
# file's license. Although I request that you do not make use of this
# code in a proprietary program, I am not going to enforce this in
# court.

"""
This module provides an interface to interact with script allowing/blocking
rules configured inside Haketilo.
"""

# Enable using with Python 3.7.
from __future__ import annotations

import sqlite3
import typing as t
import dataclasses as dc

from ... import url_patterns
from .. import state as st
from . import base


def ensure_rule_not_deleted(cursor: sqlite3.Cursor, rule_id: str) -> None:
    cursor.execute('SELECT COUNT(*) from rules where rule_id = ?;', (rule_id,))

    (rule_present,), = cursor.fetchall()

    if not rule_present:
        raise st.MissingItemError()

def sanitize_rule_pattern(pattern: str) -> str:
    pattern = pattern.strip()

    try:
        assert pattern
        return url_patterns.normalize_pattern(pattern)
    except:
        raise st.RulePatternInvalid()


@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteRuleRef(st.RuleRef):
    state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False)

    def remove(self) -> None:
        with self.state.cursor(transaction=True) as cursor:
            ensure_rule_not_deleted(cursor, self.id)

            cursor.execute('DELETE FROM rules WHERE rule_id = ?;', self.id)

            self.state.rebuild_structures(payloads=False)

    def update(
            self,
            *,
            pattern: t.Optional[str]  = None,
            allow:   t.Optional[bool] = None
    ) -> None:
        if pattern is not None:
            pattern = sanitize_rule_pattern(pattern)

        if pattern is None and allow is None:
            return

        with self.state.cursor(transaction=True) as cursor:
            ensure_rule_not_deleted(cursor, self.id)

            if allow is not None:
                cursor.execute(
                    'UPDATE rules SET allow_scripts = ? WHERE rule_id = ?;',
                    (allow, self.id)
                )

            if pattern is not None:
                cursor.execute(
                    'DELETE FROM rules WHERE pattern = ? AND rule_id != ?;',
                    (pattern, self.id)
                )

                cursor.execute(
                    'UPDATE rules SET pattern = ? WHERE rule_id = ?;',
                    (pattern, self.id)
                )

            self.state.rebuild_structures(payloads=False)

    def get_display_info(self) -> st.RuleDisplayInfo:
        with self.state.cursor() as cursor:
            cursor.execute(
                'SELECT pattern, allow_scripts FROM rules WHERE rule_id = ?;',
                (self.id,)
            )

            rows = cursor.fetchall()

        if rows == []:
            raise st.MissingItemError()

        (pattern, allow), = rows

        return st.RuleDisplayInfo(self, pattern, allow)


@dc.dataclass(frozen=True)
class ConcreteRuleStore(st.RuleStore):
    state: base.HaketiloStateWithFields

    def get(self, id: str) -> st.RuleRef:
        return ConcreteRuleRef(str(int(id)), self.state)

    def add(self, pattern: str, allow: bool) -> st.RuleRef:
        pattern = sanitize_rule_pattern(pattern)

        with self.state.cursor(transaction=True) as cursor:
            cursor.execute(
                '''
                INSERT INTO rules(pattern, allow_scripts)
                VALUES (?, ?)
                ON CONFLICT (pattern)
                DO UPDATE SET allow_scripts = excluded.allow_scripts;
                ''',
                (pattern, allow)
            )

            cursor.execute(
                'SELECT rule_id FROM rules WHERE pattern = ?;',
                (pattern,)
            )

            (rule_id,), = cursor.fetchall()

            return ConcreteRuleRef(str(rule_id), self.state)

    def get_display_infos(self, allow: t.Optional[bool] = None) \
        -> t.Sequence[st.RuleDisplayInfo]:
        with self.state.cursor() as cursor:
            cursor.execute(
                '''
                SELECT
                        rule_id, pattern, allow_scripts
                FROM
                        rules
                WHERE
                        COALESCE(allow_scripts = ?, TRUE)
                ORDER BY
                        pattern;
                ''',
                (allow,)
            )

            rows = cursor.fetchall()

        result = []
        for rule_id, pattern, allow_scripts in rows:
            ref = ConcreteRuleRef(str(rule_id), self.state)

            result.append(st.RuleDisplayInfo(ref, pattern, allow_scripts))

        return result