aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/policies/payload_resource.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/policies/payload_resource.py')
-rw-r--r--src/hydrilla/proxy/policies/payload_resource.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py
index d53f1f7..ae6a490 100644
--- a/src/hydrilla/proxy/policies/payload_resource.py
+++ b/src/hydrilla/proxy/policies/payload_resource.py
@@ -242,8 +242,12 @@ class PayloadResourcePolicy(PayloadAwarePolicy):
segments_to_drop = len(self.payload_data.pattern_path_segments) + 1
return request_url.path_segments[segments_to_drop:]
- def should_process_response(self, request_url: ParsedUrl) -> bool:
- return self.extract_resource_path(request_url) \
+ def should_process_response(
+ self,
+ request_info: http_messages.RequestInfo,
+ response_info: http_messages.BodylessResponseInfo
+ ) -> bool:
+ return self.extract_resource_path(request_info.url) \
== ('api', 'unrestricted_http')
def _make_file_resource_response(self, path: tuple[str, ...]) \
@@ -338,8 +342,11 @@ class PayloadResourcePolicy(PayloadAwarePolicy):
else:
return resource_blocked_response
- def consume_response(self, response_info: http_messages.ResponseInfo) \
- -> http_messages.ProducedResponse:
+ def consume_response(
+ self,
+ request_info: http_messages.RequestInfo,
+ response_info: http_messages.ResponseInfo
+ ) -> http_messages.ProducedResponse:
"""
This method shall only be called for responses to unrestricted HTTP API
requests. Its purpose is to sanitize response headers and smuggle their
@@ -351,7 +358,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy):
if (300 <= response_info.status_code < 400):
location = response_info.headers.get('location')
if location is not None:
- orig_params = parse_qs(response_info.orig_url.query)
+ orig_params = parse_qs(request_info.url.query)
orig_extra_headers_str, = orig_params['extra_headers']
new_query = urlencode({
@@ -359,10 +366,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy):
'extra_headers': orig_extra_headers_str
})
- new_url = urljoin(
- response_info.orig_url.orig_url,
- '?' + new_query
- )
+ new_url = urljoin(request_info.url.orig_url, '?' + new_query)
extra_headers.append(('location', new_url))