aboutsummaryrefslogtreecommitdiff
# SPDX-License-Identifier: GPL-3.0-or-later

# Base defintions for policies for altering HTTP requests.
#
# This file is part of Hydrilla&Haketilo.
#
# Copyright (C) 2022 Wojtek Kosior
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
#
# I, Wojtek Kosior, thereby promise not to sue for violation of this
# file's license. Although I request that you do not make use of this
# code in a proprietary program, I am not going to enforce this in
# court.

"""
.....
"""

import enum
import re
import threading
import dataclasses as dc
import typing as t

from abc import ABC, abstractmethod
from hashlib import sha256
from base64 import b64encode

import jinja2

from immutables import Map

from ... import translations
from ... import url_patterns
from ... import common_jinja_templates
from .. import state
from .. import http_messages
from .. import csp


_info_loader = jinja2.PackageLoader(
    __package__,
    package_path = 'info_pages_templates'
)
_combined_loader = common_jinja_templates.combine_with_loaders([_info_loader])
_jinja_info_env = jinja2.Environment(
    loader        = _combined_loader,
    autoescape    = jinja2.select_autoescape(['html.jinja']),
    lstrip_blocks = True,
    extensions    = ['jinja2.ext.i18n', 'jinja2.ext.do']
)
_jinja_info_env.globals['url_patterns'] = url_patterns
_jinja_info_lock = threading.Lock()


_jinja_script_loader = jinja2.PackageLoader(
    __package__,
    package_path = 'injectable_scripts'
)
_jinja_script_env = jinja2.Environment(
    loader        = _jinja_script_loader,
    autoescape    = False,
    lstrip_blocks = True,
    extensions    = ['jinja2.ext.do']
)
_jinja_script_lock = threading.Lock()

def get_script_template(template_file_name: str) -> jinja2.Template:
    with _jinja_script_lock:
        return _jinja_script_env.get_template(template_file_name)


response_work_data = threading.local()

def response_nonce() -> str:
    """
    When called multiple times during consume_response(), each time returns the
    same unpredictable string unique to this response. The string is used as a
    nonce for script elements.
    """
    return response_work_data.nonce


class PolicyPriority(int, enum.Enum):
    """...."""
    _ONE   = 1
    _TWO   = 2
    _THREE = 3


class MsgProcessOpt(enum.Enum):
    """...."""
    MUST     = True
    MUST_NOT = False


MessageInfo = t.Union[
    http_messages.RequestInfo,
    http_messages.ResponseInfo
]


# We're doing *very* simple doctype matching for now. If a site wanted, it could
# trick us into getting this wrong.
doctype_re = re.compile(r'^\s*<!doctype[^>]*>', re.IGNORECASE)


UTF8_BOM = b'\xEF\xBB\xBF'
BOMs = (
    (UTF8_BOM,    'utf-8'),
    (b'\xFE\xFF', 'utf-16be'),
    (b'\xFF\xFE', 'utf-16le')
)


