aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/policies/payload_resource.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/policies/payload_resource.py')
-rw-r--r--src/hydrilla/proxy/policies/payload_resource.py32
1 files changed, 14 insertions, 18 deletions
diff --git a/src/hydrilla/proxy/policies/payload_resource.py b/src/hydrilla/proxy/policies/payload_resource.py
index 04a148c..6695ce1 100644
--- a/src/hydrilla/proxy/policies/payload_resource.py
+++ b/src/hydrilla/proxy/policies/payload_resource.py
@@ -245,7 +245,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy):
def should_process_response(
self,
request_info: http_messages.RequestInfo,
- response_info: http_messages.BodylessResponseInfo
+ response_info: http_messages.AnyResponseInfo
) -> bool:
return self.extract_resource_path(request_info.url) \
== ('api', 'unrestricted_http')
@@ -279,7 +279,7 @@ class PayloadResourcePolicy(PayloadAwarePolicy):
with jinja_lock:
template = jinja_env.get_template('page_init_script.js.jinja')
token = self.payload_data.unique_token
- base_url = self.assets_base_url(request_info.url)
+ base_url = self._assets_base_url(request_info.url)
ver_str = json.dumps(haketilo_version)
js = template.render(
unique_token_encoded = encode_string_for_js(token),
@@ -338,23 +338,22 @@ class PayloadResourcePolicy(PayloadAwarePolicy):
else:
return resource_blocked_response
- def consume_response(
- self,
- request_info: http_messages.RequestInfo,
- response_info: http_messages.ResponseInfo
- ) -> http_messages.ResponseInfo:
+ def consume_response(self, http_info: http_messages.FullHTTPInfo) \
+ -> http_messages.ResponseInfo:
"""
This method shall only be called for responses to unrestricted HTTP API
requests. Its purpose is to sanitize response headers and smuggle their
original data using an additional header.
"""
- serialized = json.dumps([*response_info.headers.items()])
+ serialized = json.dumps([*http_info.response_info.headers.items()])
extra_headers = [('X-Haketilo-True-Headers', quote(serialized)),]
- if (300 <= response_info.status_code < 400):
- location = response_info.headers.get('location')
+ # Greetings, adventurous code dweller! It's amazing you made it that
+ # deep. I hope you're having a good day. If not, read Isaiah 49:15 :)
+ if (300 <= http_info.response_info.status_code < 400):
+ location = http_info.response_info.headers.get('location')
if location is not None:
- orig_params = parse_qs(request_info.url.query)
+ orig_params = parse_qs(http_info.request_info.url.query)
orig_extra_headers_str, = orig_params['extra_headers']
new_query = urlencode({
@@ -362,20 +361,17 @@ class PayloadResourcePolicy(PayloadAwarePolicy):
'extra_headers': orig_extra_headers_str
})
- new_url = urljoin(request_info.url.orig_url, '?' + new_query)
+ orig_url = http_info.request_info.url.orig_url
+ new_url = urljoin(orig_url, '?' + new_query)
extra_headers.append(('location', new_url))
merged_headers = merge_response_headers(
- native_headers = response_info.headers,
+ native_headers = http_info.response_info.headers,
extra_headers = extra_headers
)
- return http_messages.ResponseInfo.make(
- status_code = response_info.status_code,
- headers = merged_headers,
- body = response_info.body,
- )
+ return dc.replace(http_info.response_info, headers=merged_headers)
resource_blocked_response = http_messages.ResponseInfo.make(