From 55b95d70b24cfa1b4703ae442a3c6d1781cc95aa Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Thu, 20 Oct 2022 21:57:47 +0200 Subject: [proxy] rework internal HTTP headers representation --- src/hydrilla/proxy/addon.py | 93 ++++------- src/hydrilla/proxy/http_messages.py | 202 +++++++++++++++++++----- src/hydrilla/proxy/policies/__init__.py | 2 +- src/hydrilla/proxy/policies/base.py | 15 +- src/hydrilla/proxy/policies/misc.py | 15 +- src/hydrilla/proxy/policies/payload.py | 94 +++-------- src/hydrilla/proxy/policies/payload_resource.py | 61 ++++--- src/hydrilla/proxy/policies/rule.py | 23 +-- src/hydrilla/proxy/policies/web_ui.py | 2 +- src/hydrilla/proxy/state_impl/concrete_state.py | 9 ++ src/hydrilla/proxy/web_ui/root.py | 4 +- src/hydrilla/url_patterns.py | 3 + 12 files changed, 305 insertions(+), 218 deletions(-) (limited to 'src') diff --git a/src/hydrilla/proxy/addon.py b/src/hydrilla/proxy/addon.py index ae03ecc..de864fc 100644 --- a/src/hydrilla/proxy/addon.py +++ b/src/hydrilla/proxy/addon.py @@ -46,49 +46,25 @@ from mitmproxy.script import concurrent from ..exceptions import HaketiloException from ..translations import smart_gettext as _ -from ..url_patterns import parse_url, ParsedUrl +from .. import url_patterns from .state_impl import ConcreteHaketiloState from . import state from . import policies from . import http_messages -DefaultGetValue = t.TypeVar('DefaultGetValue', str, None) - -class MitmproxyHeadersWrapper(): - """....""" - def __init__(self, headers: http.Headers) -> None: - """....""" - self.headers = headers - - __getitem__ = lambda self, key: self.headers[key] - get_all = lambda self, key: self.headers.get_all(key) - - @t.overload - def get(self, key: str) -> t.Optional[str]: - ... - @t.overload - def get(self, key: str, default: DefaultGetValue) \ - -> t.Union[str, DefaultGetValue]: - ... - def get(self, key, default = None): - value = self.headers.get(key) - - if value is None: - return default - else: - return t.cast(str, value) - - def items(self) -> t.Iterable[tuple[str, str]]: - """....""" - return self.headers.items(multi=True) - - class LoggerToMitmproxy(state.Logger): def warn(self, msg: str) -> None: ctx.log.warn(f'Haketilo: {msg}') +def safe_parse_url(url: str) -> url_patterns.ParsedUrl: + try: + return url_patterns.parse_url(url) + except url_patterns.HaketiloURLException: + return url_patterns.dummy_url + + @dc.dataclass class FlowHandling: flow: http.HTTPFlow @@ -114,11 +90,10 @@ class FlowHandling: if self._bl_response_info is None: assert self.flow.response is not None - headers = self.flow.response.headers - self._bl_response_info = http_messages.BodylessResponseInfo( - url = parse_url(self.flow.request.url), + self._bl_response_info = http_messages.BodylessResponseInfo.make( + url = safe_parse_url(self.flow.request.url), status_code = self.flow.response.status_code, - headers = MitmproxyHeadersWrapper(headers) + headers = self.flow.response.headers ) return self._bl_response_info @@ -131,12 +106,15 @@ class FlowHandling: return self.bl_response_info.with_body(body) @staticmethod - def make(flow: http.HTTPFlow, policy: policies.Policy, url: ParsedUrl) \ - -> 'FlowHandling': - bl_request_info = http_messages.BodylessRequestInfo( + def make( + flow: http.HTTPFlow, + policy: policies.Policy, + url: url_patterns.ParsedUrl + ) -> 'FlowHandling': + bl_request_info = http_messages.BodylessRequestInfo.make( url = url, method = flow.request.method, - headers = MitmproxyHeadersWrapper(flow.request.headers) + headers = flow.request.headers ) return FlowHandling(flow, policy, bl_request_info) @@ -157,10 +135,6 @@ class PassedOptions: self.haketilo_launch_browser is not None) -magical_mitm_it_url_reg = re.compile(r'^http://mitm.it(/.*)?$') -dummy_url = parse_url('http://dummy.replacement.url') - - @dc.dataclass class HaketiloAddon: initial_options: PassedOptions = PassedOptions() @@ -257,16 +231,13 @@ class HaketiloAddon: handling = self.handling_dict.get(id(flow)) if handling is None: - parsed_url = dummy_url - - if magical_mitm_it_url_reg.match(flow.request.url): - policy = policies.DoNothingPolicy() + try: + parsed_url = url_patterns.parse_url(flow.request.url) + except url_patterns.HaketiloURLException as e: + policy = policies.ErrorBlockPolicy(builtin=True, error=e) + parsed_url = url_patterns.dummy_url else: - try: - parsed_url = parse_url(flow.request.url) - policy = self.state.select_policy(parsed_url) - except HaketiloException as e: - policy = policies.ErrorBlockPolicy(builtin=True, error=e) + policy = self.state.select_policy(parsed_url) handling = FlowHandling.make(flow, policy, parsed_url) @@ -330,16 +301,18 @@ class HaketiloAddon: result = handling.policy.consume_request(handling.request_info) if result is not None: - if isinstance(result, http_messages.ProducedRequest): - flow.request.url = result.url + mitmproxy_headers = http.Headers(result.headers.items_bin()) + + if isinstance(result, http_messages.RequestInfo): + flow.request.url = result.url.orig_url flow.request.method = result.method - flow.request.headers = http.Headers(result.headers) + flow.request.headers = mitmproxy_headers flow.request.set_content(result.body or None) else: - # isinstance(result, http_messages.ProducedResponse) + # isinstance(result, http_messages.ResponseInfo) flow.response = http.Response.make( status_code = result.status_code, - headers = http.Headers(result.headers), + headers = mitmproxy_headers, content = result.body ) @@ -370,8 +343,10 @@ class HaketiloAddon: response_info = handling.response_info ) if result is not None: + headers_bin = result.headers.items_bin() + flow.response.status_code = result.status_code - flow.response.headers = http.Headers(result.headers) + flow.response.headers = http.Headers(headers_bin) flow.response.set_content(result.body) self.forget_flow_handling(flow) diff --git a/src/hydrilla/proxy/http_messages.py b/src/hydrilla/proxy/http_messages.py index dbf2c63..1bed103 100644 --- a/src/hydrilla/proxy/http_messages.py +++ b/src/hydrilla/proxy/http_messages.py @@ -29,6 +29,7 @@ ..... """ +import re import dataclasses as dc import typing as t import sys @@ -38,13 +39,42 @@ if sys.version_info >= (3, 8): else: from typing_extensions import Protocol +import mitmproxy.http + from .. import url_patterns DefaultGetValue = t.TypeVar('DefaultGetValue', str, None) +class _MitmproxyHeadersWrapper(): + def __init__(self, headers: mitmproxy.http.Headers) -> None: + self.headers = headers + + __getitem__ = lambda self, key: self.headers[key] + get_all = lambda self, key: self.headers.get_all(key) + + @t.overload + def get(self, key: str) -> t.Optional[str]: + ... + @t.overload + def get(self, key: str, default: DefaultGetValue) \ + -> t.Union[str, DefaultGetValue]: + ... + def get(self, key, default = None): + value = self.headers.get(key) + + if value is None: + return default + else: + return t.cast(str, value) + + def items(self) -> t.Iterable[tuple[str, str]]: + return self.headers.items(multi=True) + + def items_bin(self) -> t.Iterable[tuple[bytes, bytes]]: + return tuple((key.encode(), val.encode()) for key, val in self.items()) + class IHeaders(Protocol): - """....""" def __getitem__(self, key: str) -> str: ... def get_all(self, key: str) -> t.Iterable[str]: ... @@ -59,65 +89,165 @@ class IHeaders(Protocol): def items(self) -> t.Iterable[tuple[str, str]]: ... -def encode_headers_items(headers: t.Iterable[tuple[str, str]]) \ - -> t.Iterable[tuple[bytes, bytes]]: - """....""" - for name, value in headers: - yield name.encode(), value.encode() + def items_bin(self) -> t.Iterable[tuple[bytes, bytes]]: ... + +_AnyHeaders = t.Union[ + t.Iterable[tuple[bytes, bytes]], + t.Iterable[tuple[str, str]], + mitmproxy.http.Headers, + IHeaders +] + +def make_headers(headers: _AnyHeaders) -> IHeaders: + if not isinstance(headers, mitmproxy.http.Headers): + if isinstance(headers, t.Iterable): + headers = tuple(headers) + if not headers or isinstance(headers[0][0], str): + headers = ((key.encode(), val.encode()) for key, val in headers) + + headers = mitmproxy.http.Headers(headers) + else: + # isinstance(headers, IHeaders) + return headers + + return _MitmproxyHeadersWrapper(headers) + + +_AnyUrl = t.Union[str, url_patterns.ParsedUrl] + +def make_parsed_url(url: t.Union[str, url_patterns.ParsedUrl]) \ + -> url_patterns.ParsedUrl: + return url_patterns.parse_url(url) if isinstance(url, str) else url + + +# For details of 'Content-Type' header's structure, see: +# https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.1 +content_type_reg = re.compile(r''' +^ +(?P[\w-]+/[\w-]+) +\s* +(?: + ; + (?:[^;]*;)* # match possible parameter other than "charset" +) +\s* +charset= # no whitespace allowed in parameter as per RFC +(?P + [\w-]+ + | + "[\w-]+" # quotes are optional per RFC +) +(?:;[^;]+)* # match possible parameter other than "charset" +$ # forbid possible dangling characters after closing '"' +''', re.VERBOSE | re.IGNORECASE) @dc.dataclass(frozen=True) -class ProducedRequest: - """....""" - url: str - method: str - headers: t.Iterable[tuple[bytes, bytes]] - body: bytes +class HasHeadersMixin: + headers: IHeaders + + def deduce_content_type(self) -> tuple[t.Optional[str], t.Optional[str]]: + content_type = self.headers.get('content-type') + if content_type is None: + return (None, None) + + match = content_type_reg.match(content_type) + if match is None: + return (None, None) + + mime, encoding = match.group('mime'), match.group('encoding') + + if encoding is not None: + encoding = encoding.lower() + + return mime, encoding + @dc.dataclass(frozen=True) -class BodylessRequestInfo: +class _BaseRequestInfoFields: url: url_patterns.ParsedUrl method: str headers: IHeaders +@dc.dataclass(frozen=True) +class BodylessRequestInfo(HasHeadersMixin, _BaseRequestInfoFields): def with_body(self, body: bytes) -> 'RequestInfo': return RequestInfo(self.url, self.method, self.headers, body) + @staticmethod + def make( + url: t.Union[str, url_patterns.ParsedUrl], + method: str, + headers: _AnyHeaders + ) -> 'BodylessRequestInfo': + url = make_parsed_url(url) + return BodylessRequestInfo(url, method, make_headers(headers)) + @dc.dataclass(frozen=True) -class RequestInfo(BodylessRequestInfo): +class RequestInfo(HasHeadersMixin, _BaseRequestInfoFields): body: bytes - def make_produced_request(self) -> ProducedRequest: - return ProducedRequest( - url = self.url.orig_url, - method = self.method, - headers = encode_headers_items(self.headers.items()), - body = self.body - ) + @staticmethod + def make( + url: _AnyUrl = url_patterns.dummy_url, + method: str = 'GET', + headers: _AnyHeaders = (), + body: bytes = b'' + ) -> 'RequestInfo': + return BodylessRequestInfo.make(url, method, headers).with_body(body) -@dc.dataclass(frozen=True) -class ProducedResponse: - """....""" - status_code: int - headers: t.Iterable[tuple[bytes, bytes]] - body: bytes @dc.dataclass(frozen=True) -class BodylessResponseInfo: - """....""" +class _BaseResponseInfoFields: url: url_patterns.ParsedUrl status_code: int headers: IHeaders +@dc.dataclass(frozen=True) +class BodylessResponseInfo(HasHeadersMixin, _BaseResponseInfoFields): def with_body(self, body: bytes) -> 'ResponseInfo': return ResponseInfo(self.url, self.status_code, self.headers, body) + @staticmethod + def make( + url: t.Union[str, url_patterns.ParsedUrl], + status_code: int, + headers: _AnyHeaders + ) -> 'BodylessResponseInfo': + url = make_parsed_url(url) + return BodylessResponseInfo(url, status_code, make_headers(headers)) + @dc.dataclass(frozen=True) -class ResponseInfo(BodylessResponseInfo): +class ResponseInfo(HasHeadersMixin, _BaseResponseInfoFields): body: bytes - def make_produced_response(self) -> ProducedResponse: - return ProducedResponse( - status_code = self.status_code, - headers = encode_headers_items(self.headers.items()), - body = self.body - ) + @staticmethod + def make( + url: _AnyUrl = url_patterns.dummy_url, + status_code: int = 404, + headers: _AnyHeaders = (), + body: bytes = b'' + ) -> 'ResponseInfo': + bl_info = BodylessResponseInfo.make(url, status_code, headers) + return bl_info.with_body(body) + + +def is_likely_a_page( + request_info: t.Union[BodylessRequestInfo, RequestInfo], + response_info: t.Union[BodylessResponseInfo, ResponseInfo] +) -> bool: + fetch_dest = request_info.headers.get('sec-fetch-dest') + if fetch_dest is None: + if 'html' in request_info.headers.get('accept', ''): + fetch_dest = 'document' + else: + fetch_dest = 'unknown' + + if fetch_dest not in ('document', 'iframe', 'frame', 'embed', 'object'): + return False + + mime, encoding = response_info.deduce_content_type() + + # Right now out of all response headers we're only taking Content-Type into + # account. In the future we might also want to consider the + # Content-Disposition header. + return mime is not None and 'html' in mime diff --git a/src/hydrilla/proxy/policies/__init__.py b/src/hydrilla/proxy/policies/__init__.py index e958cbd..2276177 100644 --- a/src/hydrilla/proxy/policies/__init__.py +++ b/src/hydrilla/proxy/policies/__init__.py @@ -13,6 +13,6 @@ from .payload_resource import PayloadResourcePolicyFactory from .rule import RuleBlockPolicyFactory, RuleAllowPolicyFactory from .misc import FallbackAllowPolicy, FallbackBlockPolicy, ErrorBlockPolicy, \ - DoNothingPolicy + MitmItPagePolicyFactory from .web_ui import WebUIMainPolicyFactory, WebUILandingPolicyFactory diff --git a/src/hydrilla/proxy/policies/base.py b/src/hydrilla/proxy/policies/base.py index fcdbf9d..8ea792f 100644 --- a/src/hydrilla/proxy/policies/base.py +++ b/src/hydrilla/proxy/policies/base.py @@ -48,9 +48,9 @@ class PolicyPriority(int, enum.Enum): _TWO = 2 _THREE = 3 -ProducedMessage = t.Union[ - http_messages.ProducedRequest, - http_messages.ProducedResponse +MessageInfo = t.Union[ + http_messages.RequestInfo, + http_messages.ResponseInfo ] class Policy(ABC): @@ -75,7 +75,7 @@ class Policy(ABC): return self._process_response def consume_request(self, request_info: http_messages.RequestInfo) \ - -> t.Optional[ProducedMessage]: + -> t.Optional[MessageInfo]: raise NotImplementedError( 'This kind of policy does not consume requests.' ) @@ -84,7 +84,7 @@ class Policy(ABC): self, request_info: http_messages.RequestInfo, response_info: http_messages.ResponseInfo - ) -> t.Optional[http_messages.ProducedResponse]: + ) -> t.Optional[http_messages.ResponseInfo]: raise NotImplementedError( 'This kind of policy does not consume responses.' ) @@ -109,6 +109,11 @@ class PolicyFactory(ABC): sorting_keys.get(other.__class__.__name__, 999) sorting_order = ( + 'WebUIMainPolicyFactory', + 'WebUILandingPolicyFactory', + + 'MitmItPagePolicyFactory', + 'PayloadResourcePolicyFactory', 'PayloadPolicyFactory', diff --git a/src/hydrilla/proxy/policies/misc.py b/src/hydrilla/proxy/policies/misc.py index 71692b3..81875a2 100644 --- a/src/hydrilla/proxy/policies/misc.py +++ b/src/hydrilla/proxy/policies/misc.py @@ -58,16 +58,19 @@ class ErrorBlockPolicy(BlockPolicy): builtin: bool = True -class DoNothingPolicy(base.Policy): + +class MitmItPagePolicy(base.Policy): """ A special policy class for handling of the magical mitm.it domain. It causes - request and response not to be modified in any way, and also (unlike + request and response not to be modified in any way and also (unlike FallbackAllowPolicy) prevents them from being streamed. """ _process_request: t.ClassVar[bool] = True _process_response: t.ClassVar[bool] = True anticache: t.ClassVar[bool] = False + priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._THREE + def consume_request(self, request_info: http_messages.RequestInfo) -> None: return None @@ -79,3 +82,11 @@ class DoNothingPolicy(base.Policy): return None builtin: bool = True + +@dc.dataclass(frozen=True, unsafe_hash=True) +class MitmItPagePolicyFactory(base.PolicyFactory): + builtin: bool = True + + def make_policy(self, haketilo_state: state.HaketiloState) \ + -> MitmItPagePolicy: + return MitmItPagePolicy() diff --git a/src/hydrilla/proxy/policies/payload.py b/src/hydrilla/proxy/policies/payload.py index 5b71af7..b89a1c1 100644 --- a/src/hydrilla/proxy/policies/payload.py +++ b/src/hydrilla/proxy/policies/payload.py @@ -31,7 +31,6 @@ import dataclasses as dc import typing as t -import re from urllib.parse import urlencode @@ -91,45 +90,6 @@ class PayloadAwarePolicyFactory(base.PolicyFactory): return super().__lt__(other) -# For details of 'Content-Type' header's structure, see: -# https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.1 -content_type_reg = re.compile(r''' -^ -(?P[\w-]+/[\w-]+) -\s* -(?: - ; - (?:[^;]*;)* # match possible parameter other than "charset" -) -\s* -charset= # no whitespace allowed in parameter as per RFC -(?P - [\w-]+ - | - "[\w-]+" # quotes are optional per RFC -) -(?:;[^;]+)* # match possible parameter other than "charset" -$ # forbid possible dangling characters after closing '"' -''', re.VERBOSE | re.IGNORECASE) - -def deduce_content_type(headers: http_messages.IHeaders) \ - -> tuple[t.Optional[str], t.Optional[str]]: - """....""" - content_type = headers.get('content-type') - if content_type is None: - return (None, None) - - match = content_type_reg.match(content_type) - if match is None: - return (None, None) - - mime, encoding = match.group('mime'), match.group('encoding') - - if encoding is not None: - encoding = encoding.lower() - - return mime, encoding - UTF8_BOM = b'\xEF\xBB\xBF' BOMs = ( (UTF8_BOM, 'utf-8'), @@ -174,15 +134,17 @@ class PayloadInjectPolicy(PayloadAwarePolicy): )) def _modify_headers(self, response_info: http_messages.ResponseInfo) \ - -> t.Iterable[tuple[bytes, bytes]]: - """....""" - for header_name, header_value in response_info.headers.items(): - if header_name.lower() not in csp.header_names_and_dispositions: - yield header_name.encode(), header_value.encode() + -> http_messages.IHeaders: + new_headers = [] + + for key, val in response_info.headers.items(): + if key.lower() not in csp.header_names_and_dispositions: + new_headers.append((key, val)) new_csp = self._new_csp(response_info.url) + new_headers.append(('Content-Security-Policy', new_csp)) - yield b'Content-Security-Policy', new_csp.encode() + return http_messages.make_headers(new_headers) def _script_urls(self, url: ParsedUrl) -> t.Iterable[str]: """....""" @@ -231,22 +193,18 @@ class PayloadInjectPolicy(PayloadAwarePolicy): def _consume_response_unsafe( self, + request_info: http_messages.RequestInfo, response_info: http_messages.ResponseInfo - ) -> http_messages.ProducedResponse: - """....""" - new_response = response_info.make_produced_response() - + ) -> http_messages.ResponseInfo: new_headers = self._modify_headers(response_info) + new_response = dc.replace(response_info, headers=new_headers) - new_response = dc.replace(new_response, headers=new_headers) - - mime, encoding = deduce_content_type(response_info.headers) - if mime is None or 'html' not in mime.lower(): + if not http_messages.is_likely_a_page(request_info, response_info): return new_response data = response_info.body - if data is None: - data = b'' + + _, encoding = response_info.deduce_content_type() # A UTF BOM overrides encoding specified by the header. for bom, encoding_name in BOMs: @@ -261,9 +219,9 @@ class PayloadInjectPolicy(PayloadAwarePolicy): self, request_info: http_messages.RequestInfo, response_info: http_messages.ResponseInfo - ) -> http_messages.ProducedResponse: + ) -> http_messages.ResponseInfo: try: - return self._consume_response_unsafe(response_info) + return self._consume_response_unsafe(request_info, response_info) except Exception as e: # TODO: actually describe the errors import traceback @@ -274,10 +232,10 @@ class PayloadInjectPolicy(PayloadAwarePolicy): e.__traceback__ ) - return http_messages.ProducedResponse( - 500, - ((b'Content-Type', b'text/plain; charset=utf-8'),), - '\n'.join(error_info_list).encode() + return http_messages.ResponseInfo.make( + status_code = 500, + headers = (('Content-Type', 'text/plain; charset=utf-8'),), + body = '\n'.join(error_info_list).encode() ) @@ -292,7 +250,7 @@ class AutoPayloadInjectPolicy(PayloadInjectPolicy): self, request_info: http_messages.RequestInfo, response_info: http_messages.ResponseInfo - ) -> http_messages.ProducedResponse: + ) -> http_messages.ResponseInfo: try: if self.payload_data.ref.has_problems(): raise _PayloadHasProblemsError() @@ -317,9 +275,9 @@ class AutoPayloadInjectPolicy(PayloadInjectPolicy): redirect_url = 'https://hkt.mitm.it/auto_install_error?' + query msg = 'Error occured when installing payload. Redirecting.' - return http_messages.ProducedResponse( + return http_messages.ResponseInfo.make( status_code = 303, - headers = [(b'Location', redirect_url.encode())], + headers = [('Location', redirect_url)], body = msg.encode() ) @@ -332,7 +290,7 @@ class PayloadSuggestPolicy(PayloadAwarePolicy): priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE def consume_request(self, request_info: http_messages.RequestInfo) \ - -> http_messages.ProducedResponse: + -> http_messages.ResponseInfo: query = self._payload_details_to_signed_query_string( _salt = 'package_suggestion', next_url = request_info.url.orig_url @@ -341,9 +299,9 @@ class PayloadSuggestPolicy(PayloadAwarePolicy): redirect_url = 'https://hkt.mitm.it/package_suggestion?' + query msg = 'A package was found that could be used on this site. Redirecting.' - return http_messages.ProducedResponse( + return http_messages.ResponseInfo.make( status_code = 303, - headers = [(b'Location', redirect_url.encode())], + headers = [('Location', redirect_url)], body = msg.encode() ) diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py index ae6a490..10a43e6 100644 --- a/src/hydrilla/proxy/policies/payload_resource.py +++ b/src/hydrilla/proxy/policies/payload_resource.py @@ -219,9 +219,9 @@ def merge_response_headers( ) -ProducedAny = t.Union[ - http_messages.ProducedResponse, - http_messages.ProducedRequest +MessageInfo = t.Union[ + http_messages.ResponseInfo, + http_messages.RequestInfo ] @dc.dataclass(frozen=True) @@ -251,31 +251,30 @@ class PayloadResourcePolicy(PayloadAwarePolicy): == ('api', 'unrestricted_http') def _make_file_resource_response(self, path: tuple[str, ...]) \ - -> http_messages.ProducedResponse: - """....""" + -> http_messages.ResponseInfo: try: file_data = self.payload_data.ref.get_file_data(path) except state.MissingItemError: return resource_blocked_response if file_data is None: - return http_messages.ProducedResponse( - 404, - [(b'Content-Type', b'text/plain; charset=utf-8')], - _('api.file_not_found').encode() + return http_messages.ResponseInfo.make( + status_code = 404, + headers = [('Content-Type', 'text/plain; charset=utf-8')], + body =_('api.file_not_found').encode() ) - return http_messages.ProducedResponse( - 200, - ((b'Content-Type', file_data.mime_type.encode()),), - file_data.contents + return http_messages.ResponseInfo.make( + status_code = 200, + headers = [('Content-Type', file_data.mime_type)], + body = file_data.contents ) def _make_api_response( self, path: tuple[str, ...], request_info: http_messages.RequestInfo - ) -> ProducedAny: + ) -> MessageInfo: if path[0] == 'page_init_script.js': with jinja_lock: template = jinja_env.get_template('page_init_script.js.jinja') @@ -288,10 +287,10 @@ class PayloadResourcePolicy(PayloadAwarePolicy): haketilo_version = encode_string_for_js(ver_str) ) - return http_messages.ProducedResponse( - 200, - ((b'Content-Type', b'application/javascript'),), - js.encode() + return http_messages.ResponseInfo.make( + status_code = 200, + headers = [('Content-Type', 'application/javascript')], + body = js.encode() ) if path[0] == 'unrestricted_http': @@ -315,13 +314,10 @@ class PayloadResourcePolicy(PayloadAwarePolicy): extra_headers = extra_headers ) - result_headers_bytes = \ - [(h.encode(), v.encode()) for h, v in result_headers] - - return http_messages.ProducedRequest( + return http_messages.RequestInfo.make( url = target_url, method = request_info.method, - headers = result_headers_bytes, + headers = result_headers, body = request_info.body ) except: @@ -330,7 +326,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy): return resource_blocked_response def consume_request(self, request_info: http_messages.RequestInfo) \ - -> ProducedAny: + -> MessageInfo: resource_path = self.extract_resource_path(request_info.url) if resource_path == (): @@ -346,7 +342,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy): self, request_info: http_messages.RequestInfo, response_info: http_messages.ResponseInfo - ) -> http_messages.ProducedResponse: + ) -> http_messages.ResponseInfo: """ This method shall only be called for responses to unrestricted HTTP API requests. Its purpose is to sanitize response headers and smuggle their @@ -375,17 +371,17 @@ class PayloadResourcePolicy(PayloadAwarePolicy): extra_headers = extra_headers ) - return http_messages.ProducedResponse( + return http_messages.ResponseInfo.make( status_code = response_info.status_code, - headers = [(h.encode(), v.encode()) for h, v in merged_headers], + headers = merged_headers, body = response_info.body, ) -resource_blocked_response = http_messages.ProducedResponse( - 403, - [(b'Content-Type', b'text/plain; charset=utf-8')], - _('api.resource_not_enabled_for_access').encode() +resource_blocked_response = http_messages.ResponseInfo.make( + status_code = 403, + headers = [('Content-Type', 'text/plain; charset=utf-8')], + body = _('api.resource_not_enabled_for_access').encode() ) @dc.dataclass(frozen=True) @@ -396,8 +392,7 @@ class BlockedResponsePolicy(base.Policy): priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._THREE def consume_request(self, request_info: http_messages.RequestInfo) \ - -> http_messages.ProducedResponse: - """....""" + -> http_messages.ResponseInfo: return resource_blocked_response diff --git a/src/hydrilla/proxy/policies/rule.py b/src/hydrilla/proxy/policies/rule.py index 2e9443e..8272d2f 100644 --- a/src/hydrilla/proxy/policies/rule.py +++ b/src/hydrilla/proxy/policies/rule.py @@ -50,13 +50,14 @@ class BlockPolicy(base.Policy): priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._TWO def _modify_headers(self, response_info: http_messages.ResponseInfo) \ - -> t.Iterable[tuple[bytes, bytes]]: - """....""" + -> http_messages.IHeaders: + new_headers = [] + csp_policies = csp.extract(response_info.headers) - for header_name, header_value in response_info.headers.items(): - if header_name.lower() not in csp.header_names_and_dispositions: - yield header_name.encode(), header_value.encode() + for key, val in response_info.headers.items(): + if key.lower() not in csp.header_names_and_dispositions: + new_headers.append((key, val)) for policy in csp_policies: if policy.disposition != 'enforce': @@ -68,7 +69,7 @@ class BlockPolicy(base.Policy): policy = dc.replace(policy, directives=directives.finish()) - yield policy.header_name.encode(), policy.serialize().encode() + new_headers.append((policy.header_name, policy.serialize())) extra_csp = ';'.join(( "script-src 'none'", @@ -76,19 +77,19 @@ class BlockPolicy(base.Policy): "script-src-attr 'none'" )) - yield b'Content-Security-Policy', extra_csp.encode() + new_headers.append(('Content-Security-Policy', extra_csp)) + + return http_messages.make_headers(new_headers) def consume_response( self, request_info: http_messages.RequestInfo, response_info: http_messages.ResponseInfo - ) -> http_messages.ProducedResponse: - new_response = response_info.make_produced_response() - + ) -> http_messages.ResponseInfo: new_headers = self._modify_headers(response_info) - return dc.replace(new_response, headers=new_headers) + return dc.replace(response_info, headers=new_headers) @dc.dataclass(frozen=True) class RuleAllowPolicy(AllowPolicy): diff --git a/src/hydrilla/proxy/policies/web_ui.py b/src/hydrilla/proxy/policies/web_ui.py index f35b0b7..284d062 100644 --- a/src/hydrilla/proxy/policies/web_ui.py +++ b/src/hydrilla/proxy/policies/web_ui.py @@ -50,7 +50,7 @@ class WebUIPolicy(base.Policy): ui_domain: web_ui.UIDomain def consume_request(self, request_info: http_messages.RequestInfo) \ - -> http_messages.ProducedResponse: + -> http_messages.ResponseInfo: return web_ui.process_request( request_info = request_info, state = self.haketilo_state, diff --git a/src/hydrilla/proxy/state_impl/concrete_state.py b/src/hydrilla/proxy/state_impl/concrete_state.py index 83522cf..c28e360 100644 --- a/src/hydrilla/proxy/state_impl/concrete_state.py +++ b/src/hydrilla/proxy/state_impl/concrete_state.py @@ -223,6 +223,15 @@ class ConcreteHaketiloState(base.HaketiloStateWithFields): item = web_ui_landing_factory ) + mitm_it_page_pattern = 'http://mitm.it/***' + mitm_it_page_factory = policies.MitmItPagePolicyFactory() + + parsed_pattern, = url_patterns.parse_pattern(mitm_it_page_pattern) + new_policy_tree = new_policy_tree.register( + parsed_pattern = parsed_pattern, + item = mitm_it_page_factory + ) + # Put script blocking/allowing rules in policy tree. cursor.execute('SELECT pattern, allow_scripts FROM rules;') diff --git a/src/hydrilla/proxy/web_ui/root.py b/src/hydrilla/proxy/web_ui/root.py index 57dc958..3120d0e 100644 --- a/src/hydrilla/proxy/web_ui/root.py +++ b/src/hydrilla/proxy/web_ui/root.py @@ -191,7 +191,7 @@ def process_request( request_info: http_messages.RequestInfo, state: st.HaketiloState, ui_domain: _app.UIDomain = _app.UIDomain.MAIN -) -> http_messages.ProducedResponse: +) -> http_messages.ResponseInfo: path = '/'.join(('', *request_info.url.path_segments)) if (request_info.url.has_trailing_slash): path += '/' @@ -218,7 +218,7 @@ def process_request( in flask_response.headers ] - return http_messages.ProducedResponse( + return http_messages.ResponseInfo.make( status_code = flask_response.status_code, headers = headers_bytes, body = flask_response.data diff --git a/src/hydrilla/url_patterns.py b/src/hydrilla/url_patterns.py index cc68820..5e62a28 100644 --- a/src/hydrilla/url_patterns.py +++ b/src/hydrilla/url_patterns.py @@ -240,3 +240,6 @@ def normalize_pattern(url_pattern: str) -> str: reconstructed = replace_scheme_regex.sub('http*', reconstructed) return reconstructed + + +dummy_url = parse_url('http://dummy.replacement.url') -- cgit v1.2.3