aboutsummaryrefslogtreecommitdiff
# SPDX-License-Identifier: GPL-3.0-or-later

# Haketilo proxy data and configuration (ResourceStore and MappingStore
# implementations).
#
# This file is part of Hydrilla&Haketilo.
#
# Copyright (C) 2022 Wojtek Kosior
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
#
# I, Wojtek Kosior, thereby promise not to sue for violation of this
# file's license. Although I request that you do not make use of this
# code in a proprietary program, I am not going to enforce this in
# court.

"""
This module provides an interface to interact with mappings, and resources
inside Haketilo.
"""

import sqlite3
import typing as t
import dataclasses as dc

from contextlib import contextmanager
from urllib.parse import urljoin

from ... import item_infos
from .. import state as st
from . import base


def _get_item_id(cursor: sqlite3.Cursor, item_type: str, identifier: str) \
    -> str:
    cursor.execute(
        'SELECT item_id FROM items WHERE identifier = ? AND type = ?;',
        (identifier, item_type)
    )

    rows = cursor.fetchall()

    if rows == []:
        raise st.MissingItemError()

    (item_id,), = rows

    return str(item_id)


def _get_parent_item_id(cursor: sqlite3.Cursor, version_id: str) -> str:
    cursor.execute(
        '''
        SELECT
                item_id
        FROM
                item_versions
        WHERE
                item_version_id = ?;
        ''',
        (version_id,)
    )

    rows = cursor.fetchall()
    if rows == []:
        raise st.MissingItemError()

    (item_id,), = rows

    return str(item_id)


def _set_installed_status(cursor: sqlite3.Cursor, id: str, new_status: str) \
    -> None:
    cursor.execute(
         'UPDATE item_versions SET installed = ? WHERE item_version_id = ?;',
        (new_status, id)
    )

def _get_statuses(cursor: sqlite3.Cursor, id: str) -> tuple[str, str]:
    cursor.execute(
        '''
        SELECT
                installed, active
        FROM
                item_versions
        WHERE
                item_version_id = ?;
        ''',
        (id,)
    )

    rows = cursor.fetchall()

    if rows == []:
        raise st.MissingItemError()

    (installed_status, active_status), = rows

    return installed_status, active_status

VersionRefVar = t.TypeVar(
    'VersionRefVar',
    'ConcreteResourceVersionRef',
    'ConcreteMappingVersionRef'
)

def _install_version(ref: VersionRefVar) -> None:
    with ref.state.cursor(transaction=True) as cursor:
        installed_status, _ = _get_statuses(cursor, ref.id)

        if installed_status == 'I':
            return

        _set_installed_status(cursor, ref.id, 'I')

        ref.state.pull_missing_files()

def _uninstall_version(ref: VersionRefVar) -> t.Optional[VersionRefVar]:
    with ref.state.cursor(transaction=True) as cursor:
        installed_status, active_status = _get_statuses(cursor, ref.id)

        if installed_status == 'N':
            return ref

        if active_status == 'R':
            return ref

        _set_installed_status(cursor, ref.id, 'N')

        ref.state.soft_prune_orphan_items()

        if active_status != 'N':
            ref.state.recompute_dependencies()

        cursor.execute(
            'SELECT COUNT(*) FROM item_versions WHERE item_version_id = ?;',
            (ref.id,)
        )

        (version_still_present,), = cursor.fetchall()
        return ref if version_still_present else None


def _get_file(ref: VersionRefVar, name: str, file_type: str = 'L') \
    -> st.FileData:
    with ref.state.cursor() as cursor:
        cursor.execute(
            '''
            SELECT
                    f.data, fu.mime_type
            FROM
                         item_versions AS iv
                    JOIN items         AS i  USING (item_id)
                    JOIN file_uses     AS fu USING (item_version_id)
                    JOIN files         AS f  USING (file_id)
            WHERE
                    (iv.item_version_id = ? AND iv.installed = 'I') AND
                    i.type = ?                                      AND
                    (fu.name = ? AND fu.type = ?)                   AND
                    f.data IS NOT NULL;
            ''',
            (ref.id, ref.type.value[0].upper(), name, file_type)
        )

        rows = cursor.fetchall()

    if rows == []:
        raise st.MissingItemError()

    (data, mime_type), = rows

    return st.FileData(mime_type, name, data)


