diff options
Diffstat (limited to 'src/hydrilla/proxy/policies/payload_resource.py')
-rw-r--r-- | src/hydrilla/proxy/policies/payload_resource.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py index d53f1f7..ae6a490 100644 --- a/src/hydrilla/proxy/policies/payload_resource.py +++ b/src/hydrilla/proxy/policies/payload_resource.py @@ -242,8 +242,12 @@ class PayloadResourcePolicy(PayloadAwarePolicy): segments_to_drop = len(self.payload_data.pattern_path_segments) + 1 return request_url.path_segments[segments_to_drop:] - def should_process_response(self, request_url: ParsedUrl) -> bool: - return self.extract_resource_path(request_url) \ + def should_process_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.BodylessResponseInfo + ) -> bool: + return self.extract_resource_path(request_info.url) \ == ('api', 'unrestricted_http') def _make_file_resource_response(self, path: tuple[str, ...]) \ @@ -338,8 +342,11 @@ class PayloadResourcePolicy(PayloadAwarePolicy): else: return resource_blocked_response - def consume_response(self, response_info: http_messages.ResponseInfo) \ - -> http_messages.ProducedResponse: + def consume_response( + self, + request_info: http_messages.RequestInfo, + response_info: http_messages.ResponseInfo + ) -> http_messages.ProducedResponse: """ This method shall only be called for responses to unrestricted HTTP API requests. Its purpose is to sanitize response headers and smuggle their @@ -351,7 +358,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy): if (300 <= response_info.status_code < 400): location = response_info.headers.get('location') if location is not None: - orig_params = parse_qs(response_info.orig_url.query) + orig_params = parse_qs(request_info.url.query) orig_extra_headers_str, = orig_params['extra_headers'] new_query = urlencode({ @@ -359,10 +366,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy): 'extra_headers': orig_extra_headers_str }) - new_url = urljoin( - response_info.orig_url.orig_url, - '?' + new_query - ) + new_url = urljoin(request_info.url.orig_url, '?' + new_query) extra_headers.append(('location', new_url)) |