aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/http_messages.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/http_messages.py')
-rw-r--r--src/hydrilla/proxy/http_messages.py202
1 files changed, 166 insertions, 36 deletions
diff --git a/src/hydrilla/proxy/http_messages.py b/src/hydrilla/proxy/http_messages.py
index dbf2c63..1bed103 100644
--- a/src/hydrilla/proxy/http_messages.py
+++ b/src/hydrilla/proxy/http_messages.py
@@ -29,6 +29,7 @@
.....
"""
+import re
import dataclasses as dc
import typing as t
import sys
@@ -38,13 +39,42 @@ if sys.version_info >= (3, 8):
else:
from typing_extensions import Protocol
+import mitmproxy.http
+
from .. import url_patterns
DefaultGetValue = t.TypeVar('DefaultGetValue', str, None)
+class _MitmproxyHeadersWrapper():
+ def __init__(self, headers: mitmproxy.http.Headers) -> None:
+ self.headers = headers
+
+ __getitem__ = lambda self, key: self.headers[key]
+ get_all = lambda self, key: self.headers.get_all(key)
+
+ @t.overload
+ def get(self, key: str) -> t.Optional[str]:
+ ...
+ @t.overload
+ def get(self, key: str, default: DefaultGetValue) \
+ -> t.Union[str, DefaultGetValue]:
+ ...
+ def get(self, key, default = None):
+ value = self.headers.get(key)
+
+ if value is None:
+ return default
+ else:
+ return t.cast(str, value)
+
+ def items(self) -> t.Iterable[tuple[str, str]]:
+ return self.headers.items(multi=True)
+
+ def items_bin(self) -> t.Iterable[tuple[bytes, bytes]]:
+ return tuple((key.encode(), val.encode()) for key, val in self.items())
+
class IHeaders(Protocol):
- """...."""
def __getitem__(self, key: str) -> str: ...
def get_all(self, key: str) -> t.Iterable[str]: ...
@@ -59,65 +89,165 @@ class IHeaders(Protocol):
def items(self) -> t.Iterable[tuple[str, str]]: ...
-def encode_headers_items(headers: t.Iterable[tuple[str, str]]) \
- -> t.Iterable[tuple[bytes, bytes]]:
- """...."""
- for name, value in headers:
- yield name.encode(), value.encode()
+ def items_bin(self) -> t.Iterable[tuple[bytes, bytes]]: ...
+
+_AnyHeaders = t.Union[
+ t.Iterable[tuple[bytes, bytes]],
+ t.Iterable[tuple[str, str]],
+ mitmproxy.http.Headers,
+ IHeaders
+]
+
+def make_headers(headers: _AnyHeaders) -> IHeaders:
+ if not isinstance(headers, mitmproxy.http.Headers):
+ if isinstance(headers, t.Iterable):
+ headers = tuple(headers)
+ if not headers or isinstance(headers[0][0], str):
+ headers = ((key.encode(), val.encode()) for key, val in headers)
+
+ headers = mitmproxy.http.Headers(headers)
+ else:
+ # isinstance(headers, IHeaders)
+ return headers
+
+ return _MitmproxyHeadersWrapper(headers)
+
+
+_AnyUrl = t.Union[str, url_patterns.ParsedUrl]
+
+def make_parsed_url(url: t.Union[str, url_patterns.ParsedUrl]) \
+ -> url_patterns.ParsedUrl:
+ return url_patterns.parse_url(url) if isinstance(url, str) else url
+
+
+# For details of 'Content-Type' header's structure, see:
+# https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.1
+content_type_reg = re.compile(r'''
+^
+(?P<mime>[\w-]+/[\w-]+)
+\s*
+(?:
+ ;
+ (?:[^;]*;)* # match possible parameter other than "charset"
+)
+\s*
+charset= # no whitespace allowed in parameter as per RFC
+(?P<encoding>
+ [\w-]+
+ |
+ "[\w-]+" # quotes are optional per RFC
+)
+(?:;[^;]+)* # match possible parameter other than "charset"
+$ # forbid possible dangling characters after closing '"'
+''', re.VERBOSE | re.IGNORECASE)
@dc.dataclass(frozen=True)
-class ProducedRequest:
- """...."""
- url: str
- method: str
- headers: t.Iterable[tuple[bytes, bytes]]
- body: bytes
+class HasHeadersMixin:
+ headers: IHeaders
+
+ def deduce_content_type(self) -> tuple[t.Optional[str], t.Optional[str]]:
+ content_type = self.headers.get('content-type')
+ if content_type is None:
+ return (None, None)
+
+ match = content_type_reg.match(content_type)
+ if match is None:
+ return (None, None)
+
+ mime, encoding = match.group('mime'), match.group('encoding')
+
+ if encoding is not None:
+ encoding = encoding.lower()
+
+ return mime, encoding
+
@dc.dataclass(frozen=True)
-class BodylessRequestInfo:
+class _BaseRequestInfoFields:
url: url_patterns.ParsedUrl
method: str
headers: IHeaders
+@dc.dataclass(frozen=True)
+class BodylessRequestInfo(HasHeadersMixin, _BaseRequestInfoFields):
def with_body(self, body: bytes) -> 'RequestInfo':
return RequestInfo(self.url, self.method, self.headers, body)
+ @staticmethod
+ def make(
+ url: t.Union[str, url_patterns.ParsedUrl],
+ method: str,
+ headers: _AnyHeaders
+ ) -> 'BodylessRequestInfo':
+ url = make_parsed_url(url)
+ return BodylessRequestInfo(url, method, make_headers(headers))
+
@dc.dataclass(frozen=True)
-class RequestInfo(BodylessRequestInfo):
+class RequestInfo(HasHeadersMixin, _BaseRequestInfoFields):
body: bytes
- def make_produced_request(self) -> ProducedRequest:
- return ProducedRequest(
- url = self.url.orig_url,
- method = self.method,
- headers = encode_headers_items(self.headers.items()),
- body = self.body
- )
+ @staticmethod
+ def make(
+ url: _AnyUrl = url_patterns.dummy_url,
+ method: str = 'GET',
+ headers: _AnyHeaders = (),
+ body: bytes = b''
+ ) -> 'RequestInfo':
+ return BodylessRequestInfo.make(url, method, headers).with_body(body)
-@dc.dataclass(frozen=True)
-class ProducedResponse:
- """...."""
- status_code: int
- headers: t.Iterable[tuple[bytes, bytes]]
- body: bytes
@dc.dataclass(frozen=True)
-class BodylessResponseInfo:
- """...."""
+class _BaseResponseInfoFields:
url: url_patterns.ParsedUrl
status_code: int
headers: IHeaders
+@dc.dataclass(frozen=True)
+class BodylessResponseInfo(HasHeadersMixin, _BaseResponseInfoFields):
def with_body(self, body: bytes) -> 'ResponseInfo':
return ResponseInfo(self.url, self.status_code, self.headers, body)
+ @staticmethod
+ def make(
+ url: t.Union[str, url_patterns.ParsedUrl],
+ status_code: int,
+ headers: _AnyHeaders
+ ) -> 'BodylessResponseInfo':
+ url = make_parsed_url(url)
+ return BodylessResponseInfo(url, status_code, make_headers(headers))
+
@dc.dataclass(frozen=True)
-class ResponseInfo(BodylessResponseInfo):
+class ResponseInfo(HasHeadersMixin, _BaseResponseInfoFields):
body: bytes
- def make_produced_response(self) -> ProducedResponse:
- return ProducedResponse(
- status_code = self.status_code,
- headers = encode_headers_items(self.headers.items()),
- body = self.body
- )
+ @staticmethod
+ def make(
+ url: _AnyUrl = url_patterns.dummy_url,
+ status_code: int = 404,
+ headers: _AnyHeaders = (),
+ body: bytes = b''
+ ) -> 'ResponseInfo':
+ bl_info = BodylessResponseInfo.make(url, status_code, headers)
+ return bl_info.with_body(body)
+
+
+def is_likely_a_page(
+ request_info: t.Union[BodylessRequestInfo, RequestInfo],
+ response_info: t.Union[BodylessResponseInfo, ResponseInfo]
+) -> bool:
+ fetch_dest = request_info.headers.get('sec-fetch-dest')
+ if fetch_dest is None:
+ if 'html' in request_info.headers.get('accept', ''):
+ fetch_dest = 'document'
+ else:
+ fetch_dest = 'unknown'
+
+ if fetch_dest not in ('document', 'iframe', 'frame', 'embed', 'object'):
+ return False
+
+ mime, encoding = response_info.deduce_content_type()
+
+ # Right now out of all response headers we're only taking Content-Type into
+ # account. In the future we might also want to consider the
+ # Content-Disposition header.
+ return mime is not None and 'html' in mime