def _get_upstream_file_url(
        ref:       VersionRefVar,
        name:      str,
        file_type: str = 'L'
) -> str:
    with ref.state.cursor() as cursor:
        cursor.execute(
            '''
            SELECT
                    f.sha256, r.url
            FROM
                         item_versions   AS iv
                    JOIN repo_iterations AS ri USING(repo_iteration_id)
                    JOIN repos           AS r  USING(repo_id)
                    JOIN file_uses       AS fu USING(item_version_id)
                    JOIN files           AS f  USING(file_id)
            WHERE
                    iv.item_version_id = ?        AND
                    (fu.name = ? AND fu.type = ?) AND
                    r.url IS NOT NULL;
            ''',
            (ref.id, name, file_type)
        )

        rows = cursor.fetchall()

    if rows == []:
        raise st.MissingItemError()

    (sha256, repo_url), = rows

    return urljoin(repo_url, f'file/sha256/{sha256}')


@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteMappingRef(st.MappingRef):
    state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False)

    def _get_status_data(self, cursor: sqlite3.Cursor) \
        -> tuple[str, str, int]:
        cursor.execute(
            '''
            SELECT
                    ms.enabled, ms.frozen, ms.active_version_id
            FROM
                    mapping_statuses
            WHERE
                    item_id = ?;
            ''',
            (self.id,)
        )

        rows = cursor.fetchall()

        if rows == []:
            raise st.MissingItemError()

        (enabled_status, frozen_status, active_version_id), = rows

        return (enabled_status, frozen_status, active_version_id)


    def update_status(
            self,
            enabled:                st.EnabledStatus,
            frozen:                 t.Optional[st.FrozenStatus] = None,
            version_id_to_activate: t.Optional[str]             = None
    ) -> None:
        assert frozen is None or enabled == st.EnabledStatus.ENABLED
        assert version_id_to_activate is None or \
            frozen != st.FrozenStatus.NOT_FROZEN

        with self.state.cursor(transaction=True) as cursor:
            cursor.execute(
                '''
                SELECT
                        enabled, frozen, active_version_id
                FROM
                        mapping_statuses
                WHERE
                        item_id = ?;
                ''',
                (self.id,)
            )

            rows = cursor.fetchall()

            if rows == []:
                raise st.MissingItemError()

            (old_enabled_status, old_frozen_status,
             old_active_version_id), = rows

            if enabled.value == old_enabled_status and frozen is None:
                return

            new_enabled_status = enabled.value

            new_frozen_status = None if frozen is None else frozen.value

            if version_id_to_activate is not None:
                new_active_version_id = version_id_to_activate
            elif enabled == st.EnabledStatus.ENABLED and \
                 old_active_version_id is not None:
                new_active_version_id = str(old_active_version_id)
            else:
                new_active_version_id = None

            cursor.execute(
                '''
                UPDATE
                        mapping_statuses
                SET
                        enabled           = ?,
                        frozen            = ?,
                        active_version_id = ?
                WHERE
                        item_id = ?;
                ''', (
                    new_enabled_status,
                    new_frozen_status,
                    new_active_version_id,
                    self.id
                ))

            if enabled == st.EnabledStatus.ENABLED:
                if old_enabled_status == 'E'                           and \
                   new_active_version_id == str(old_active_version_id) and \
                   (new_frozen_status == 'E' or
                    old_frozen_status == 'N' or
                    new_frozen_status == old_frozen_status):
                    return
            else:
                if old_active_version_id is None and old_enabled_status != 'D':
                    return

            self.state.recompute_dependencies([int(self.id)])

    def get_display_info(self) -> st.RichMappingDisplayInfo:
        with self.state.cursor() as cursor:
            cursor.execute(
                '''
                SELECT
                        i.identifier,
                        ms.enabled, ms.frozen
                FROM
                             items            AS i
                        JOIN mapping_statuses AS ms USING (item_id)
                WHERE
                        item_id = ?;
                ''',
                (self.id,)
            )

            rows = cursor.fetchall()

            if rows == []:
                raise st.MissingItemError()

            (identifier, enabled_status, frozen_status), = rows

            cursor.execute(
                '''
                SELECT
                        item_version_id,
                        definition,
                        repo,
                        repo_iteration,
                        installed,
                        active,
                        is_orphan,
                        is_local
                FROM
                        item_versions_extra
                WHERE
                        item_id = ?;
                ''',
                (self.id,)
            )

            rows = cursor.fetchall()

        version_infos = []

        active_info: t.Optional[st.MappingVersionDisplayInfo] = None

        for (item_version_id, definition, repo, repo_iteration,
             installed_status, active_status, is_orphan, is_local) in rows:
            ref = ConcreteMappingVersionRef(str(item_version_id), self.state)

            item_info = item_infos.MappingInfo.load(
                definition,
                repo,
                repo_iteration
            )

            version_display_info = st.MappingVersionDisplayInfo(
                ref             = ref,
                info            = item_info,
                installed       = st.InstalledStatus(installed_status),
                active          = st.ActiveStatus(active_status),
                is_orphan       = is_orphan,
                is_local        = is_local
            )

            version_infos.append(version_display_info)

            if active_status in ('R', 'A'):
                active_info = version_display_info

        return st.RichMappingDisplayInfo(
            ref            = self,
            identifier     = identifier,
            enabled        = st.EnabledStatus(enabled_status),
            frozen         = st.FrozenStatus.make(frozen_status),
            active_version = active_info,
            all_versions   = sorted(version_infos, key=(lambda vi: vi.info))
        )


