aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/policies/payload.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/policies/payload.py')
-rw-r--r--src/hydrilla/proxy/policies/payload.py94
1 files changed, 26 insertions, 68 deletions
diff --git a/src/hydrilla/proxy/policies/payload.py b/src/hydrilla/proxy/policies/payload.py
index 5b71af7..b89a1c1 100644
--- a/src/hydrilla/proxy/policies/payload.py
+++ b/src/hydrilla/proxy/policies/payload.py
@@ -31,7 +31,6 @@
import dataclasses as dc
import typing as t
-import re
from urllib.parse import urlencode
@@ -91,45 +90,6 @@ class PayloadAwarePolicyFactory(base.PolicyFactory):
return super().__lt__(other)
-# For details of 'Content-Type' header's structure, see:
-# https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.1
-content_type_reg = re.compile(r'''
-^
-(?P<mime>[\w-]+/[\w-]+)
-\s*
-(?:
- ;
- (?:[^;]*;)* # match possible parameter other than "charset"
-)
-\s*
-charset= # no whitespace allowed in parameter as per RFC
-(?P<encoding>
- [\w-]+
- |
- "[\w-]+" # quotes are optional per RFC
-)
-(?:;[^;]+)* # match possible parameter other than "charset"
-$ # forbid possible dangling characters after closing '"'
-''', re.VERBOSE | re.IGNORECASE)
-
-def deduce_content_type(headers: http_messages.IHeaders) \
- -> tuple[t.Optional[str], t.Optional[str]]:
- """...."""
- content_type = headers.get('content-type')
- if content_type is None:
- return (None, None)
-
- match = content_type_reg.match(content_type)
- if match is None:
- return (None, None)
-
- mime, encoding = match.group('mime'), match.group('encoding')
-
- if encoding is not None:
- encoding = encoding.lower()
-
- return mime, encoding
-
UTF8_BOM = b'\xEF\xBB\xBF'
BOMs = (
(UTF8_BOM, 'utf-8'),
@@ -174,15 +134,17 @@ class PayloadInjectPolicy(PayloadAwarePolicy):
))
def _modify_headers(self, response_info: http_messages.ResponseInfo) \
- -> t.Iterable[tuple[bytes, bytes]]:
- """...."""
- for header_name, header_value in response_info.headers.items():
- if header_name.lower() not in csp.header_names_and_dispositions:
- yield header_name.encode(), header_value.encode()
+ -> http_messages.IHeaders:
+ new_headers = []
+
+ for key, val in response_info.headers.items():
+ if key.lower() not in csp.header_names_and_dispositions:
+ new_headers.append((key, val))
new_csp = self._new_csp(response_info.url)
+ new_headers.append(('Content-Security-Policy', new_csp))
- yield b'Content-Security-Policy', new_csp.encode()
+ return http_messages.make_headers(new_headers)
def _script_urls(self, url: ParsedUrl) -> t.Iterable[str]:
"""...."""
@@ -231,22 +193,18 @@ class PayloadInjectPolicy(PayloadAwarePolicy):
def _consume_response_unsafe(
self,
+ request_info: http_messages.RequestInfo,
response_info: http_messages.ResponseInfo
- ) -> http_messages.ProducedResponse:
- """...."""
- new_response = response_info.make_produced_response()
-
+ ) -> http_messages.ResponseInfo:
new_headers = self._modify_headers(response_info)
+ new_response = dc.replace(response_info, headers=new_headers)
- new_response = dc.replace(new_response, headers=new_headers)
-
- mime, encoding = deduce_content_type(response_info.headers)
- if mime is None or 'html' not in mime.lower():
+ if not http_messages.is_likely_a_page(request_info, response_info):
return new_response
data = response_info.body
- if data is None:
- data = b''
+
+ _, encoding = response_info.deduce_content_type()
# A UTF BOM overrides encoding specified by the header.
for bom, encoding_name in BOMs:
@@ -261,9 +219,9 @@ class PayloadInjectPolicy(PayloadAwarePolicy):
self,
request_info: http_messages.RequestInfo,
response_info: http_messages.ResponseInfo
- ) -> http_messages.ProducedResponse:
+ ) -> http_messages.ResponseInfo:
try:
- return self._consume_response_unsafe(response_info)
+ return self._consume_response_unsafe(request_info, response_info)
except Exception as e:
# TODO: actually describe the errors
import traceback
@@ -274,10 +232,10 @@ class PayloadInjectPolicy(PayloadAwarePolicy):
e.__traceback__
)
- return http_messages.ProducedResponse(
- 500,
- ((b'Content-Type', b'text/plain; charset=utf-8'),),
- '\n'.join(error_info_list).encode()
+ return http_messages.ResponseInfo.make(
+ status_code = 500,
+ headers = (('Content-Type', 'text/plain; charset=utf-8'),),
+ body = '\n'.join(error_info_list).encode()
)
@@ -292,7 +250,7 @@ class AutoPayloadInjectPolicy(PayloadInjectPolicy):
self,
request_info: http_messages.RequestInfo,
response_info: http_messages.ResponseInfo
- ) -> http_messages.ProducedResponse:
+ ) -> http_messages.ResponseInfo:
try:
if self.payload_data.ref.has_problems():
raise _PayloadHasProblemsError()
@@ -317,9 +275,9 @@ class AutoPayloadInjectPolicy(PayloadInjectPolicy):
redirect_url = 'https://hkt.mitm.it/auto_install_error?' + query
msg = 'Error occured when installing payload. Redirecting.'
- return http_messages.ProducedResponse(
+ return http_messages.ResponseInfo.make(
status_code = 303,
- headers = [(b'Location', redirect_url.encode())],
+ headers = [('Location', redirect_url)],
body = msg.encode()
)
@@ -332,7 +290,7 @@ class PayloadSuggestPolicy(PayloadAwarePolicy):
priority: t.ClassVar[base.PolicyPriority] = base.PolicyPriority._ONE
def consume_request(self, request_info: http_messages.RequestInfo) \
- -> http_messages.ProducedResponse:
+ -> http_messages.ResponseInfo:
query = self._payload_details_to_signed_query_string(
_salt = 'package_suggestion',
next_url = request_info.url.orig_url
@@ -341,9 +299,9 @@ class PayloadSuggestPolicy(PayloadAwarePolicy):
redirect_url = 'https://hkt.mitm.it/package_suggestion?' + query
msg = 'A package was found that could be used on this site. Redirecting.'
- return http_messages.ProducedResponse(
+ return http_messages.ResponseInfo.make(
status_code = 303,
- headers = [(b'Location', redirect_url.encode())],
+ headers = [('Location', redirect_url)],
body = msg.encode()
)