diff options
Diffstat (limited to 'src/hydrilla/proxy')
-rw-r--r-- | src/hydrilla/proxy/addon.py | 118 | ||||
-rw-r--r-- | src/hydrilla/proxy/http_messages.py | 12 | ||||
-rw-r--r-- | src/hydrilla/proxy/policies/base.py | 13 | ||||
-rw-r--r-- | src/hydrilla/proxy/policies/js_templates/page_init_script.js.jinja | 145 | ||||
-rw-r--r-- | src/hydrilla/proxy/policies/payload.py | 6 | ||||
-rw-r--r-- | src/hydrilla/proxy/policies/payload_resource.py | 289 | ||||
-rw-r--r-- | src/hydrilla/proxy/policies/rule.py | 2 | ||||
-rw-r--r-- | src/hydrilla/proxy/policies/web_ui.py | 2 |
8 files changed, 502 insertions, 85 deletions
diff --git a/src/hydrilla/proxy/addon.py b/src/hydrilla/proxy/addon.py index c1069bc..2185bcb 100644 --- a/src/hydrilla/proxy/addon.py +++ b/src/hydrilla/proxy/addon.py @@ -48,13 +48,13 @@ from mitmproxy.script import concurrent from ..exceptions import HaketiloException from ..translations import smart_gettext as _ -from ..url_patterns import parse_url +from ..url_patterns import parse_url, ParsedUrl from .state_impl import ConcreteHaketiloState from . import policies from . import http_messages -DefaultGetValue = t.TypeVar('DefaultGetValue', object, None) +DefaultGetValue = t.TypeVar('DefaultGetValue', str, None) class MitmproxyHeadersWrapper(): """....""" @@ -65,9 +65,14 @@ class MitmproxyHeadersWrapper(): __getitem__ = lambda self, key: self.headers[key] get_all = lambda self, key: self.headers.get_all(key) - def get(self, key: str, default: DefaultGetValue = None) \ + @t.overload + def get(self, key: str) -> t.Optional[str]: + ... + @t.overload + def get(self, key: str, default: DefaultGetValue) \ -> t.Union[str, DefaultGetValue]: - """....""" + ... + def get(self, key, default = None): value = self.headers.get(key) if value is None: @@ -79,6 +84,13 @@ class MitmproxyHeadersWrapper(): """....""" return self.headers.items(multi=True) + +@dc.dataclass(frozen=True) +class FlowHandlingData: + request_url: ParsedUrl + policy: policies.Policy + + @dc.dataclass class HaketiloAddon: """ @@ -87,8 +99,8 @@ class HaketiloAddon: configured: bool = False configured_lock: Lock = dc.field(default_factory=Lock) - flow_policies: dict[int, policies.Policy] = dc.field(default_factory=dict) - policies_lock: Lock = dc.field(default_factory=Lock) + flows_data: dict[int, FlowHandlingData] = dc.field(default_factory=dict) + flows_data_lock: Lock = dc.field(default_factory=Lock) state: t.Optional[ConcreteHaketiloState] = None @@ -121,37 +133,32 @@ class HaketiloAddon: self.configured = True - def try_get_policy(self, flow: http.HTTPFlow, fail_ok: bool = True) -> \ - t.Optional[policies.Policy]: - """....""" - with self.policies_lock: - policy = self.flow_policies.get(id(flow)) + def get_handling_data(self, flow: http.HTTPFlow) -> FlowHandlingData: + policy: policies.Policy - if policy is None: - try: - parsed_url = parse_url(flow.request.url) - except HaketiloException: - if fail_ok: - return None - else: - raise + assert self.state is not None - assert self.state is not None + with self.flows_data_lock: + handling_data = self.flows_data.get(id(flow)) - policy = self.state.select_policy(parsed_url) + if handling_data is None: + try: + parsed_url = parse_url(flow.request.url) + policy = self.state.select_policy(parsed_url) + except HaketiloException as e: + policy = policies.ErrorBlockPolicy(builtin=True, error=e) - with self.policies_lock: - self.flow_policies[id(flow)] = policy + handling_data = FlowHandlingData(parsed_url, policy) - return policy + with self.flows_data_lock: + self.flows_data[id(flow)] = handling_data - def get_policy(self, flow: http.HTTPFlow) -> policies.Policy: - return t.cast(policies.Policy, self.try_get_policy(flow, fail_ok=False)) + return handling_data - def forget_policy(self, flow: http.HTTPFlow) -> None: + def forget_handling_data(self, flow: http.HTTPFlow) -> None: """....""" - with self.policies_lock: - self.flow_policies.pop(id(flow), None) + with self.flows_data_lock: + self.flows_data.pop(id(flow), None) @contextmanager def http_safe_event_handling(self, flow: http.HTTPFlow) -> t.Iterator: @@ -172,19 +179,10 @@ class HaketiloAddon: headers = [(b'Content-Type', b'text/plain; charset=utf-8')] ) - self.forget_policy(flow) + self.forget_handling_data(flow) @concurrent def requestheaders(self, flow: http.HTTPFlow) -> None: - # TODO: don't account for mitmproxy 6 in the code - # Mitmproxy 6 causes even more strange behavior than described below. - # This cannot be easily worked around. Let's just use version 8 and - # make an APT package for it. - """ - Under mitmproxy 8 this handler deduces an appropriate policy for flow's - URL and assigns it to the flow. Under mitmproxy 6 the URL is not yet - available at this point, so the handler effectively does nothing. - """ with self.http_safe_event_handling(flow): referrer = flow.request.headers.get('referer') if referrer is not None: @@ -194,13 +192,13 @@ class HaketiloAddon: # visited before. flow.request.headers.pop('referer', None) - policy = self.try_get_policy(flow) + handling_data = self.get_handling_data(flow) + policy = handling_data.policy - if policy is not None: - if not policy.process_request: - flow.request.stream = True - if policy.anticache: - flow.request.anticache() + if not policy.should_process_request(handling_data.request_url): + flow.request.stream = True + if policy.anticache: + flow.request.anticache() @concurrent def request(self, flow: http.HTTPFlow) -> None: @@ -211,25 +209,23 @@ class HaketiloAddon: return with self.http_safe_event_handling(flow): - policy = self.get_policy(flow) + handling_data = self.get_handling_data(flow) request_info = http_messages.RequestInfo( - url = parse_url(flow.request.url), + url = handling_data.request_url, method = flow.request.method, headers = MitmproxyHeadersWrapper(flow.request.headers), body = flow.request.get_content(strict=False) or b'' ) - result = policy.consume_request(request_info) + result = handling_data.policy.consume_request(request_info) if result is not None: if isinstance(result, http_messages.ProducedRequest): - flow.request = http.Request.make( - url = result.url, - method = result.method, - headers = http.Headers(result.headers), - content = result.body - ) + flow.request.url = result.url + flow.request.method = result.method + flow.request.headers = http.Headers(result.headers) + flow.request.set_content(result.body or None) else: # isinstance(result, http_messages.ProducedResponse) flow.response = http.Response.make( @@ -245,9 +241,10 @@ class HaketiloAddon: assert flow.response is not None with self.http_safe_event_handling(flow): - policy = self.get_policy(flow) + handling_data = self.get_handling_data(flow) + policy = handling_data.policy - if not policy.process_response: + if not policy.should_process_response(handling_data.request_url): flow.response.stream = True @concurrent @@ -261,22 +258,23 @@ class HaketiloAddon: return with self.http_safe_event_handling(flow): - policy = self.get_policy(flow) + handling_data = self.get_handling_data(flow) response_info = http_messages.ResponseInfo( url = parse_url(flow.request.url), + orig_url = handling_data.request_url, status_code = flow.response.status_code, headers = MitmproxyHeadersWrapper(flow.response.headers), body = flow.response.get_content(strict=False) or b'' ) - result = policy.consume_response(response_info) + result = handling_data.policy.consume_response(response_info) if result is not None: flow.response.status_code = result.status_code flow.response.headers = http.Headers(result.headers) flow.response.set_content(result.body) - self.forget_policy(flow) + self.forget_handling_data(flow) def tls_clienthello(self, data: tls.ClientHelloData): if data.context.server.address is None: @@ -291,4 +289,4 @@ class HaketiloAddon: def error(self, flow: http.HTTPFlow) -> None: """....""" - self.forget_policy(flow) + self.forget_handling_data(flow) diff --git a/src/hydrilla/proxy/http_messages.py b/src/hydrilla/proxy/http_messages.py index 698020d..53d10db 100644 --- a/src/hydrilla/proxy/http_messages.py +++ b/src/hydrilla/proxy/http_messages.py @@ -44,7 +44,7 @@ else: from .. import url_patterns -DefaultGetValue = t.TypeVar('DefaultGetValue', object, None) +DefaultGetValue = t.TypeVar('DefaultGetValue', str, None) class IHeaders(Protocol): """....""" @@ -52,8 +52,13 @@ class IHeaders(Protocol): def get_all(self, key: str) -> t.Iterable[str]: ... - def get(self, key: str, default: DefaultGetValue = None) \ - -> t.Union[str, DefaultGetValue]: ... + @t.overload + def get(self, key: str) -> t.Optional[str]: + ... + @t.overload + def get(self, key: str, default: DefaultGetValue) \ + -> t.Union[str, DefaultGetValue]: + ... def items(self) -> t.Iterable[tuple[str, str]]: ... @@ -99,6 +104,7 @@ class ProducedResponse: class ResponseInfo: """....""" url: url_patterns.ParsedUrl + orig_url: url_patterns.ParsedUrl status_code: int headers: IHeaders body: bytes diff --git a/src/hydrilla/proxy/policies/base.py b/src/hydrilla/proxy/policies/base.py index b7beba3..c02ea0b 100644 --- a/src/hydrilla/proxy/policies/base.py +++ b/src/hydrilla/proxy/policies/base.py @@ -40,6 +40,7 @@ from abc import ABC, abstractmethod from immutables import Map +from ... url_patterns import ParsedUrl from .. import state from .. import http_messages @@ -57,12 +58,18 @@ ProducedMessage = t.Union[ class Policy(ABC): """....""" - process_request: t.ClassVar[bool] = False - process_response: t.ClassVar[bool] = False - anticache: t.ClassVar[bool] = True + _process_request: t.ClassVar[bool] = False + _process_response: t.ClassVar[bool] = False + anticache: t.ClassVar[bool] = True priority: t.ClassVar[PolicyPriority] + def should_process_request(self, parsed_url: ParsedUrl) -> bool: + return self._process_request + + def should_process_response(self, parsed_url: ParsedUrl) -> bool: + return self._process_response + def consume_request(self, request_info: http_messages.RequestInfo) \ -> t.Optional[ProducedMessage]: raise NotImplementedError( diff --git a/src/hydrilla/proxy/policies/js_templates/page_init_script.js.jinja b/src/hydrilla/proxy/policies/js_templates/page_init_script.js.jinja new file mode 100644 index 0000000..3a8382c --- /dev/null +++ b/src/hydrilla/proxy/policies/js_templates/page_init_script.js.jinja @@ -0,0 +1,145 @@ +{# +SPDX-License-Identifier: GPL-3.0-or-later + +Haketilo page APIs code template. + +This file is part of Hydrilla&Haketilo. + +Copyright (C) 2021,2022 Wojtek Kosior + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +As additional permission under GNU GPL version 3 section 7, you +may distribute forms of that code without the copy of the GNU +GPL normally required by section 4, provided you include this +license notice and, in case of non-source distribution, a URL +through which recipients can access the Corresponding Source. +If you modify file(s) with this exception, you may extend this +exception to your version of the file(s), but you are not +obligated to do so. If you do not wish to do so, delete this +exception statement from your version. + +As a special exception to the GPL, any HTML file which merely +makes function calls to this code, and for that purpose +includes it by reference shall be deemed a separate work for +copyright law purposes. If you modify this code, you may extend +this exception to your version of the code, but you are not +obligated to do so. If you do not wish to do so, delete this +exception statement from your version. + +You should have received a copy of the GNU General Public License +along with this program. If not, see <https://www.gnu.org/licenses/>. + + +I, Wojtek Kosior, thereby promise not to sue for violation of this +file's license. Although I request that you do not make use of this +code in a proprietary program, I am not going to enforce this in court. +-#} + +(function(){ + /* + * Snapshot some variables that other code could theoretically redefine + * later. We're not making the effort to protect from redefinition of + * prototype properties right now. + */ + const console = window.console; + const fetch = window.fetch; + const JSON = window.JSON; + const URL = window.URL; + const Array = window.Array; + const Uint8Array = window.Uint8Array; + const CustomEvent = window.CustomEvent; + const window_dispatchEvent = window.dispatchEvent; + + /* Get values from the proxy. */ + function decode_jinja(str) { + return decodeURIComponent(atob(str)); + } + const unique_token = decode_jinja("{{ unique_token_encoded }}"); + const assets_base_url = decode_jinja("{{ assets_base_url_encoded }}"); + + /* Make it possible to serialize an Error object. */ + function error_data_jsonifiable(error) { + const jsonifiable = {}; + for (const property of ["name", "message", "fileName", "lineNumber"]) + jsonifiable[property] = error[property]; + + return jsonifiable; + } + + /* Make it possible to serialize a Uint8Array. */ + function uint8_to_hex(array) { + return [...array].map(b => ("0" + b.toString(16)).slice(-2)).join(""); + } + + async function on_unrestricted_http_request(event) { + const name = "haketilo_CORS_bypass"; + + if (typeof event.detail !== "object" || + event.detail === null || + typeof event.detail.id !== "string" || + typeof event.detail.data !== "string") { + console.error(`Unrestricted HTTP: Invalid detail.`, event.detail); + return; + } + + try { + const data = JSON.parse(event.detail.data); + + const params = new URLSearchParams({ + target_url: data.url, + extra_headers: JSON.stringify(data.headers || []) + }); + const replacement_url = assets_base_url + "api/unrestricted_http"; + const replacement_url_obj = new URL(replacement_url); + replacement_url_obj.search = params; + + const response = await fetch(replacement_url_obj.href, data.init); + const response_buffer = await response.arrayBuffer(); + + const true_headers_serialized = + response.headers.get("x-haketilo-true-headers"); + + if (true_headers_serialized === null) + throw new Error("Unrestricted HTTP: The 'X-Haketilo-True-Headers' HTTP response header is missing. Are we connected to Haketilo proxy?") + + const true_headers = JSON.parse( + decodeURIComponent(true_headers_serialized) + ); + + const bad_format_error_msg = + "Unrestricted HTTP: The 'X-Haketilo-True-Headers' HTTP response header has invalid format."; + + if (!Array.isArray(true_headers)) + throw new Error(bad_format_error_msg); + + for (const [header, value] of true_headers) { + if (typeof header !== "string" || typeof value !== "string") + throw new Error(bad_format_error_msg); + } + + var result = { + status: response.status, + statusText: response.statusText, + headers: true_headers, + body: uint8_to_hex(new Uint8Array(response_buffer)) + }; + } catch(e) { + var result = {error: error_data_jsonifiable(e)}; + } + + const response_name = `${name}-${event.detail.id}`; + const detail = JSON.stringify(result); + window_dispatchEvent(new CustomEvent(response_name, {detail})); +} + +window.addEventListener("haketilo_CORS_bypass", on_unrestricted_http_request); +})(); diff --git a/src/hydrilla/proxy/policies/payload.py b/src/hydrilla/proxy/policies/payload.py index c50bdef..7eef184 100644 --- a/src/hydrilla/proxy/policies/payload.py +++ b/src/hydrilla/proxy/policies/payload.py @@ -157,7 +157,7 @@ def block_attr(element: bs4.PageElement, attr_name: str) -> None: @dc.dataclass(frozen=True) class PayloadInjectPolicy(PayloadAwarePolicy): """....""" - process_response: t.ClassVar[bool] = True + _process_response: t.ClassVar[bool] = True priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._TWO @@ -192,6 +192,8 @@ class PayloadInjectPolicy(PayloadAwarePolicy): base_url = self.assets_base_url(url) payload_ref = self.payload_data.ref + yield base_url + 'api/page_init_script.js' + for path in payload_ref.get_script_paths(): yield base_url + '/'.join(('static', *path)) @@ -323,7 +325,7 @@ class AutoPayloadInjectPolicy(PayloadInjectPolicy): @dc.dataclass(frozen=True) class PayloadSuggestPolicy(PayloadAwarePolicy): """....""" - process_request: t.ClassVar[bool] = True + _process_request: t.ClassVar[bool] = True priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py index 30b28f2..cda19ba 100644 --- a/src/hydrilla/proxy/policies/payload_resource.py +++ b/src/hydrilla/proxy/policies/payload_resource.py @@ -59,21 +59,195 @@ from __future__ import annotations import dataclasses as dc import typing as t +import json + +from threading import Lock +from base64 import b64encode +from urllib.parse import quote, parse_qs, urlparse, urlencode, urljoin + +import jinja2 from ...translations import smart_gettext as _ +from ...url_patterns import ParsedUrl from .. import state from .. import http_messages from . import base from .payload import PayloadAwarePolicy, PayloadAwarePolicyFactory +loader = jinja2.PackageLoader(__package__, package_path='js_templates') +jinja_env = jinja2.Environment( + loader = loader, + lstrip_blocks = True, + autoescape = False +) +jinja_lock = Lock() + + +def encode_string_for_js(string: str) -> str: + return b64encode(quote(string).encode()).decode() + + +AnyValue = t.TypeVar('AnyValue', bound=object) + +def header_keys(headers: t.Iterable[tuple[str, AnyValue]]) -> frozenset[str]: + return frozenset(header.lower() for header, _ in headers) + +def _merge_headers( + standard_headers: t.Iterable[tuple[str, t.Optional[str]]], + overridable_headers_keys: frozenset[str], + native_headers: http_messages.IHeaders, + extra_headers: t.Iterable[tuple[str, str]] +) -> t.Iterable[tuple[str, str]]: + standard_keys = header_keys(standard_headers) + standard_iterator = iter(standard_headers) + native_keys = header_keys(native_headers.items()) + + selected_base: list[tuple[str, str]] = [] + processed: set[str] = set() + + for header, _ in native_headers.items(): + header_l = header.lower() + + if header_l in processed or header_l not in standard_keys: + continue + + for standard_header_l, chosen_value in standard_iterator: + if standard_header_l not in native_keys: + if chosen_value is not None: + selected_base.append((standard_header_l, chosen_value)) + elif standard_header_l == header_l: + processed.add(header_l) + + if header_l in overridable_headers_keys: + chosen_value = native_headers.get(header_l, chosen_value) + + if chosen_value is not None: + selected_base.append((header, chosen_value)) + + break + + for standard_header_l, standard_value in standard_iterator: + if standard_value is not None: + selected_base.append((standard_header_l, standard_value)) + + extra_keys = header_keys(extra_headers) + extra_iterator = iter(extra_headers) + + result: list[tuple[str, str]] = [] + processed = set() + + for header, value in selected_base: + header_l = header.lower() + + if header_l in processed: + continue + + if header_l in extra_keys: + for extra_header, extra_value in extra_iterator: + extra_header_l = extra_header.lower() + + processed.add(extra_header_l) + + result.append((extra_header, extra_value)) + + if extra_header_l == header_l: + break + else: + result.append((header, value)) + + result.extend(extra_iterator) + + return result + +request_standard_headers: t.Iterable[tuple[str, t.Optional[str]]] = ( + ('user-agent', None), + ('accept', 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8'), + ('accept-language', 'en-US,en;q=0.5'), + ('accept-encoding', None), + ('dnt', '1'), + ('connection', None), + ('upgrade-insecure-requests', '1'), + ('sec-fetch-dest', 'document'), + ('sec-fetch-mode', 'navigate'), + ('sec-fetch-site', 'none'), + ('sec-fetch-user', '?1'), + ('te', 'trailers') +) + +auto_overridable_request_headers = frozenset(( + 'user-agent', + 'accept-language', + 'accept-encoding', + 'dnt' +)) + +def merge_request_headers( + native_headers: http_messages.IHeaders, + extra_headers: t.Iterable[tuple[str, str]] +) -> t.Iterable[tuple[str, str]]: + return _merge_headers( + standard_headers = request_standard_headers, + overridable_headers_keys = auto_overridable_request_headers, + native_headers = native_headers, + extra_headers = extra_headers + ) + +response_standard_headers: t.Iterable[tuple[str, t.Optional[str]]] = ( + ('cache-control', 'max-age=0, private, must-revalidate'), + ('connection', None), + ('content-length', None), + ('content-type', None), + ('date', None), + ('keep-alive', None), + ('server', None) +) + +auto_overridable_response_headers = frozenset( + header.lower() + for header, value in response_standard_headers + if value is None +) + +def merge_response_headers( + native_headers: http_messages.IHeaders, + extra_headers: t.Iterable[tuple[str, str]] +) -> t.Iterable[tuple[str, str]]: + return _merge_headers( + standard_headers = response_standard_headers, + overridable_headers_keys = auto_overridable_response_headers, + native_headers = native_headers, + extra_headers = extra_headers + ) + + +ProducedAny = t.Union[ + http_messages.ProducedResponse, + http_messages.ProducedRequest +] + @dc.dataclass(frozen=True) class PayloadResourcePolicy(PayloadAwarePolicy): """....""" - process_request: t.ClassVar[bool] = True + _process_request: t.ClassVar[bool] = True priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._THREE + def extract_resource_path(self, request_url: ParsedUrl) -> tuple[str, ...]: + # Payload resource pattern has path of the form: + # "/some/arbitrary/segments/<per-session_token>/***" + # + # Corresponding requests shall have path of the form: + # "/some/arbitrary/segments/<per-session_token>/actual/resource/path" + # + # Here we need to extract the "/actual/resource/path" part. + segments_to_drop = len(self.payload_data.pattern_path_segments) + 1 + return request_url.path_segments[segments_to_drop:] + + def should_process_response(self, request_url: ParsedUrl) -> bool: + return self.extract_resource_path(request_url) \ + == ('api', 'unrestricted_http') + def _make_file_resource_response(self, path: tuple[str, ...]) \ -> http_messages.ProducedResponse: """....""" @@ -95,29 +269,114 @@ class PayloadResourcePolicy(PayloadAwarePolicy): file_data.contents ) + def _make_api_response( + self, + path: tuple[str, ...], + request_info: http_messages.RequestInfo + ) -> ProducedAny: + if path[0] == 'page_init_script.js': + with jinja_lock: + template = jinja_env.get_template('page_init_script.js.jinja') + token = self.payload_data.unique_token + base_url = self.assets_base_url(request_info.url) + js = template.render( + unique_token_encoded = encode_string_for_js(token), + assets_base_url_encoded = encode_string_for_js(base_url) + ) + + return http_messages.ProducedResponse( + 200, + ((b'Content-Type', b'application/javascript'),), + js.encode() + ) + + if path[0] == 'unrestricted_http': + try: + assert self.payload_data.cors_bypass_allowed + + params = parse_qs(request_info.url.query) + target_url, = params['target_url'] + extra_headers_str, = params['extra_headers'] + + assert urlparse(target_url).scheme in ('http', 'https') + + extra_headers = json.loads(extra_headers_str) + assert isinstance(extra_headers, list) + for header, value in extra_headers: + assert isinstance(header, str) + assert isinstance(value, str) + + result_headers = merge_request_headers( + native_headers = request_info.headers, + extra_headers = extra_headers + ) + + result_headers_bytes = \ + [(h.encode(), v.encode()) for h, v in result_headers] + + return http_messages.ProducedRequest( + url = target_url, + method = request_info.method, + headers = result_headers_bytes, + body = request_info.body + ) + except: + return resource_blocked_response + else: + return resource_blocked_response + def consume_request(self, request_info: http_messages.RequestInfo) \ - -> http_messages.ProducedResponse: - """....""" - # Payload resource pattern has path of the form: - # "/some/arbitrary/segments/<per-session_token>/***" - # - # Corresponding requests shall have path of the form: - # "/some/arbitrary/segments/<per-session_token>/actual/resource/path" - # - # Here we need to extract the "/actual/resource/path" part. - segments_to_drop = len(self.payload_data.pattern_path_segments) + 1 - resource_path = request_info.url.path_segments[segments_to_drop:] + -> ProducedAny: + resource_path = self.extract_resource_path(request_info.url) if resource_path == (): return resource_blocked_response elif resource_path[0] == 'static': return self._make_file_resource_response(resource_path[1:]) elif resource_path[0] == 'api': - # TODO: implement Haketilo APIs - return resource_blocked_response + return self._make_api_response(resource_path[1:], request_info) else: return resource_blocked_response + def consume_response(self, response_info: http_messages.ResponseInfo) \ + -> http_messages.ProducedResponse: + """ + This method shall only be called for responses to unrestricted HTTP API + requests. Its purpose is to sanitize response headers and smuggle their + original data using an additional header. + """ + serialized = json.dumps([*response_info.headers.items()]) + extra_headers = [('X-Haketilo-True-Headers', quote(serialized)),] + + if (300 <= response_info.status_code < 400): + location = response_info.headers.get('location') + if location is not None: + orig_params = parse_qs(response_info.orig_url.query) + orig_extra_headers_str, = orig_params['extra_headers'] + + new_query = urlencode({ + 'target_url': location, + 'extra_headers': orig_extra_headers_str + }) + + new_url = urljoin( + response_info.orig_url.orig_url, + '?' + new_query + ) + + extra_headers.append(('location', new_url)) + + merged_headers = merge_response_headers( + native_headers = response_info.headers, + extra_headers = extra_headers + ) + + return http_messages.ProducedResponse( + status_code = response_info.status_code, + headers = [(h.encode(), v.encode()) for h, v in merged_headers], + body = response_info.body, + ) + resource_blocked_response = http_messages.ProducedResponse( 403, @@ -128,7 +387,7 @@ resource_blocked_response = http_messages.ProducedResponse( @dc.dataclass(frozen=True) class BlockedResponsePolicy(base.Policy): """....""" - process_request: t.ClassVar[bool] = True + _process_request: t.ClassVar[bool] = True priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._THREE diff --git a/src/hydrilla/proxy/policies/rule.py b/src/hydrilla/proxy/policies/rule.py index 6482e84..833d287 100644 --- a/src/hydrilla/proxy/policies/rule.py +++ b/src/hydrilla/proxy/policies/rule.py @@ -48,7 +48,7 @@ class AllowPolicy(base.Policy): class BlockPolicy(base.Policy): """....""" - process_response: t.ClassVar[bool] = True + _process_response: t.ClassVar[bool] = True priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._TWO diff --git a/src/hydrilla/proxy/policies/web_ui.py b/src/hydrilla/proxy/policies/web_ui.py index 2b1ae02..9f6c0f5 100644 --- a/src/hydrilla/proxy/policies/web_ui.py +++ b/src/hydrilla/proxy/policies/web_ui.py @@ -45,7 +45,7 @@ from . import base @dc.dataclass(frozen=True) class WebUIPolicy(base.Policy): """....""" - process_request: t.ClassVar[bool] = True + _process_request: t.ClassVar[bool] = True priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._THREE |