@dc.dataclass(frozen=True)
class ConcreteMappingStore(st.MappingStore):
    state: base.HaketiloStateWithFields

    def get(self, id: str) -> st.MappingRef:
        return ConcreteMappingRef(str(int(id)), self.state)

    def get_display_infos(self) -> t.Sequence[st.MappingDisplayInfo]:
        with self.state.cursor() as cursor:
            cursor.execute(
                '''
                WITH available_item_ids AS (
                        SELECT DISTINCT item_id FROM item_versions
                )
                SELECT
                        i.item_id,
                        i.identifier,
                        ive.item_version_id,
                        ive.definition,
                        ive.repo,
                        ive.repo_iteration,
                        ive.installed,
                        ive.active,
                        ive.is_orphan,
                        ive.is_local,
                        ms.enabled,
                        ms.frozen
                FROM
                                  items                 AS i
                        JOIN      mapping_statuses      AS ms
                                USING (item_id)
                        LEFT JOIN item_versions_extra   AS ive
                                ON ms.active_version_id = ive.item_version_id
                WHERE
                        i.item_id IN available_item_ids;
                '''
            )

            rows = cursor.fetchall()

        result = []

        for (item_id, identifier, item_version_id, definition, repo,
             repo_iteration, installed_status, active_status, is_orphan,
             is_local, enabled_status, frozen_status) in rows:
            ref = ConcreteMappingRef(str(item_id), self.state)

            active_version: t.Optional[st.MappingVersionDisplayInfo] = None

            if item_version_id is not None:
                active_version_ref = ConcreteMappingVersionRef(
                    id    = str(item_version_id),
                    state = self.state
                )

                active_version_info = item_infos.MappingInfo.load(
                    definition,
                    repo,
                    repo_iteration
                )

                active_version = st.MappingVersionDisplayInfo(
                    ref             = active_version_ref,
                    info            = active_version_info,
                    installed       = st.InstalledStatus(installed_status),
                    active          = st.ActiveStatus(active_status),
                    is_orphan       = is_orphan,
                    is_local        = is_local
                )

            display_info = st.MappingDisplayInfo(
                ref            = ref,
                identifier     = identifier,
                enabled        = st.EnabledStatus(enabled_status),
                frozen         = st.FrozenStatus.make(frozen_status),
                active_version = active_version
            )

            result.append(display_info)

        return sorted(result, key=(lambda di: di.identifier))

    def get_by_identifier(self, identifier: str) -> st.MappingRef:
        with self.state.cursor() as cursor:
            item_id = _get_item_id(cursor, 'M', identifier)

        return ConcreteMappingRef(item_id, self.state)


