aboutsummaryrefslogtreecommitdiff
# SPDX-License-Identifier: GPL-3.0-or-later

# Classes/protocols for representing HTTP requests and responses data.
#
# This file is part of Hydrilla&Haketilo.
#
# Copyright (C) 2022 Wojtek Kosior
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
#
# I, Wojtek Kosior, thereby promise not to sue for violation of this
# file's license. Although I request that you do not make use of this
# code in a proprietary program, I am not going to enforce this in
# court.

"""
.....
"""

import re
import cgi
import dataclasses as dc
import typing as t
import sys

if sys.version_info >= (3, 8):
    from typing import Protocol
else:
    from typing_extensions import Protocol

import mitmproxy.http

from .. import url_patterns


DefaultGetValue = t.TypeVar('DefaultGetValue', str, None)

class _MitmproxyHeadersWrapper():
    def __init__(self, headers: mitmproxy.http.Headers) -> None:
        self.headers = headers

    __getitem__ = lambda self, key: self.headers[key]
    get_all     = lambda self, key: self.headers.get_all(key)

    @t.overload
    def get(self, key: str) -> t.Optional[str]:
        ...
    @t.overload
    def get(self, key: str, default: DefaultGetValue) \
        -> t.Union[str, DefaultGetValue]:
        ...
    def get(self, key, default = None):
        value = self.headers.get(key)

        if value is None:
            return default
        else:
            return t.cast(str, value)

    def items(self) -> t.Iterable[tuple[str, str]]:
        return self.headers.items(multi=True)

    def items_bin(self) -> t.Iterable[tuple[bytes, bytes]]:
        return tuple((key.encode(), val.encode()) for key, val in self.items())

class IHeaders(Protocol):
    def __getitem__(self, key: str) -> str:  ...

    def get_all(self, key: str) -> t.Iterable[str]: ...

    @t.overload
    def get(self, key: str) -> t.Optional[str]:
        ...
    @t.overload
    def get(self, key: str, default: DefaultGetValue) \
        -> t.Union[str, DefaultGetValue]:
        ...

    def items(self) -> t.Iterable[tuple[str, str]]: ...

    def items_bin(self) -> t.Iterable[tuple[bytes, bytes]]: ...

_AnyHeaders = t.Union[
    t.Iterable[tuple[bytes, bytes]],
    t.Iterable[tuple[str, str]],
    mitmproxy.http.Headers,
    IHeaders
]

def make_headers(headers: _AnyHeaders) -> IHeaders:
    if not isinstance(headers, mitmproxy.http.Headers):
        if isinstance(headers, t.Iterable):
            headers = tuple(headers)
            if not headers or isinstance(headers[0][0], str):
                headers = ((key.encode(), val.encode()) for key, val in headers)

            headers = mitmproxy.http.Headers(headers)
        else:
            # isinstance(headers, IHeaders)
            return headers

    return _MitmproxyHeadersWrapper(headers)


_AnyUrl = t.Union[str, url_patterns.ParsedUrl]

def make_parsed_url(url: t.Union[str, url_patterns.ParsedUrl]) \
    -> url_patterns.ParsedUrl:
    return url_patterns.parse_url(url) if isinstance(url, str) else url


@dc.dataclass(frozen=True)
class HasHeadersMixin:
    headers: IHeaders

    def deduce_content_type(self) -> tuple[t.Optional[str], t.Optional[str]]:
        content_type_header = self.headers.get('content-type')
        if content_type_header is None:
            return (None, None)

        mime, options = cgi.parse_header(content_type_header)

        encoding = options.get('charset')
        if encoding is not None:
            encoding = encoding.lower()

        return mime, encoding


@dc.dataclass(frozen=True)
class _BaseRequestInfoFields:
    url:         url_patterns.ParsedUrl
    method:      str
    headers:     IHeaders

@dc.dataclass(frozen=True)
class BodylessRequestInfo(HasHeadersMixin, _BaseRequestInfoFields):
    def with_body(self, body: bytes) -> 'RequestInfo':
        return RequestInfo(self.url, self.method, self.headers, body)

    @staticmethod
    def make(
            url:     t.Union[str, url_patterns.ParsedUrl],
            method:  str,
            headers: _AnyHeaders
    ) -> 'BodylessRequestInfo':
        url = make_parsed_url(url)
        return BodylessRequestInfo(url, method, make_headers(headers))

@dc.dataclass(frozen=True)
class RequestInfo(HasHeadersMixin, _BaseRequestInfoFields):
    body: bytes

    @staticmethod
    def make(
            url:     _AnyUrl     = url_patterns.dummy_url,
            method:  str         = 'GET',
            headers: _AnyHeaders = (),
            body:    bytes       = b''
    ) -> 'RequestInfo':
        return BodylessRequestInfo.make(url, method, headers).with_body(body)

AnyRequestInfo = t.Union[BodylessRequestInfo, RequestInfo]


@dc.dataclass(frozen=True)
class _BaseResponseInfoFields:
    url:         url_patterns.ParsedUrl
    status_code: int
    headers:     IHeaders

@dc.dataclass(frozen=True)
class BodylessResponseInfo(HasHeadersMixin, _BaseResponseInfoFields):
    def with_body(self, body: bytes) -> 'ResponseInfo':
        return ResponseInfo(self.url, self.status_code, self.headers, body)

    @staticmethod
    def make(
            url:         t.Union[str, url_patterns.ParsedUrl],
            status_code: int,
            headers:     _AnyHeaders
    ) -> 'BodylessResponseInfo':
        url = make_parsed_url(url)
        return BodylessResponseInfo(url, status_code, make_headers(headers))

@dc.dataclass(frozen=True)
class ResponseInfo(HasHeadersMixin, _BaseResponseInfoFields):
    body: bytes

    @staticmethod
    def make(
            url:         _AnyUrl     = url_patterns.dummy_url,
            status_code: int         = 404,
            headers:     _AnyHeaders = (),
            body:        bytes       = b''
    ) -> 'ResponseInfo':
        bl_info = BodylessResponseInfo.make(url, status_code, headers)
        return bl_info.with_body(body)

AnyResponseInfo = t.Union[BodylessResponseInfo, ResponseInfo]


def is_likely_a_page(
        request_info:  AnyRequestInfo,
        response_info: AnyResponseInfo
) -> bool:
    fetch_dest = request_info.headers.get('sec-fetch-dest')
    if fetch_dest is None:
        if 'html' in request_info.headers.get('accept', ''):
            fetch_dest = 'document'
        else:
            fetch_dest = 'unknown'

    if fetch_dest not in ('document', 'iframe', 'frame', 'embed', 'object'):
        return False

    mime, encoding = response_info.deduce_content_type()

    # Right now out of all response headers we're only taking Content-Type into
    # account. In the future we might also want to consider the
    # Content-Disposition header.
    return mime is not None and 'html' in mime


@dc.dataclass(frozen=True)
class FullHTTPInfo:
    request_info:  RequestInfo
    response_info: ResponseInfo

    @property
    def is_likely_a_page(self) -> bool:
        return is_likely_a_page(self.request_info, self.response_info)