# SPDX-License-Identifier: GPL-3.0-or-later
# Haketilo proxy data and configuration (instantiatable HaketiloState subtype).
#
# This file is part of Hydrilla&Haketilo.
#
# Copyright (C) 2022 Wojtek Kosior
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
#
# I, Wojtek Kosior, thereby promise not to sue for violation of this
# file's license. Although I request that you do not make use this code
# in a proprietary program, I am not going to enforce this in court.
"""
This module contains logic for keeping track of all settings, rules, mappings
and resources.
"""
# Enable using with Python 3.7.
from __future__ import annotations
import secrets
import io
import hashlib
import typing as t
import dataclasses as dc
from pathlib import Path
import sqlite3
from ...exceptions import HaketiloException
from ...translations import smart_gettext as _
from ... import pattern_tree
from ... import url_patterns
from ... import versions
from ... import item_infos
from ..simple_dependency_satisfying import compute_payloads, ComputedPayload
from .. import state as st
from .. import policies
from . import base
here = Path(__file__).resolve().parent
@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
class ConcreteRepoRef(st.RepoRef):
def remove(self, state: st.HaketiloState) -> None:
raise NotImplementedError()
def update(
self,
state: st.HaketiloState,
*,
name: t.Optional[str] = None,
url: t.Optional[str] = None
) -> ConcreteRepoRef:
raise NotImplementedError()
def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef:
raise NotImplementedError()
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteRepoIterationRef(st.RepoIterationRef):
pass
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteMappingRef(st.MappingRef):
def disable(self, state: st.HaketiloState) -> None:
raise NotImplementedError()
def forget_enabled(self, state: st.HaketiloState) -> None:
raise NotImplementedError()
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteMappingVersionRef(st.MappingVersionRef):
def enable(self, state: st.HaketiloState) -> None:
raise NotImplementedError()
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteResourceRef(st.ResourceRef):
pass
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteResourceVersionRef(st.ResourceVersionRef):
pass
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcretePayloadRef(st.PayloadRef):
def get_data(self, state: st.HaketiloState) -> st.PayloadData:
return t.cast(ConcreteHaketiloState, state).payloads_data[self]
def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
return 'to implement'
def get_script_paths(self, state: st.HaketiloState) \
-> t.Iterable[t.Sequence[str]]:
with t.cast(ConcreteHaketiloState, state).cursor() as cursor:
cursor.execute(
'''
SELECT
i.identifier, fu.name
FROM
payloads AS p
LEFT JOIN resolved_depended_resources AS rdd
USING (payload_id)
LEFT JOIN item_versions AS iv
ON rdd.resource_item_id = iv.item_version_id
LEFT JOIN items AS i
USING (item_id)
LEFT JOIN file_uses AS fu
USING (item_version_id)
WHERE
fu.type = 'W' AND
p.payload_id = ? AND
(fu.idx IS NOT NULL OR rdd.idx IS NULL)
ORDER BY
rdd.idx, fu.idx;
''',
(self.id,)
)
paths: list[t.Sequence[str]] = []
for resource_identifier, file_name in cursor.fetchall():
if resource_identifier is None:
# payload found but it had no script files
return ()
paths.append((resource_identifier, *file_name.split('/')))
if paths == []:
# payload not found
raise st.MissingItemError()
return paths
def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
-> t.Optional[st.FileData]:
if len(path) == 0:
raise st.MissingItemError()
resource_identifier, *file_name_segments = path
file_name = '/'.join(file_name_segments)
with t.cast(ConcreteHaketiloState, state).cursor() as cursor:
cursor.execute(
'''
SELECT
f.data, fu.mime_type
FROM
payloads AS p
JOIN resolved_depended_resources AS rdd
USING (payload_id)
JOIN item_versions AS iv
ON rdd.resource_item_id = iv.item_version_id
JOIN items AS i
USING (item_id)
JOIN file_uses AS fu
USING (item_version_id)
JOIN files AS f
USING (file_id)
WHERE
p.payload_id = ? AND
i.identifier = ? AND
fu.name = ? AND
fu.type = 'W';
''',
(self.id, resource_identifier, file_name)
)
result = cursor.fetchall()
if result == []:
return None
(data, mime_type), = result
return st.FileData(type=mime_type, name=file_name, contents=data)
# @dc.dataclass(frozen=True, unsafe_hash=True)
# class ConcretePayloadRef(st.PayloadRef):
# computed_payload: ComputedPayload = dc.field(hash=False, compare=False)
# def get_data(self, state: st.HaketiloState) -> st.PayloadData:
# return t.cast(ConcreteHaketiloState, state).payloads_data[self.id]
# def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
# return 'to implement'
# def get_script_paths(self, state: st.HaketiloState) \
# -> t.Iterator[t.Sequence[str]]:
# for resource_info in self.computed_payload.resources:
# for file_spec in resource_info.scripts:
# yield (resource_info.identifier, *file_spec.name.split('/'))
# def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
# -> t.Optional[st.FileData]:
# if len(path) == 0:
# raise st.MissingItemError()
# resource_identifier, *file_name_segments = path
# file_name = '/'.join(file_name_segments)
# script_sha256 = ''
# matched_resource_info = False
# for resource_info in self.computed_payload.resources:
# if resource_info.identifier == resource_identifier:
# matched_resource_info = True
# for script_spec in resource_info.scripts:
# if script_spec.name == file_name:
# script_sha256 = script_spec.sha256
# break
# if not matched_resource_info:
# raise st.MissingItemError(resource_identifier)
# if script_sha256 == '':
# return None
# store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir
# files_dir_path = store_dir_path / 'temporary_malcontent' / 'file'
# file_path = files_dir_path / 'sha256' / script_sha256
# return st.FileData(
# type = 'application/javascript',
# name = file_name,
# contents = file_path.read_bytes()
# )
def register_payload(
policy_tree: base.PolicyTree,
pattern: url_patterns.ParsedPattern,
payload_key: st.PayloadKey,
token: str
) -> base.PolicyTree:
"""...."""
payload_policy_factory = policies.PayloadPolicyFactory(
builtin = False,
payload_key = payload_key
)
policy_tree = policy_tree.register(pattern, payload_policy_factory)
resource_policy_factory = policies.PayloadResourcePolicyFactory(
builtin = False,
payload_key = payload_key
)
policy_tree = policy_tree.register(
pattern.path_append(token, '***'),
resource_policy_factory
)
return policy_tree
DataById = t.Mapping[str, st.PayloadData]
AnyInfoVar = t.TypeVar(
'AnyInfoVar',
item_infos.ResourceInfo,
item_infos.MappingInfo
)
def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
-> t.Iterator[tuple[AnyInfoVar, str]]:
item_type_path = malcontent_path / item_class.type_name
if not item_type_path.is_dir():
return
for item_path in item_type_path.iterdir():
if not item_path.is_dir():
continue
for item_version_path in item_path.iterdir():
definition = item_version_path.read_text()
item_info = item_class.load(io.StringIO(definition))
assert item_info.identifier == item_path.name
assert versions.version_string(item_info.version) == \
item_version_path.name
yield item_info, definition
def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int:
cursor.execute(
'''
INSERT OR IGNORE INTO repos(name, url, deleted, next_iteration)
VALUES(?, '', TRUE, 2);
''',
(repo_name,)
)
cursor.execute(
'''
SELECT
repo_id, next_iteration - 1
FROM
repos
WHERE
name = ?;
''',
(repo_name,)
)
(repo_id, last_iteration), = cursor.fetchall()
cursor.execute(
'''
INSERT OR IGNORE INTO repo_iterations(repo_id, iteration)
VALUES(?, ?);
''',
(repo_id, last_iteration)
)
cursor.execute(
'''
SELECT
repo_iteration_id
FROM
repo_iterations
WHERE
repo_id = ? AND iteration = ?;
''',
(repo_id, last_iteration)
)
(repo_iteration_id,), = cursor.fetchall()
return repo_iteration_id
def get_or_make_item(cursor: sqlite3.Cursor, type: str, identifier: str) -> int:
type_letter = {'resource': 'R', 'mapping': 'M'}[type]
cursor.execute(
'''
INSERT OR IGNORE INTO items(type, identifier)
VALUES(?, ?);
''',
(type_letter, identifier)
)
cursor.execute(
'''
SELECT
item_id
FROM
items
WHERE
type = ? AND identifier = ?;
''',
(type_letter, identifier)
)
(item_id,), = cursor.fetchall()
return item_id
def get_or_make_item_version(
cursor: sqlite3.Cursor,
item_id: int,
repo_iteration_id: int,
version: versions.VerTuple,
definition: str
) -> int:
ver_str = versions.version_string(version)
cursor.execute(
'''
INSERT OR IGNORE INTO item_versions(
item_id,
version,
repo_iteration_id,
definition
)
VALUES(?, ?, ?, ?);
''',
(item_id, ver_str, repo_iteration_id, definition)
)
cursor.execute(
'''
SELECT
item_version_id
FROM
item_versions
WHERE
item_id = ? AND version = ? AND repo_iteration_id = ?;
''',
(item_id, ver_str, repo_iteration_id)
)
(item_version_id,), = cursor.fetchall()
return item_version_id
def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None:
cursor.execute(
'''
INSERT OR IGNORE INTO mapping_statuses(item_id, enabled)
VALUES(?, 'N');
''',
(item_id,)
)
def get_or_make_file(cursor: sqlite3.Cursor, sha256: str, file_bytes: bytes) \
-> int:
cursor.execute(
'''
INSERT OR IGNORE INTO files(sha256, data)
VALUES(?, ?)
''',
(sha256, file_bytes)
)
cursor.execute(
'''
SELECT
file_id
FROM
files
WHERE
sha256 = ?;
''',
(sha256,)
)
(file_id,), = cursor.fetchall()
return file_id
def make_file_use(
cursor: sqlite3.Cursor,
item_version_id: int,
file_id: int,
name: str,
type: str,
mime_type: str,
idx: int
) -> None:
cursor.execute(
'''
INSERT OR IGNORE INTO file_uses(
item_version_id,
file_id,
name,
type,
mime_type,
idx
)
VALUES(?, ?, ?, ?, ?, ?);
''',
(item_version_id, file_id, name, type, mime_type, idx)
)
def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \
-> t.Mapping[AnyInfoVar, int]:
cursor.execute(
'''
SELECT
i.item_id, iv.definition, r.name, ri.iteration
FROM
item_versions AS iv
JOIN items AS i USING (item_id)
JOIN repo_iterations AS ri USING (repo_iteration_id)
JOIN repos AS r USING (repo_id)
WHERE
i.type = ?;
''',
(info_type.type_name[0].upper(),)
)
result: dict[AnyInfoVar, int] = {}
for item_id, definition, repo_name, repo_iteration in cursor.fetchall():
definition_io = io.StringIO(definition)
info = info_type.load(definition_io, repo_name, repo_iteration)
result[info] = item_id
return result
@dc.dataclass
class ConcreteHaketiloState(base.HaketiloStateWithFields):
def __post_init__(self) -> None:
self._prepare_database()
self._populate_database_with_stuff_from_temporary_malcontent_dir()
with self.cursor(transaction=True) as cursor:
self.recompute_payloads(cursor)
def _prepare_database(self) -> None:
"""...."""
cursor = self.connection.cursor()
try:
cursor.execute(
'''
SELECT
COUNT(name)
FROM
sqlite_master
WHERE
name = 'general' AND type = 'table';
'''
)
(db_initialized,), = cursor.fetchall()
if not db_initialized:
cursor.executescript((here.parent / 'tables.sql').read_text())
else:
cursor.execute(
'''
SELECT
haketilo_version
FROM
general;
'''
)
(db_haketilo_version,) = cursor.fetchone()
if db_haketilo_version != '3.0b1':
raise HaketiloException(_('err.unknown_db_schema'))
cursor.execute('PRAGMA FOREIGN_KEYS;')
if cursor.fetchall() == []:
raise HaketiloException(_('err.proxy.no_sqlite_foreign_keys'))
cursor.execute('PRAGMA FOREIGN_KEYS=ON;')
finally:
cursor.close()
def _populate_database_with_stuff_from_temporary_malcontent_dir(self) \
-> None:
malcontent_dir_path = self.store_dir / 'temporary_malcontent'
files_by_sha256_path = malcontent_dir_path / 'file' / 'sha256'
with self.cursor(transaction=True) as cursor:
for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]:
info: item_infos.AnyInfo
for info, definition in read_items(
malcontent_dir_path,
info_type # type: ignore
):
repo_iteration_id = get_or_make_repo_iteration(
cursor,
info.repo
)
item_id = get_or_make_item(
cursor,
info.type_name,
info.identifier
)
item_version_id = get_or_make_item_version(
cursor,
item_id,
repo_iteration_id,
info.version,
definition
)
if info_type is item_infos.MappingInfo:
make_mapping_status(cursor, item_id)
file_ids_bytes = {}
file_specifiers = [*info.source_copyright]
if isinstance(info, item_infos.ResourceInfo):
file_specifiers.extend(info.scripts)
for file_spec in file_specifiers:
file_path = files_by_sha256_path / file_spec.sha256
file_bytes = file_path.read_bytes()
sha256 = hashlib.sha256(file_bytes).digest().hex()
assert sha256 == file_spec.sha256
file_id = get_or_make_file(cursor, sha256, file_bytes)
file_ids_bytes[sha256] = (file_id, file_bytes)
for idx, file_spec in enumerate(info.source_copyright):
file_id, file_bytes = file_ids_bytes[file_spec.sha256]
if file_bytes.isascii():
mime = 'text/plain'
else:
mime = 'application/octet-stream'
make_file_use(
cursor,
item_version_id = item_version_id,
file_id = file_id,
name = file_spec.name,
type = 'L',
mime_type = mime,
idx = idx
)
if isinstance(info, item_infos.MappingInfo):
continue
for idx, file_spec in enumerate(info.scripts):
file_id, _ = file_ids_bytes[file_spec.sha256]
make_file_use(
cursor,
item_version_id = item_version_id,
file_id = file_id,
name = file_spec.name,
type = 'W',
mime_type = 'application/javascript',
idx = idx
)
def recompute_payloads(self, cursor: sqlite3.Cursor) -> None:
assert self.connection.in_transaction
resources = get_infos_of_type(cursor, item_infos.ResourceInfo)
mappings = get_infos_of_type(cursor, item_infos.MappingInfo)
payloads = compute_payloads(resources.keys(), mappings.keys())
payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
cursor.execute('DELETE FROM payloads;')
for mapping_info, by_pattern in payloads.items():
for num, (pattern, payload) in enumerate(by_pattern.items()):
cursor.execute(
'''
INSERT INTO payloads(
mapping_item_id,
pattern,
eval_allowed,
cors_bypass_allowed
)
VALUES (?, ?, ?, ?);
''',
(
mappings[mapping_info],
pattern.orig_url,
payload.allows_eval,
payload.allows_cors_bypass
)
)
cursor.execute(
'''
SELECT
payload_id
FROM
payloads
WHERE
mapping_item_id = ? AND pattern = ?;
''',
(mappings[mapping_info], pattern.orig_url)
)
(payload_id_int,), = cursor.fetchall()
for res_num, resource_info in enumerate(payload.resources):
cursor.execute(
'''
INSERT INTO resolved_depended_resources(
payload_id,
resource_item_id,
idx
)
VALUES(?, ?, ?);
''',
(payload_id_int, resources[resource_info], res_num)
)
self._rebuild_structures(cursor)
def _rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
'''
SELECT
p.payload_id, p.pattern, p.eval_allowed,
p.cors_bypass_allowed,
ms.enabled,
i.identifier
FROM
payloads AS p
JOIN item_versions AS iv
ON p.mapping_item_id = iv.item_version_id
JOIN items AS i USING (item_id)
JOIN mapping_statuses AS ms USING (item_id);
'''
)
new_policy_tree = base.PolicyTree()
new_payloads_data: dict[st.PayloadRef, st.PayloadData] = {}
for row in cursor.fetchall():
(payload_id_int, pattern, eval_allowed, cors_bypass_allowed,
enabled_status,
identifier) = row
payload_ref = ConcretePayloadRef(str(payload_id_int))
previous_data = self.payloads_data.get(payload_ref)
if previous_data is not None:
token = previous_data.unique_token
else:
token = secrets.token_urlsafe(8)
payload_key = st.PayloadKey(payload_ref, identifier)
for parsed_pattern in url_patterns.parse_pattern(pattern):
new_policy_tree = register_payload(
new_policy_tree,
parsed_pattern,
payload_key,
token
)
pattern_path_segments = parsed_pattern.path_segments
payload_data = st.PayloadData(
payload_ref = payload_ref,
#explicitly_enabled = enabled_status == 'E',
explicitly_enabled = True,
unique_token = token,
pattern_path_segments = pattern_path_segments,
eval_allowed = eval_allowed,
cors_bypass_allowed = cors_bypass_allowed
)
new_payloads_data[payload_ref] = payload_data
self.policy_tree = new_policy_tree
self.payloads_data = new_payloads_data
def get_repo(self, repo_id: str) -> st.RepoRef:
return ConcreteRepoRef(repo_id)
def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef:
return ConcreteRepoIterationRef(repo_iteration_id)
def get_mapping(self, mapping_id: str) -> st.MappingRef:
return ConcreteMappingRef(mapping_id)
def get_mapping_version(self, mapping_version_id: str) \
-> st.MappingVersionRef:
return ConcreteMappingVersionRef(mapping_version_id)
def get_resource(self, resource_id: str) -> st.ResourceRef:
return ConcreteResourceRef(resource_id)
def get_resource_version(self, resource_version_id: str) \
-> st.ResourceVersionRef:
return ConcreteResourceVersionRef(resource_version_id)
def get_payload(self, payload_id: str) -> st.PayloadRef:
return 'not implemented'
def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \
-> st.RepoRef:
raise NotImplementedError()
def get_settings(self) -> st.HaketiloGlobalSettings:
return st.HaketiloGlobalSettings(
mapping_use_mode = st.MappingUseMode.AUTO,
default_allow_scripts = True,
repo_refresh_seconds = 0
)
def update_settings(
self,
*,
mapping_use_mode: t.Optional[st.MappingUseMode] = None,
default_allow_scripts: t.Optional[bool] = None,
repo_refresh_seconds: t.Optional[int] = None
) -> None:
raise NotImplementedError()
def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy:
"""...."""
with self.lock:
policy_tree = self.policy_tree
try:
best_priority: int = 0
best_policy: t.Optional[policies.Policy] = None
for factories_set in policy_tree.search(url):
for stored_factory in sorted(factories_set):
factory = stored_factory.item
policy = factory.make_policy(self)
if policy.priority > best_priority:
best_priority = policy.priority
best_policy = policy
except Exception as e:
return policies.ErrorBlockPolicy(
builtin = True,
error = e
)
if best_policy is not None:
return best_policy
if self.get_settings().default_allow_scripts:
return policies.FallbackAllowPolicy()
else:
return policies.FallbackBlockPolicy()
@staticmethod
def make(store_dir: Path) -> 'ConcreteHaketiloState':
connection = sqlite3.connect(
str(store_dir / 'sqlite3.db'),
isolation_level = None,
check_same_thread = False
)
return ConcreteHaketiloState(
store_dir = store_dir,
connection = connection
)