@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteMappingVersionRef(st.MappingVersionRef):
    state: base.HaketiloStateWithFields

    def install(self) -> None:
        return _install_version(self)

    def uninstall(self) -> t.Optional['ConcreteMappingVersionRef']:
        return _uninstall_version(self)

    def ensure_depended_items_installed(self) -> None:
        with self.state.cursor(transaction=True) as cursor:
            cursor.execute(
                '''
                UPDATE
                        item_versions
                SET
                        installed = 'I'
                WHERE
                        item_version_id = ?;
                ''',
                (self.id,)
            )

            cursor.execute(
                '''
                WITH depended_resource_ids AS (
                        SELECT
                                rdd.resource_item_id
                        FROM
                                     payloads                    AS p
                                JOIN resolved_depended_resources AS rdd
                                        USING (payload_id)
                        WHERE
                                p.mapping_item_id = ?
                )
                UPDATE
                        item_versions
                SET
                        installed = 'I'
                WHERE
                        item_version_id IN depended_resource_ids;
                ''',
                (self.id,)
            )

            self.state.pull_missing_files()

    @contextmanager
    def _mapping_ref(self) -> t.Iterator[ConcreteMappingRef]:
        with self.state.cursor(transaction=True) as cursor:
            mapping_id = _get_parent_item_id(cursor, self.id)
            yield ConcreteMappingRef(mapping_id, self.state)

    def update_mapping_status(
            self,
            enabled: st.EnabledStatus,
            frozen:  t.Optional[st.FrozenStatus] = None
    ) -> None:
        with self._mapping_ref() as mapping_ref:
            id_to_pass: t.Optional[str] = self.id
            if enabled.value != 'E' or frozen is None or frozen.value == 'N':
                id_to_pass = None

            mapping_ref.update_status(enabled, frozen, id_to_pass)

    def get_license_file(self, name: str) -> st.FileData:
        return _get_file(self, name, 'L')

    def get_upstream_license_file_url(self, name: str) -> str:
        return _get_upstream_file_url(self, name, 'L')

    def get_required_mapping(self, identifier: str) \
        -> 'ConcreteMappingVersionRef':
        with self.state.cursor() as cursor:
            cursor.execute(
                '''
                SELECT
                        iv2.item_version_id
                FROM
                             item_versions              AS iv1
                        JOIN resolved_required_mappings AS rrm
                                ON iv1.item_version_id =
                                   rrm.requiring_mapping_id
                        JOIN item_versions              AS iv2
                                ON rrm.required_mapping_id =
                                   iv2.item_version_id
                        JOIN items                      AS i
                                ON iv2.item_id = i.item_id
                WHERE
                        iv1.item_version_id = ? AND
                        i.identifier = ?;
                ''',
                (self.id, identifier)
            )

            rows = cursor.fetchall()

        if rows == []:
            raise st.MissingItemError()

        (required_id,), = rows

        return ConcreteMappingVersionRef(str(required_id), self.state)

    def get_payload_resource(self, pattern: str, identifier: str) \
        -> 'ConcreteResourceVersionRef':
        with self.state.cursor() as cursor:
            cursor.execute(
                '''
                SELECT
                        iv.item_version_id
                FROM
                             payloads                    AS p
                        JOIN resolved_depended_resources AS rdd
                                USING(payload_id)
                        JOIN item_versions               AS iv
                                ON rdd.resource_item_id = iv.item_version_id
                        JOIN items                       AS i
                                USING (item_id)
                WHERE
                        (p.mapping_item_id = ? AND p.pattern = ?) AND
                        i.identifier = ?;
                ''',
                (self.id, pattern, identifier)
            )

            rows = cursor.fetchall()

        if rows == []:
            raise st.MissingItemError()

        (resource_ver_id,), = rows

        return ConcreteResourceVersionRef(str(resource_ver_id), self.state)

    def get_item_display_info(self) -> st.RichMappingDisplayInfo:
        with self._mapping_ref() as mapping_ref:
            return mapping_ref.get_display_info()


@dc.dataclass(frozen=True)
class ConcreteMappingVersionStore(st.MappingVersionStore):
    state: base.HaketiloStateWithFields

    def get(self, id: str) -> st.MappingVersionRef:
        return ConcreteMappingVersionRef(str(int(id)), self.state)