# mypy needs to be corrected:
# https://stackoverflow.com/questions/70999513/conflict-between-mix-ins-for-abstract-dataclasses/70999704#70999704
@dc.dataclass(frozen=True) # type: ignore[misc]
class Policy(ABC):
    _process_request:  t.ClassVar[t.Optional[MsgProcessOpt]] = None
    _process_response: t.ClassVar[t.Optional[MsgProcessOpt]] = None
    anticache:         t.ClassVar[bool] = True

    priority: t.ClassVar[PolicyPriority]

    haketilo_settings: state.HaketiloGlobalSettings

    @property
    def current_popup_settings(self) -> state.PopupSettings:
        return self.haketilo_settings.default_popup_jsallowed

    def should_process_request(
            self,
            request_info: http_messages.BodylessRequestInfo
    ) -> bool:
        return self._process_request == MsgProcessOpt.MUST

    def should_process_response(
            self,
            request_info:  http_messages.RequestInfo,
            response_info: http_messages.AnyResponseInfo
    ) -> bool:
        if self._process_response is not None:
            return self._process_response.value

        return (self.current_popup_settings.popup_enabled and
                http_messages.is_likely_a_page(request_info, response_info))

    def _get_info_template(self, template_file_name: str) -> jinja2.Template:
        with _jinja_info_lock:
            chosen_locale = self.haketilo_settings.locale
            if chosen_locale not in translations.supported_locales:
                chosen_locale = None

            if chosen_locale is None:
                chosen_locale = translations.default_locale

            trans = translations.translation(chosen_locale)
            _jinja_info_env.install_gettext_translations(trans) # type: ignore
            return _jinja_info_env.get_template(template_file_name)


    def _csp_to_clear(self, http_info: http_messages.FullHTTPInfo) \
        -> t.Union[t.Sequence[str], t.Literal['all']]:
        return ()

    def _csp_to_add(self, http_info: http_messages.FullHTTPInfo) \
        -> t.Mapping[str, t.Sequence[str]]:
        return Map()

    def _csp_to_extend(self, http_info: http_messages.FullHTTPInfo) \
        -> t.Mapping[str, t.Sequence[str]]:
        if (self.current_popup_settings.popup_enabled and
            http_info.is_likely_a_page):
            nonce_source = f"'nonce-{response_nonce()}'"
            directives = (
                'script-src',
                'script-src-elem',
                'style-src',
                'frame-src'
            )
            return dict((directive, [nonce_source]) for directive in directives)
        else:
            return Map()

    def _modify_response_headers(self, http_info: http_messages.FullHTTPInfo) \
        -> http_messages.IHeaders:
        csp_to_clear  = self._csp_to_clear(http_info)
        csp_to_add    = self._csp_to_add(http_info)
        csp_to_extend = self._csp_to_extend(http_info)

        if len(csp_to_clear) + len(csp_to_extend) + len(csp_to_add) == 0:
            return http_info.response_info.headers

        return csp.modify(
            headers = http_info.response_info.headers,
            clear   = csp_to_clear,
            add     = csp_to_add,
            extend  = csp_to_extend
        )

    def _modify_response_document(
            self,
            http_info: http_messages.FullHTTPInfo,
            encoding:  t.Optional[str]
    ) -> t.Union[str, bytes]:
        popup_settings = self.current_popup_settings

        if popup_settings.popup_enabled:
            nonce = response_nonce()

            popup_page = self.make_info_page(http_info)
            if popup_page is None:
                template = self._get_info_template(
                    'special_page_info.html.jinja'
                )
                popup_page = template.render(
                    url = http_info.request_info.url.orig_url
                )

            template = get_script_template('popup.js.jinja')
            popup_script = template.render(
                popup_page_b64 = b64encode(popup_page.encode()).decode(),
                nonce_b64      = b64encode(nonce.encode()).decode(),
                # TODO: add an option to configure popup style in the web UI.
                # Then start passing the real style value.
                #popup_style    = popup_settings.style.value
                popup_style    = 'D'
            )

            if encoding is None:
                encoding = 'utf-8'

            body_bytes = http_info.response_info.body
            body = body_bytes.decode(encoding, errors='replace')

            match = doctype_re.match(body)
            doctype_decl_len = 0 if match is None else match.end()

            dotype_decl = body[0:doctype_decl_len]
            doc_rest = body[doctype_decl_len:]
            script_tag = f'<script nonce="{nonce}">{popup_script}</script>'

            return dotype_decl + script_tag + doc_rest
        else:
            return http_info.response_info.body

    def _modify_response_body(self, http_info: http_messages.FullHTTPInfo) \
        -> bytes:
        if not http_info.is_likely_a_page:
            return http_info.response_info.body

        data = http_info.response_info.body

        _, encoding = http_info.response_info.deduce_content_type()

        # A UTF BOM overrides encoding specified by the header.
        for bom, encoding_name in BOMs:
            if data.startswith(bom):
                encoding = encoding_name

        new_data = self._modify_response_document(http_info, encoding)

        if isinstance(new_data, str):
            # Appending a three-byte Byte Order Mark (BOM) will force the
            # browser to decode this as UTF-8 regardless of the 'Content-Type'
            # header. See
            # https://www.w3.org/International/tests/repository/html5/the-input-byte-stream/results-basics#precedence
            new_data = UTF8_BOM + new_data.encode()

        return new_data

    def consume_request(self, request_info: http_messages.RequestInfo) \
        -> t.Optional[MessageInfo]:
        # We're not using @abstractmethod because not every Policy needs it and
        # we don't want to force child classes into implementing dummy methods.
        raise NotImplementedError(
            'This kind of policy does not consume requests.'
        )

    def consume_response(self, http_info: http_messages.FullHTTPInfo) \
        -> t.Optional[http_messages.ResponseInfo]:
        try:
            new_headers = self._modify_response_headers(http_info)
            new_body    = self._modify_response_body(http_info)
        except Exception as e:
            # In the future we might want to actually describe eventual errors.
            # For now, we're just printing the stack trace.
            import traceback

            error_info_list = traceback.format_exception(
                type(e),
                e,
                e.__traceback__
            )

            return http_messages.ResponseInfo.make(
                status_code = 500,
                headers     = (('Content-Type', 'text/plain; charset=utf-8'),),
                body        = '\n'.join(error_info_list).encode()
            )

        if (new_headers is http_info.response_info.headers and
            new_body is http_info.response_info.body):
            return None

        return dc.replace(
            http_info.response_info,
            headers = new_headers,
            body    = new_body
        )

    def make_info_page(self, http_info: http_messages.FullHTTPInfo) \
        -> t.Optional[str]:
        return None


@dc.dataclass(frozen=True, unsafe_hash=True) # type: ignore[misc]
class PolicyFactory(ABC):
    """...."""
    builtin: bool

    @abstractmethod
    def make_policy(self, haketilo_state: state.HaketiloState) \
        -> t.Optional[Policy]:
        """...."""
        ...

    def __lt__(self, other: 'PolicyFactory'):
        """...."""
        return sorting_keys.get(self.__class__.__name__, 999) < \
            sorting_keys.get(other.__class__.__name__, 999)

sorting_order = (
    'WebUIMainPolicyFactory',
    'WebUILandingPolicyFactory',

    'MitmItPagePolicyFactory',

    'PayloadResourcePolicyFactory',

    'PayloadPolicyFactory',

    'RuleBlockPolicyFactory',
    'RuleAllowPolicyFactory',

    'FallbackPolicyFactory'
)

sorting_keys = Map((cls, name) for name, cls in enumerate(sorting_order))