From 3beab050c339c51c484af9bcd8248ba8ebbbf4d4 Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Wed, 19 Oct 2022 18:55:08 +0200 Subject: [proxy] pass all available flow information to relevant methods of Policy type --- src/hydrilla/proxy/addon.py | 142 ++++++++++++++---------- src/hydrilla/proxy/http_messages.py | 24 ++-- src/hydrilla/proxy/policies/base.py | 18 ++- src/hydrilla/proxy/policies/misc.py | 7 +- src/hydrilla/proxy/policies/payload.py | 17 ++- src/hydrilla/proxy/policies/payload_resource.py | 22 ++-- src/hydrilla/proxy/policies/rule.py | 8 +- 7 files changed, 148 insertions(+), 90 deletions(-) diff --git a/src/hydrilla/proxy/addon.py b/src/hydrilla/proxy/addon.py index cca7924..ae03ecc 100644 --- a/src/hydrilla/proxy/addon.py +++ b/src/hydrilla/proxy/addon.py @@ -89,10 +89,57 @@ class LoggerToMitmproxy(state.Logger): ctx.log.warn(f'Haketilo: {msg}') -@dc.dataclass(frozen=True) -class FlowHandlingData: - request_url: ParsedUrl - policy: policies.Policy +@dc.dataclass +class FlowHandling: + flow: http.HTTPFlow + policy: policies.Policy + _bl_request_info: http_messages.BodylessRequestInfo + _request_info: t.Optional[http_messages.RequestInfo] = None + _bl_response_info: t.Optional[http_messages.BodylessResponseInfo] = None + + @property + def bl_request_info(self) -> http_messages.BodylessRequestInfo: + return self._bl_request_info + + @property + def request_info(self) -> http_messages.RequestInfo: + if self._request_info is None: + body = self.flow.request.get_content(strict=False) or b'' + self._request_info = self._bl_request_info.with_body(body) + + return self._request_info + + @property + def bl_response_info(self) -> http_messages.BodylessResponseInfo: + if self._bl_response_info is None: + assert self.flow.response is not None + + headers = self.flow.response.headers + self._bl_response_info = http_messages.BodylessResponseInfo( + url = parse_url(self.flow.request.url), + status_code = self.flow.response.status_code, + headers = MitmproxyHeadersWrapper(headers) + ) + + return self._bl_response_info + + @property + def response_info(self) -> http_messages.ResponseInfo: + assert self.flow.response is not None + + body = self.flow.response.get_content(strict=False) or b'' + return self.bl_response_info.with_body(body) + + @staticmethod + def make(flow: http.HTTPFlow, policy: policies.Policy, url: ParsedUrl) \ + -> 'FlowHandling': + bl_request_info = http_messages.BodylessRequestInfo( + url = url, + method = flow.request.method, + headers = MitmproxyHeadersWrapper(flow.request.headers) + ) + + return FlowHandling(flow, policy, bl_request_info) @dc.dataclass @@ -120,8 +167,8 @@ class HaketiloAddon: configured: bool = False configured_lock: Lock = dc.field(default_factory=Lock) - flows_data: dict[int, FlowHandlingData] = dc.field(default_factory=dict) - flows_data_lock: Lock = dc.field(default_factory=Lock) + handling_dict: dict[int, FlowHandling] = dc.field(default_factory=dict) + handling_dict_lock: Lock = dc.field(default_factory=Lock) logger: LoggerToMitmproxy = dc.field(default_factory=LoggerToMitmproxy) @@ -201,15 +248,15 @@ class HaketiloAddon: if not self.state.launch_browser(): self.logger.warn(_('warn.proxy.couldnt_launch_browser')) - def get_handling_data(self, flow: http.HTTPFlow) -> FlowHandlingData: + def get_flow_handling(self, flow: http.HTTPFlow) -> FlowHandling: policy: policies.Policy assert self.state is not None - with self.flows_data_lock: - handling_data = self.flows_data.get(id(flow)) + with self.handling_dict_lock: + handling = self.handling_dict.get(id(flow)) - if handling_data is None: + if handling is None: parsed_url = dummy_url if magical_mitm_it_url_reg.match(flow.request.url): @@ -221,17 +268,16 @@ class HaketiloAddon: except HaketiloException as e: policy = policies.ErrorBlockPolicy(builtin=True, error=e) - handling_data = FlowHandlingData(parsed_url, policy) + handling = FlowHandling.make(flow, policy, parsed_url) - with self.flows_data_lock: - self.flows_data[id(flow)] = handling_data + with self.handling_dict_lock: + self.handling_dict[id(flow)] = handling - return handling_data + return handling - def forget_handling_data(self, flow: http.HTTPFlow) -> None: - """....""" - with self.flows_data_lock: - self.flows_data.pop(id(flow), None) + def forget_flow_handling(self, flow: http.HTTPFlow) -> None: + with self.handling_dict_lock: + self.handling_dict.pop(id(flow), None) @contextmanager def http_safe_event_handling(self, flow: http.HTTPFlow) -> t.Iterator: @@ -252,7 +298,7 @@ class HaketiloAddon: headers = [(b'Content-Type', b'text/plain; charset=utf-8')] ) - self.forget_handling_data(flow) + self.forget_flow_handling(flow) @concurrent def requestheaders(self, flow: http.HTTPFlow) -> None: @@ -265,33 +311,23 @@ class HaketiloAddon: # visited before. flow.request.headers.pop('referer', None) - handling_data = self.get_handling_data(flow) - policy = handling_data.policy + handling = self.get_flow_handling(flow) + policy = handling.policy - if not policy.should_process_request(handling_data.request_url): + if not policy.should_process_request(handling.bl_request_info): flow.request.stream = True if policy.anticache: flow.request.anticache() @concurrent def request(self, flow: http.HTTPFlow) -> None: - """ - .... - """ if flow.request.stream: return with self.http_safe_event_handling(flow): - handling_data = self.get_handling_data(flow) + handling = self.get_flow_handling(flow) - request_info = http_messages.RequestInfo( - url = handling_data.request_url, - method = flow.request.method, - headers = MitmproxyHeadersWrapper(flow.request.headers), - body = flow.request.get_content(strict=False) or b'' - ) - - result = handling_data.policy.consume_request(request_info) + result = handling.policy.consume_request(handling.request_info) if result is not None: if isinstance(result, http_messages.ProducedRequest): @@ -303,51 +339,42 @@ class HaketiloAddon: # isinstance(result, http_messages.ProducedResponse) flow.response = http.Response.make( status_code = result.status_code, - headers = http.Headers(result.headers), - content = result.body + headers = http.Headers(result.headers), + content = result.body ) def responseheaders(self, flow: http.HTTPFlow) -> None: - """ - ...... - """ assert flow.response is not None with self.http_safe_event_handling(flow): - handling_data = self.get_handling_data(flow) - policy = handling_data.policy + handling = self.get_flow_handling(flow) - if not policy.should_process_response(handling_data.request_url): + if not handling.policy.should_process_response( + request_info = handling.request_info, + response_info = handling.bl_response_info + ): flow.response.stream = True @concurrent def response(self, flow: http.HTTPFlow) -> None: - """ - ...... - """ assert flow.response is not None if flow.response.stream: return with self.http_safe_event_handling(flow): - handling_data = self.get_handling_data(flow) - - response_info = http_messages.ResponseInfo( - url = parse_url(flow.request.url), - orig_url = handling_data.request_url, - status_code = flow.response.status_code, - headers = MitmproxyHeadersWrapper(flow.response.headers), - body = flow.response.get_content(strict=False) or b'' - ) + handling = self.get_flow_handling(flow) - result = handling_data.policy.consume_response(response_info) + result = handling.policy.consume_response( + request_info = handling.request_info, + response_info = handling.response_info + ) if result is not None: flow.response.status_code = result.status_code flow.response.headers = http.Headers(result.headers) flow.response.set_content(result.body) - self.forget_handling_data(flow) + self.forget_flow_handling(flow) def tls_clienthello(self, data: tls.ClientHelloData): if data.context.server.address is None: @@ -361,5 +388,4 @@ class HaketiloAddon: data.establish_server_tls_first = True def error(self, flow: http.HTTPFlow) -> None: - """....""" - self.forget_handling_data(flow) + self.forget_flow_handling(flow) diff --git a/src/hydrilla/proxy/http_messages.py b/src/hydrilla/proxy/http_messages.py index 78baf81..dbf2c63 100644 --- a/src/hydrilla/proxy/http_messages.py +++ b/src/hydrilla/proxy/http_messages.py @@ -74,15 +74,19 @@ class ProducedRequest: body: bytes @dc.dataclass(frozen=True) -class RequestInfo: - """....""" +class BodylessRequestInfo: url: url_patterns.ParsedUrl method: str headers: IHeaders - body: bytes + + def with_body(self, body: bytes) -> 'RequestInfo': + return RequestInfo(self.url, self.method, self.headers, body) + +@dc.dataclass(frozen=True) +class RequestInfo(BodylessRequestInfo): + body: bytes def make_produced_request(self) -> ProducedRequest: - """....""" return ProducedRequest( url = self.url.orig_url, method = self.method, @@ -98,16 +102,20 @@ class ProducedResponse: body: bytes @dc.dataclass(frozen=True) -class ResponseInfo: +class BodylessResponseInfo: """....""" url: url_patterns.ParsedUrl - orig_url: url_patterns.ParsedUrl status_code: int headers: IHeaders - body: bytes + + def with_body(self, body: bytes) -> 'ResponseInfo': + return ResponseInfo(self.url, self.status_code, self.headers, body) + +@dc.dataclass(frozen=True) +class ResponseInfo(BodylessResponseInfo): + body: bytes def make_produced_response(self) -> ProducedResponse: - """....""" return ProducedResponse( status_code = self.status_code, headers = encode_headers_items(self.headers.items()), diff --git a/src/hydrilla/proxy/policies/base.py b/src/hydrilla/proxy/policies/base.py index b3d3172..fcdbf9d 100644 --- a/src/hydrilla/proxy/policies/base.py +++ b/src/hydrilla/proxy/policies/base.py @@ -61,10 +61,17 @@ class Policy(ABC): priority: t.ClassVar[PolicyPriority] - def should_process_request(self, parsed_url: ParsedUrl) -> bool: + def should_process_request( + self, + request_info: http_messages.BodylessRequestInfo + ) -> bool: return self._process_request - def should_process_response(self, parsed_url: ParsedUrl) -> bool: + def should_process_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.BodylessResponseInfo + ) -> bool: return self._process_response def consume_request(self, request_info: http_messages.RequestInfo) \ @@ -73,8 +80,11 @@ class Policy(ABC): 'This kind of policy does not consume requests.' ) - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> t.Optional[http_messages.ProducedResponse]: + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> t.Optional[http_messages.ProducedResponse]: raise NotImplementedError( 'This kind of policy does not consume responses.' ) diff --git a/src/hydrilla/proxy/policies/misc.py b/src/hydrilla/proxy/policies/misc.py index 6d1e032..71692b3 100644 --- a/src/hydrilla/proxy/policies/misc.py +++ b/src/hydrilla/proxy/policies/misc.py @@ -71,8 +71,11 @@ class DoNothingPolicy(base.Policy): def consume_request(self, request_info: http_messages.RequestInfo) -> None: return None - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> None: + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> None: return None builtin: bool = True diff --git a/src/hydrilla/proxy/policies/payload.py b/src/hydrilla/proxy/policies/payload.py index d3e8e30..5b71af7 100644 --- a/src/hydrilla/proxy/policies/payload.py +++ b/src/hydrilla/proxy/policies/payload.py @@ -257,9 +257,11 @@ class PayloadInjectPolicy(PayloadAwarePolicy): return dc.replace(new_response, body=new_data) - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> http_messages.ProducedResponse: - """....""" + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> http_messages.ProducedResponse: try: return self._consume_response_unsafe(response_info) except Exception as e: @@ -286,15 +288,18 @@ class AutoPayloadInjectPolicy(PayloadInjectPolicy): """....""" priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> http_messages.ProducedResponse: + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> http_messages.ProducedResponse: try: if self.payload_data.ref.has_problems(): raise _PayloadHasProblemsError() self.payload_data.ref.ensure_items_installed() - return super().consume_response(response_info) + return super().consume_response(request_info, response_info) except (state.RepoCommunicationError, state.FileInstallationError, _PayloadHasProblemsError) as ex: extra_params: dict[str, str] = { diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py index d53f1f7..ae6a490 100644 --- a/src/hydrilla/proxy/policies/payload_resource.py +++ b/src/hydrilla/proxy/policies/payload_resource.py @@ -242,8 +242,12 @@ class PayloadResourcePolicy(PayloadAwarePolicy): segments_to_drop = len(self.payload_data.pattern_path_segments) + 1 return request_url.path_segments[segments_to_drop:] - def should_process_response(self, request_url: ParsedUrl) -> bool: - return self.extract_resource_path(request_url) \ + def should_process_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.BodylessResponseInfo + ) -> bool: + return self.extract_resource_path(request_info.url) \ == ('api', 'unrestricted_http') def _make_file_resource_response(self, path: tuple[str, ...]) \ @@ -338,8 +342,11 @@ class PayloadResourcePolicy(PayloadAwarePolicy): else: return resource_blocked_response - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> http_messages.ProducedResponse: + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> http_messages.ProducedResponse: """ This method shall only be called for responses to unrestricted HTTP API requests. Its purpose is to sanitize response headers and smuggle their @@ -351,7 +358,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy): if (300 <= response_info.status_code < 400): location = response_info.headers.get('location') if location is not None: - orig_params = parse_qs(response_info.orig_url.query) + orig_params = parse_qs(request_info.url.query) orig_extra_headers_str, = orig_params['extra_headers'] new_query = urlencode({ @@ -359,10 +366,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy): 'extra_headers': orig_extra_headers_str }) - new_url = urljoin( - response_info.orig_url.orig_url, - '?' + new_query - ) + new_url = urljoin(request_info.url.orig_url, '?' + new_query) extra_headers.append(('location', new_url)) diff --git a/src/hydrilla/proxy/policies/rule.py b/src/hydrilla/proxy/policies/rule.py index b742a64..2e9443e 100644 --- a/src/hydrilla/proxy/policies/rule.py +++ b/src/hydrilla/proxy/policies/rule.py @@ -79,9 +79,11 @@ class BlockPolicy(base.Policy): yield b'Content-Security-Policy', extra_csp.encode() - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> http_messages.ProducedResponse: - """....""" + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> http_messages.ProducedResponse: new_response = response_info.make_produced_response() new_headers = self._modify_headers(response_info) -- cgit v1.2.3