@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteResourceRef(st.ResourceRef):
    state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False)

    def get_display_info(self) -> st.RichResourceDisplayInfo:
        with self.state.cursor() as cursor:
            cursor.execute(
                'SELECT identifier FROM items WHERE item_id = ?;',
                (self.id,)
            )

            rows = cursor.fetchall()

            if rows == []:
                raise st.MissingItemError()

            (identifier,), = rows

            cursor.execute(
                '''
                SELECT
                        item_version_id,
                        definition,
                        repo,
                        repo_iteration,
                        installed,
                        active,
                        is_orphan,
                        is_local
                FROM
                        item_versions_extra
                WHERE
                        item_id = ?;
                ''',
                (self.id,)
            )

            rows = cursor.fetchall()

        version_infos = []

        for (item_version_id, definition, repo, repo_iteration,
             installed_status, active_status, is_orphan, is_local) in rows:
            ref = ConcreteResourceVersionRef(str(item_version_id), self.state)

            item_info = item_infos.ResourceInfo.load(
                definition,
                repo,
                repo_iteration
            )

            display_info = st.ResourceVersionDisplayInfo(
                ref             = ref,
                info            = item_info,
                installed       = st.InstalledStatus(installed_status),
                active          = st.ActiveStatus(active_status),
                is_orphan       = is_orphan,
                is_local        = is_local
            )

            version_infos.append(display_info)

        return st.RichResourceDisplayInfo(
            ref            = self,
            identifier     = identifier,
            all_versions   = sorted(version_infos, key=(lambda vi: vi.info))
        )


@dc.dataclass(frozen=True)
class ConcreteResourceStore(st.ResourceStore):
    state: base.HaketiloStateWithFields

    def get(self, id: str) -> st.ResourceRef:
        return ConcreteResourceRef(str(int(id)), self.state)

    def get_display_infos(self) -> t.Sequence[st.ResourceDisplayInfo]:
        with self.state.cursor() as cursor:
            cursor.execute(
                "SELECT item_id, identifier FROM items WHERE type = 'R';"
            )

            rows = cursor.fetchall()

        result = []

        for item_id, identifier in rows:
            ref = ConcreteResourceRef(str(item_id), self.state)

            result.append(st.ResourceDisplayInfo(ref, identifier))

        return sorted(result, key=(lambda di: di.identifier))

    def get_by_identifier(self, identifier: str) -> st.ResourceRef:
        with self.state.cursor() as cursor:
            item_id = _get_item_id(cursor, 'R', identifier)

        return ConcreteResourceRef(item_id, self.state)


@dc.dataclass(frozen=True, unsafe_hash=True)
class ConcreteResourceVersionRef(st.ResourceVersionRef):
    state: base.HaketiloStateWithFields

    def install(self) -> None:
        return _install_version(self)

    def uninstall(self) -> t.Optional['ConcreteResourceVersionRef']:
        return _uninstall_version(self)

    def get_license_file(self, name: str) -> st.FileData:
        return _get_file(self, name, 'L')

    def get_resource_file(self, name: str) -> st.FileData:
        return _get_file(self, name, 'W')

    def get_upstream_license_file_url(self, name: str) -> str:
        return _get_upstream_file_url(self, name, 'L')

    def get_upstream_resource_file_url(self, name: str) -> str:
        return _get_upstream_file_url(self, name, 'W')

    def get_dependency(self, identifier: str) -> st.ResourceVersionRef:
        with self.state.cursor() as cursor:
            cursor.execute(
                '''
                SELECT
                        iv.item_version_id
                FROM
                             resolved_depended_resources AS rdd1
                        JOIN payloads                    AS p
                                ON rdd1.payload_id = p.payload_id
                        JOIN resolved_depended_resources AS rdd2
                                ON p.payload_id = rdd2.payload_id
                        JOIN item_versions               AS iv
                                ON rdd2.resource_item_id = iv.item_version_id
                        JOIN items                       AS i
                                USING (item_id)
                WHERE
                        rdd1.resource_item_id = ? AND i.identifier = ?;
                ''',
                (self.id, identifier)
            )

            rows = cursor.fetchall()

        if rows == []:
            raise st.MissingItemError()

        (dep_id,), = rows

        return ConcreteResourceVersionRef(str(dep_id), self.state)

    def get_item_display_info(self) -> st.RichResourceDisplayInfo:
        with self.state.cursor() as cursor:
            resource_id = _get_parent_item_id(cursor, self.id)
            resource_ref = ConcreteResourceRef(resource_id, self.state)
            return resource_ref.get_display_info()


@dc.dataclass(frozen=True)
class ConcreteResourceVersionStore(st.ResourceVersionStore):
    state: base.HaketiloStateWithFields

    def get(self, id: str) -> st.ResourceVersionRef:
        return ConcreteResourceVersionRef(str(int(id)), self.state)