diff options
Diffstat (limited to 'src/hydrilla/proxy/policies')
-rw-r--r-- | src/hydrilla/proxy/policies/base.py | 18 | ||||
-rw-r--r-- | src/hydrilla/proxy/policies/misc.py | 7 | ||||
-rw-r--r-- | src/hydrilla/proxy/policies/payload.py | 17 | ||||
-rw-r--r-- | src/hydrilla/proxy/policies/payload_resource.py | 22 | ||||
-rw-r--r-- | src/hydrilla/proxy/policies/rule.py | 8 |
5 files changed, 48 insertions, 24 deletions
diff --git a/src/hydrilla/proxy/policies/base.py b/src/hydrilla/proxy/policies/base.py index b3d3172..fcdbf9d 100644 --- a/src/hydrilla/proxy/policies/base.py +++ b/src/hydrilla/proxy/policies/base.py @@ -61,10 +61,17 @@ class Policy(ABC): priority: t.ClassVar[PolicyPriority] - def should_process_request(self, parsed_url: ParsedUrl) -> bool: + def should_process_request( + self, + request_info: http_messages.BodylessRequestInfo + ) -> bool: return self._process_request - def should_process_response(self, parsed_url: ParsedUrl) -> bool: + def should_process_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.BodylessResponseInfo + ) -> bool: return self._process_response def consume_request(self, request_info: http_messages.RequestInfo) \ @@ -73,8 +80,11 @@ class Policy(ABC): 'This kind of policy does not consume requests.' ) - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> t.Optional[http_messages.ProducedResponse]: + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> t.Optional[http_messages.ProducedResponse]: raise NotImplementedError( 'This kind of policy does not consume responses.' ) diff --git a/src/hydrilla/proxy/policies/misc.py b/src/hydrilla/proxy/policies/misc.py index 6d1e032..71692b3 100644 --- a/src/hydrilla/proxy/policies/misc.py +++ b/src/hydrilla/proxy/policies/misc.py @@ -71,8 +71,11 @@ class DoNothingPolicy(base.Policy): def consume_request(self, request_info: http_messages.RequestInfo) -> None: return None - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> None: + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> None: return None builtin: bool = True diff --git a/src/hydrilla/proxy/policies/payload.py b/src/hydrilla/proxy/policies/payload.py index d3e8e30..5b71af7 100644 --- a/src/hydrilla/proxy/policies/payload.py +++ b/src/hydrilla/proxy/policies/payload.py @@ -257,9 +257,11 @@ class PayloadInjectPolicy(PayloadAwarePolicy): return dc.replace(new_response, body=new_data) - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> http_messages.ProducedResponse: - """....""" + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> http_messages.ProducedResponse: try: return self._consume_response_unsafe(response_info) except Exception as e: @@ -286,15 +288,18 @@ class AutoPayloadInjectPolicy(PayloadInjectPolicy): """....""" priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> http_messages.ProducedResponse: + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> http_messages.ProducedResponse: try: if self.payload_data.ref.has_problems(): raise _PayloadHasProblemsError() self.payload_data.ref.ensure_items_installed() - return super().consume_response(response_info) + return super().consume_response(request_info, response_info) except (state.RepoCommunicationError, state.FileInstallationError, _PayloadHasProblemsError) as ex: extra_params: dict[str, str] = { diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py index d53f1f7..ae6a490 100644 --- a/src/hydrilla/proxy/policies/payload_resource.py +++ b/src/hydrilla/proxy/policies/payload_resource.py @@ -242,8 +242,12 @@ class PayloadResourcePolicy(PayloadAwarePolicy): segments_to_drop = len(self.payload_data.pattern_path_segments) + 1 return request_url.path_segments[segments_to_drop:] - def should_process_response(self, request_url: ParsedUrl) -> bool: - return self.extract_resource_path(request_url) \ + def should_process_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.BodylessResponseInfo + ) -> bool: + return self.extract_resource_path(request_info.url) \ == ('api', 'unrestricted_http') def _make_file_resource_response(self, path: tuple[str, ...]) \ @@ -338,8 +342,11 @@ class PayloadResourcePolicy(PayloadAwarePolicy): else: return resource_blocked_response - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> http_messages.ProducedResponse: + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> http_messages.ProducedResponse: """ This method shall only be called for responses to unrestricted HTTP API requests. Its purpose is to sanitize response headers and smuggle their @@ -351,7 +358,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy): if (300 <= response_info.status_code < 400): location = response_info.headers.get('location') if location is not None: - orig_params = parse_qs(response_info.orig_url.query) + orig_params = parse_qs(request_info.url.query) orig_extra_headers_str, = orig_params['extra_headers'] new_query = urlencode({ @@ -359,10 +366,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy): 'extra_headers': orig_extra_headers_str }) - new_url = urljoin( - response_info.orig_url.orig_url, - '?' + new_query - ) + new_url = urljoin(request_info.url.orig_url, '?' + new_query) extra_headers.append(('location', new_url)) diff --git a/src/hydrilla/proxy/policies/rule.py b/src/hydrilla/proxy/policies/rule.py index b742a64..2e9443e 100644 --- a/src/hydrilla/proxy/policies/rule.py +++ b/src/hydrilla/proxy/policies/rule.py @@ -79,9 +79,11 @@ class BlockPolicy(base.Policy): yield b'Content-Security-Policy', extra_csp.encode() - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> http_messages.ProducedResponse: - """....""" + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> http_messages.ProducedResponse: new_response = response_info.make_produced_response() new_headers = self._modify_headers(response_info) |