# SPDX-License-Identifier: GPL-3.0-or-later
# Haketilo proxy data and configuration (instantiatable HaketiloState subtype).
#
# This file is part of Hydrilla&Haketilo.
#
# Copyright (C) 2022 Wojtek Kosior
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
#
# I, Wojtek Kosior, thereby promise not to sue for violation of this
# file's license. Although I request that you do not make use this code
# in a proprietary program, I am not going to enforce this in court.
"""
This module contains logic for keeping track of all settings, rules, mappings
and resources.
"""
# Enable using with Python 3.7.
from __future__ import annotations
import secrets
import io
import hashlib
import typing as t
import dataclasses as dc
from pathlib import Path
import sqlite3
from ...exceptions import HaketiloException
from ...translations import smart_gettext as _
from ... import pattern_tree
from ... import url_patterns
from ... import versions
from ... import item_infos
from ..simple_dependency_satisfying import compute_payloads, ComputedPayload
from .. import state as st
from .. import policies
from . import base
here = Path(__file__).resolve().parent
@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
class ConcreteRepoRef(st.RepoRef):
def remove(self, state: st.HaketiloState) -> None:
raise NotImplementedError()
def update(
self,
state: st.HaketiloState,
*,
name: t.Optional[str] = None,
url: t.Optional[str] = None
) -> ConcreteRepoRef:
raise NotImplementedError()
def refresh(self, state: st.HaketiloState) -> ConcreteRepoIterationRef:
raise NotImplementedError()
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteRepoIterationRef(st.RepoIterationRef):
pass
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteMappingRef(st.MappingRef):
def disable(self, state: st.HaketiloState) -> None:
raise NotImplementedError()
def forget_enabled(self, state: st.HaketiloState) -> None:
raise NotImplementedError()
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteMappingVersionRef(st.MappingVersionRef):
def enable(self, state: st.HaketiloState) -> None:
raise NotImplementedError()
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteResourceRef(st.ResourceRef):
pass
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteResourceVersionRef(st.ResourceVersionRef):
pass
@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcretePayloadRef(st.PayloadRef):
computed_payload: ComputedPayload = dc.field(hash=False, compare=False)
def get_data(self, state: st.HaketiloState) -> st.PayloadData:
return t.cast(ConcreteHaketiloState, state).payloads_data[self.id]
def get_mapping(self, state: st.HaketiloState) -> st.MappingVersionRef:
return 'to implement'
def get_script_paths(self, state: st.HaketiloState) \
-> t.Iterator[t.Sequence[str]]:
for resource_info in self.computed_payload.resources:
for file_spec in resource_info.scripts:
yield (resource_info.identifier, *file_spec.name.split('/'))
def get_file_data(self, state: st.HaketiloState, path: t.Sequence[str]) \
-> t.Optional[st.FileData]:
if len(path) == 0:
raise st.MissingItemError()
resource_identifier, *file_name_segments = path
file_name = '/'.join(file_name_segments)
script_sha256 = ''
matched_resource_info = False
for resource_info in self.computed_payload.resources:
if resource_info.identifier == resource_identifier:
matched_resource_info = True
for script_spec in resource_info.scripts:
if script_spec.name == file_name:
script_sha256 = script_spec.sha256
break
if not matched_resource_info:
raise st.MissingItemError(resource_identifier)
if script_sha256 == '':
return None
store_dir_path = t.cast(ConcreteHaketiloState, state).store_dir
files_dir_path = store_dir_path / 'temporary_malcontent' / 'file'
file_path = files_dir_path / 'sha256' / script_sha256
return st.FileData(
type = 'application/javascript',
name = file_name,
contents = file_path.read_bytes()
)
PolicyTree = pattern_tree.PatternTree[policies.PolicyFactory]
def register_payload(
policy_tree: PolicyTree,
payload_key: st.PayloadKey,
token: str
) -> PolicyTree:
"""...."""
payload_policy_factory = policies.PayloadPolicyFactory(
builtin = False,
payload_key = payload_key
)
policy_tree = policy_tree.register(
payload_key.pattern,
payload_policy_factory
)
resource_policy_factory = policies.PayloadResourcePolicyFactory(
builtin = False,
payload_key = payload_key
)
policy_tree = policy_tree.register(
payload_key.pattern.path_append(token, '***'),
resource_policy_factory
)
return policy_tree
DataById = t.Mapping[str, st.PayloadData]
AnyInfoVar = t.TypeVar(
'AnyInfoVar',
item_infos.ResourceInfo,
item_infos.MappingInfo
)
# def newest_item_path(item_dir: Path) -> t.Optional[Path]:
# available_versions = tuple(
# versions.parse_normalize_version(ver_path.name)
# for ver_path in item_dir.iterdir()
# if ver_path.is_file()
# )
# if available_versions == ():
# return None
# newest_version = max(available_versions)
# version_path = item_dir / versions.version_string(newest_version)
# assert version_path.is_file()
# return version_path
def read_items(malcontent_path: Path, item_class: t.Type[AnyInfoVar]) \
-> t.Iterator[tuple[AnyInfoVar, str]]:
item_type_path = malcontent_path / item_class.type_name
if not item_type_path.is_dir():
return
for item_path in item_type_path.iterdir():
if not item_path.is_dir():
continue
for item_version_path in item_path.iterdir():
definition = item_version_path.read_text()
item_info = item_class.load(io.StringIO(definition))
assert item_info.identifier == item_path.name
assert versions.version_string(item_info.version) == \
item_version_path.name
yield item_info, definition
def get_or_make_repo_iteration(cursor: sqlite3.Cursor, repo_name: str) -> int:
cursor.execute(
'''
INSERT OR IGNORE INTO repos(name, url, deleted, next_iteration)
VALUES(?, '', TRUE, 2);
''',
(repo_name,)
)
cursor.execute(
'''
SELECT
repo_id, next_iteration - 1
FROM
repos
WHERE
name = ?;
''',
(repo_name,)
)
(repo_id, last_iteration), = cursor.fetchall()
cursor.execute(
'''
INSERT OR IGNORE INTO repo_iterations(repo_id, iteration)
VALUES(?, ?);
''',
(repo_id, last_iteration)
)
cursor.execute(
'''
SELECT
repo_iteration_id
FROM
repo_iterations
WHERE
repo_id = ? AND iteration = ?;
''',
(repo_id, last_iteration)
)
(repo_iteration_id,), = cursor.fetchall()
return repo_iteration_id
def get_or_make_item(cursor: sqlite3.Cursor, type: str, identifier: str) -> int:
type_letter = {'resource': 'R', 'mapping': 'M'}[type]
cursor.execute(
'''
INSERT OR IGNORE INTO items(type, identifier)
VALUES(?, ?);
''',
(type_letter, identifier)
)
cursor.execute(
'''
SELECT
item_id
FROM
items
WHERE
type = ? AND identifier = ?;
''',
(type_letter, identifier)
)
(item_id,), = cursor.fetchall()
return item_id
def get_or_make_item_version(
cursor: sqlite3.Cursor,
item_id: int,
repo_iteration_id: int,
version: versions.VerTuple,
definition: str
) -> int:
ver_str = versions.version_string(version)
cursor.execute(
'''
INSERT OR IGNORE INTO item_versions(
item_id,
version,
repo_iteration_id,
definition
)
VALUES(?, ?, ?, ?);
''',
(item_id, ver_str, repo_iteration_id, definition)
)
cursor.execute(
'''
SELECT
item_version_id
FROM
item_versions
WHERE
item_id = ? AND version = ? AND repo_iteration_id = ?;
''',
(item_id, ver_str, repo_iteration_id)
)
(item_version_id,), = cursor.fetchall()
return item_version_id
def make_mapping_status(cursor: sqlite3.Cursor, item_id: int) -> None:
cursor.execute(
'''
INSERT OR IGNORE INTO mapping_statuses(item_id, enabled)
VALUES(?, 'N');
''',
(item_id,)
)
def get_or_make_file(cursor: sqlite3.Cursor, sha256: str, file_bytes: bytes) \
-> int:
cursor.execute(
'''
INSERT OR IGNORE INTO files(sha256, data)
VALUES(?, ?)
''',
(sha256, file_bytes)
)
cursor.execute(
'''
SELECT
file_id
FROM
files
WHERE
sha256 = ?;
''',
(sha256,)
)
(file_id,), = cursor.fetchall()
return file_id
def make_file_use(
cursor: sqlite3.Cursor,
item_version_id: int,
file_id: int,
name: str,
type: str,
mime_type: str,
idx: int
) -> None:
cursor.execute(
'''
INSERT OR IGNORE INTO file_uses(
item_version_id,
file_id,
name,
type,
mime_type,
idx
)
VALUES(?, ?, ?, ?, ?, ?);
''',
(item_version_id, file_id, name, type, mime_type, idx)
)
def get_infos_of_type(cursor: sqlite3.Cursor, info_type: t.Type[AnyInfoVar],) \
-> t.Mapping[AnyInfoVar, int]:
cursor.execute(
'''
SELECT
i.item_id, iv.definition, r.name, ri.iteration
FROM
item_versions AS iv
JOIN items AS i USING (item_id)
JOIN repo_iterations AS ri USING (repo_iteration_id)
JOIN repos AS r USING (repo_id)
WHERE
i.type = ?;
''',
(info_type.type_name[0].upper(),)
)
result: dict[AnyInfoVar, int] = {}
for item_id, definition, repo_name, repo_iteration in cursor.fetchall():
definition_io = io.StringIO(definition)
info = info_type.load(definition_io, repo_name, repo_iteration)
result[info] = item_id
return result
@dc.dataclass
class ConcreteHaketiloState(base.HaketiloStateWithFields):
def __post_init__(self) -> None:
self._prepare_database()
self._populate_database_with_stuff_from_temporary_malcontent_dir()
with self.cursor(transaction=True) as cursor:
self.rebuild_structures(cursor)
def _prepare_database(self) -> None:
"""...."""
cursor = self.connection.cursor()
try:
cursor.execute(
'''
SELECT
COUNT(name)
FROM
sqlite_master
WHERE
name = 'general' AND type = 'table';
'''
)
(db_initialized,), = cursor.fetchall()
if not db_initialized:
cursor.executescript((here.parent / 'tables.sql').read_text())
else:
cursor.execute(
'''
SELECT
haketilo_version
FROM
general;
'''
)
(db_haketilo_version,) = cursor.fetchone()
if db_haketilo_version != '3.0b1':
raise HaketiloException(_('err.unknown_db_schema'))
cursor.execute('PRAGMA FOREIGN_KEYS;')
if cursor.fetchall() == []:
raise HaketiloException(_('err.proxy.no_sqlite_foreign_keys'))
cursor.execute('PRAGMA FOREIGN_KEYS=ON;')
finally:
cursor.close()
def _populate_database_with_stuff_from_temporary_malcontent_dir(self) \
-> None:
malcontent_dir_path = self.store_dir / 'temporary_malcontent'
files_by_sha256_path = malcontent_dir_path / 'file' / 'sha256'
with self.cursor(transaction=True) as cursor:
for info_type in [item_infos.ResourceInfo, item_infos.MappingInfo]:
info: item_infos.AnyInfo
for info, definition in read_items(
malcontent_dir_path,
info_type # type: ignore
):
repo_iteration_id = get_or_make_repo_iteration(
cursor,
info.repo
)
item_id = get_or_make_item(
cursor,
info.type_name,
info.identifier
)
item_version_id = get_or_make_item_version(
cursor,
item_id,
repo_iteration_id,
info.version,
definition
)
if info_type is item_infos.MappingInfo:
make_mapping_status(cursor, item_id)
file_ids_bytes = {}
file_specifiers = [*info.source_copyright]
if isinstance(info, item_infos.ResourceInfo):
file_specifiers.extend(info.scripts)
for file_spec in file_specifiers:
file_path = files_by_sha256_path / file_spec.sha256
file_bytes = file_path.read_bytes()
sha256 = hashlib.sha256(file_bytes).digest().hex()
assert sha256 == file_spec.sha256
file_id = get_or_make_file(cursor, sha256, file_bytes)
file_ids_bytes[sha256] = (file_id, file_bytes)
for idx, file_spec in enumerate(info.source_copyright):
file_id, file_bytes = file_ids_bytes[file_spec.sha256]
if file_bytes.isascii():
mime = 'text/plain'
else:
mime = 'application/octet-stream'
make_file_use(
cursor,
item_version_id = item_version_id,
file_id = file_id,
name = file_spec.name,
type = 'L',
mime_type = mime,
idx = idx
)
if isinstance(info, item_infos.MappingInfo):
continue
for idx, file_spec in enumerate(info.scripts):
file_id, _ = file_ids_bytes[file_spec.sha256]
make_file_use(
cursor,
item_version_id = item_version_id,
file_id = file_id,
name = file_spec.name,
type = 'W',
mime_type = 'application/javascript',
idx = idx
)
def rebuild_structures(self, cursor: sqlite3.Cursor) -> None:
assert self.connection.in_transaction
resources = get_infos_of_type(cursor, item_infos.ResourceInfo)
mappings = get_infos_of_type(cursor, item_infos.MappingInfo)
payloads = compute_payloads(resources.keys(), mappings.keys())
payloads_data = {}
cursor.execute('DELETE FROM payloads;')
for mapping_info, by_pattern in payloads.items():
for num, (pattern, payload) in enumerate(by_pattern.items()):
print('adding payload')
cursor.execute(
'''
INSERT INTO payloads(
mapping_item_id,
pattern,
eval_allowed,
cors_bypass_allowed
)
VALUES (?, ?, ?, ?);
''',
(
mappings[mapping_info],
pattern.orig_url,
payload.allows_eval,
payload.allows_cors_bypass
)
)
cursor.execute(
'''
SELECT
payload_id
FROM
payloads
WHERE
mapping_item_id = ? AND pattern = ?;
''',
(mappings[mapping_info], pattern.orig_url)
)
(payload_id_int,), = cursor.fetchall()
for res_num, resource_info in enumerate(payload.resources):
cursor.execute(
'''
INSERT INTO resolved_depended_resources(
payload_id,
resource_item_id,
idx
)
VALUES(?, ?, ?);
''',
(payload_id_int, resources[resource_info], res_num)
)
payload_id = str(payload_id_int)
ref = ConcretePayloadRef(payload_id, payload)
data = st.PayloadData(
payload_ref = ref,
mapping_installed = True,
explicitly_enabled = True,
unique_token = secrets.token_urlsafe(16),
pattern = pattern,
eval_allowed = payload.allows_eval,
cors_bypass_allowed = payload.allows_cors_bypass
)
payloads_data[payload_id] = data
key = st.PayloadKey(
payload_ref = ref,
mapping_identifier = mapping_info.identifier,
pattern = pattern
)
self.policy_tree = register_payload(
self.policy_tree,
key,
data.unique_token
)
self.payloads_data = payloads_data
def get_repo(self, repo_id: str) -> st.RepoRef:
return ConcreteRepoRef(repo_id)
def get_repo_iteration(self, repo_iteration_id: str) -> st.RepoIterationRef:
return ConcreteRepoIterationRef(repo_iteration_id)
def get_mapping(self, mapping_id: str) -> st.MappingRef:
return ConcreteMappingRef(mapping_id)
def get_mapping_version(self, mapping_version_id: str) \
-> st.MappingVersionRef:
return ConcreteMappingVersionRef(mapping_version_id)
def get_resource(self, resource_id: str) -> st.ResourceRef:
return ConcreteResourceRef(resource_id)
def get_resource_version(self, resource_version_id: str) \
-> st.ResourceVersionRef:
return ConcreteResourceVersionRef(resource_version_id)
def get_payload(self, payload_id: str) -> st.PayloadRef:
return 'not implemented'
def add_repo(self, name: t.Optional[str], url: t.Optional[str]) \
-> st.RepoRef:
raise NotImplementedError()
def get_settings(self) -> st.HaketiloGlobalSettings:
return st.HaketiloGlobalSettings(
mapping_use_mode = st.MappingUseMode.AUTO,
default_allow_scripts = True,
repo_refresh_seconds = 0
)
def update_settings(
self,
*,
mapping_use_mode: t.Optional[st.MappingUseMode] = None,
default_allow_scripts: t.Optional[bool] = None,
repo_refresh_seconds: t.Optional[int] = None
) -> None:
raise NotImplementedError()
def select_policy(self, url: url_patterns.ParsedUrl) -> policies.Policy:
"""...."""
with self.lock:
policy_tree = self.policy_tree
try:
best_priority: int = 0
best_policy: t.Optional[policies.Policy] = None
for factories_set in policy_tree.search(url):
for stored_factory in sorted(factories_set):
factory = stored_factory.item
policy = factory.make_policy(self)
if policy.priority > best_priority:
best_priority = policy.priority
best_policy = policy
except Exception as e:
return policies.ErrorBlockPolicy(
builtin = True,
error = e
)
if best_policy is not None:
return best_policy
if self.get_settings().default_allow_scripts:
return policies.FallbackAllowPolicy()
else:
return policies.FallbackBlockPolicy()
@staticmethod
def make(store_dir: Path) -> 'ConcreteHaketiloState':
return ConcreteHaketiloState(
store_dir = store_dir,
connection = sqlite3.connect(str(store_dir / 'sqlite3.db'))
)