diff options
Diffstat (limited to 'src/hydrilla/proxy/http_messages.py')
-rw-r--r-- | src/hydrilla/proxy/http_messages.py | 202 |
1 files changed, 166 insertions, 36 deletions
diff --git a/src/hydrilla/proxy/http_messages.py b/src/hydrilla/proxy/http_messages.py index dbf2c63..1bed103 100644 --- a/src/hydrilla/proxy/http_messages.py +++ b/src/hydrilla/proxy/http_messages.py @@ -29,6 +29,7 @@ ..... """ +import re import dataclasses as dc import typing as t import sys @@ -38,13 +39,42 @@ if sys.version_info >= (3, 8): else: from typing_extensions import Protocol +import mitmproxy.http + from .. import url_patterns DefaultGetValue = t.TypeVar('DefaultGetValue', str, None) +class _MitmproxyHeadersWrapper(): + def __init__(self, headers: mitmproxy.http.Headers) -> None: + self.headers = headers + + __getitem__ = lambda self, key: self.headers[key] + get_all = lambda self, key: self.headers.get_all(key) + + @t.overload + def get(self, key: str) -> t.Optional[str]: + ... + @t.overload + def get(self, key: str, default: DefaultGetValue) \ + -> t.Union[str, DefaultGetValue]: + ... + def get(self, key, default = None): + value = self.headers.get(key) + + if value is None: + return default + else: + return t.cast(str, value) + + def items(self) -> t.Iterable[tuple[str, str]]: + return self.headers.items(multi=True) + + def items_bin(self) -> t.Iterable[tuple[bytes, bytes]]: + return tuple((key.encode(), val.encode()) for key, val in self.items()) + class IHeaders(Protocol): - """....""" def __getitem__(self, key: str) -> str: ... def get_all(self, key: str) -> t.Iterable[str]: ... @@ -59,65 +89,165 @@ class IHeaders(Protocol): def items(self) -> t.Iterable[tuple[str, str]]: ... -def encode_headers_items(headers: t.Iterable[tuple[str, str]]) \ - -> t.Iterable[tuple[bytes, bytes]]: - """....""" - for name, value in headers: - yield name.encode(), value.encode() + def items_bin(self) -> t.Iterable[tuple[bytes, bytes]]: ... + +_AnyHeaders = t.Union[ + t.Iterable[tuple[bytes, bytes]], + t.Iterable[tuple[str, str]], + mitmproxy.http.Headers, + IHeaders +] + +def make_headers(headers: _AnyHeaders) -> IHeaders: + if not isinstance(headers, mitmproxy.http.Headers): + if isinstance(headers, t.Iterable): + headers = tuple(headers) + if not headers or isinstance(headers[0][0], str): + headers = ((key.encode(), val.encode()) for key, val in headers) + + headers = mitmproxy.http.Headers(headers) + else: + # isinstance(headers, IHeaders) + return headers + + return _MitmproxyHeadersWrapper(headers) + + +_AnyUrl = t.Union[str, url_patterns.ParsedUrl] + +def make_parsed_url(url: t.Union[str, url_patterns.ParsedUrl]) \ + -> url_patterns.ParsedUrl: + return url_patterns.parse_url(url) if isinstance(url, str) else url + + +# For details of 'Content-Type' header's structure, see: +# https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.1 +content_type_reg = re.compile(r''' +^ +(?P<mime>[\w-]+/[\w-]+) +\s* +(?: + ; + (?:[^;]*;)* # match possible parameter other than "charset" +) +\s* +charset= # no whitespace allowed in parameter as per RFC +(?P<encoding> + [\w-]+ + | + "[\w-]+" # quotes are optional per RFC +) +(?:;[^;]+)* # match possible parameter other than "charset" +$ # forbid possible dangling characters after closing '"' +''', re.VERBOSE | re.IGNORECASE) @dc.dataclass(frozen=True) -class ProducedRequest: - """....""" - url: str - method: str - headers: t.Iterable[tuple[bytes, bytes]] - body: bytes +class HasHeadersMixin: + headers: IHeaders + + def deduce_content_type(self) -> tuple[t.Optional[str], t.Optional[str]]: + content_type = self.headers.get('content-type') + if content_type is None: + return (None, None) + + match = content_type_reg.match(content_type) + if match is None: + return (None, None) + + mime, encoding = match.group('mime'), match.group('encoding') + + if encoding is not None: + encoding = encoding.lower() + + return mime, encoding + @dc.dataclass(frozen=True) -class BodylessRequestInfo: +class _BaseRequestInfoFields: url: url_patterns.ParsedUrl method: str headers: IHeaders +@dc.dataclass(frozen=True) +class BodylessRequestInfo(HasHeadersMixin, _BaseRequestInfoFields): def with_body(self, body: bytes) -> 'RequestInfo': return RequestInfo(self.url, self.method, self.headers, body) + @staticmethod + def make( + url: t.Union[str, url_patterns.ParsedUrl], + method: str, + headers: _AnyHeaders + ) -> 'BodylessRequestInfo': + url = make_parsed_url(url) + return BodylessRequestInfo(url, method, make_headers(headers)) + @dc.dataclass(frozen=True) -class RequestInfo(BodylessRequestInfo): +class RequestInfo(HasHeadersMixin, _BaseRequestInfoFields): body: bytes - def make_produced_request(self) -> ProducedRequest: - return ProducedRequest( - url = self.url.orig_url, - method = self.method, - headers = encode_headers_items(self.headers.items()), - body = self.body - ) + @staticmethod + def make( + url: _AnyUrl = url_patterns.dummy_url, + method: str = 'GET', + headers: _AnyHeaders = (), + body: bytes = b'' + ) -> 'RequestInfo': + return BodylessRequestInfo.make(url, method, headers).with_body(body) -@dc.dataclass(frozen=True) -class ProducedResponse: - """....""" - status_code: int - headers: t.Iterable[tuple[bytes, bytes]] - body: bytes @dc.dataclass(frozen=True) -class BodylessResponseInfo: - """....""" +class _BaseResponseInfoFields: url: url_patterns.ParsedUrl status_code: int headers: IHeaders +@dc.dataclass(frozen=True) +class BodylessResponseInfo(HasHeadersMixin, _BaseResponseInfoFields): def with_body(self, body: bytes) -> 'ResponseInfo': return ResponseInfo(self.url, self.status_code, self.headers, body) + @staticmethod + def make( + url: t.Union[str, url_patterns.ParsedUrl], + status_code: int, + headers: _AnyHeaders + ) -> 'BodylessResponseInfo': + url = make_parsed_url(url) + return BodylessResponseInfo(url, status_code, make_headers(headers)) + @dc.dataclass(frozen=True) -class ResponseInfo(BodylessResponseInfo): +class ResponseInfo(HasHeadersMixin, _BaseResponseInfoFields): body: bytes - def make_produced_response(self) -> ProducedResponse: - return ProducedResponse( - status_code = self.status_code, - headers = encode_headers_items(self.headers.items()), - body = self.body - ) + @staticmethod + def make( + url: _AnyUrl = url_patterns.dummy_url, + status_code: int = 404, + headers: _AnyHeaders = (), + body: bytes = b'' + ) -> 'ResponseInfo': + bl_info = BodylessResponseInfo.make(url, status_code, headers) + return bl_info.with_body(body) + + +def is_likely_a_page( + request_info: t.Union[BodylessRequestInfo, RequestInfo], + response_info: t.Union[BodylessResponseInfo, ResponseInfo] +) -> bool: + fetch_dest = request_info.headers.get('sec-fetch-dest') + if fetch_dest is None: + if 'html' in request_info.headers.get('accept', ''): + fetch_dest = 'document' + else: + fetch_dest = 'unknown' + + if fetch_dest not in ('document', 'iframe', 'frame', 'embed', 'object'): + return False + + mime, encoding = response_info.deduce_content_type() + + # Right now out of all response headers we're only taking Content-Type into + # account. In the future we might also want to consider the + # Content-Disposition header. + return mime is not None and 'html' in mime |