aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/policies/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/policies/base.py')
-rw-r--r--src/hydrilla/proxy/policies/base.py114
1 files changed, 106 insertions, 8 deletions
diff --git a/src/hydrilla/proxy/policies/base.py b/src/hydrilla/proxy/policies/base.py
index 8ea792f..7ce8663 100644
--- a/src/hydrilla/proxy/policies/base.py
+++ b/src/hydrilla/proxy/policies/base.py
@@ -40,6 +40,7 @@ from immutables import Map
from ... url_patterns import ParsedUrl
from .. import state
from .. import http_messages
+from .. import csp
class PolicyPriority(int, enum.Enum):
@@ -53,6 +54,15 @@ MessageInfo = t.Union[
http_messages.ResponseInfo
]
+
+UTF8_BOM = b'\xEF\xBB\xBF'
+BOMs = (
+ (UTF8_BOM, 'utf-8'),
+ (b'\xFE\xFF', 'utf-16be'),
+ (b'\xFF\xFE', 'utf-16le')
+)
+
+
class Policy(ABC):
"""...."""
_process_request: t.ClassVar[bool] = False
@@ -70,23 +80,111 @@ class Policy(ABC):
def should_process_response(
self,
request_info: http_messages.RequestInfo,
- response_info: http_messages.BodylessResponseInfo
+ response_info: http_messages.AnyResponseInfo
) -> bool:
return self._process_response
+ def _csp_to_clear(self, http_info: http_messages.FullHTTPInfo) \
+ -> t.Union[t.Sequence[str], t.Literal['all']]:
+ return ()
+
+ def _csp_to_add(self, http_info: http_messages.FullHTTPInfo) \
+ -> t.Mapping[str, t.Sequence[str]]:
+ return Map()
+
+ def _csp_to_extend(self, http_info: http_messages.FullHTTPInfo) \
+ -> t.Mapping[str, t.Sequence[str]]:
+ return Map()
+
+ def _modify_response_headers(self, http_info: http_messages.FullHTTPInfo) \
+ -> http_messages.IHeaders:
+ csp_to_clear = self._csp_to_clear(http_info)
+ csp_to_add = self._csp_to_add(http_info)
+ csp_to_extend = self._csp_to_extend(http_info)
+
+ if len(csp_to_clear) + len(csp_to_extend) + len(csp_to_add) == 0:
+ return http_info.response_info.headers
+
+ return csp.modify(
+ headers = http_info.response_info.headers,
+ clear = csp_to_clear,
+ add = csp_to_add,
+ extend = csp_to_extend
+ )
+
+ def _modify_response_document(
+ self,
+ http_info: http_messages.FullHTTPInfo,
+ encoding: t.Optional[str]
+ ) -> t.Union[str, bytes]:
+ return http_info.response_info.body
+
+ def _modify_response_body(self, http_info: http_messages.FullHTTPInfo) \
+ -> bytes:
+ if not http_messages.is_likely_a_page(
+ request_info = http_info.request_info,
+ response_info = http_info.response_info
+ ):
+ return http_info.response_info.body
+
+ data = http_info.response_info.body
+
+ _, encoding = http_info.response_info.deduce_content_type()
+
+ # A UTF BOM overrides encoding specified by the header.
+ for bom, encoding_name in BOMs:
+ if data.startswith(bom):
+ encoding = encoding_name
+
+ new_data = self._modify_response_document(http_info, encoding)
+
+ if isinstance(new_data, str):
+ # Appending a three-byte Byte Order Mark (BOM) will force the
+ # browser to decode this as UTF-8 regardless of the 'Content-Type'
+ # header. See
+ # https://www.w3.org/International/tests/repository/html5/the-input-byte-stream/results-basics#precedence
+ new_data = UTF8_BOM + new_data.encode()
+
+ return new_data
+
def consume_request(self, request_info: http_messages.RequestInfo) \
-> t.Optional[MessageInfo]:
+ # We're not using @abstractmethod because not every Policy needs it and
+ # we don't want to force child classes into implementing dummy methods.
raise NotImplementedError(
'This kind of policy does not consume requests.'
)
- def consume_response(
- self,
- request_info: http_messages.RequestInfo,
- response_info: http_messages.ResponseInfo
- ) -> t.Optional[http_messages.ResponseInfo]:
- raise NotImplementedError(
- 'This kind of policy does not consume responses.'
+ def consume_response(self, http_info: http_messages.FullHTTPInfo) \
+ -> t.Optional[http_messages.ResponseInfo]:
+ try:
+ new_headers = self._modify_response_headers(http_info)
+ new_body = self._modify_response_body(http_info)
+ except Exception as e:
+ # In the future we might want to actually describe eventual errors.
+ # For now, we're just printing the stack trace.
+ import traceback
+
+ error_info_list = traceback.format_exception(
+ type(e),
+ e,
+ e.__traceback__
+ )
+
+ return http_messages.ResponseInfo.make(
+ status_code = 500,
+ headers = (('Content-Type', 'text/plain; charset=utf-8'),),
+ body = '\n'.join(error_info_list).encode()
+ )
+
+ if (new_headers is http_info.response_info.headers and
+ new_body is http_info.response_info.body):
+ return None
+
+ return dc.replace(
+ http_info.response_info,
+ headers = new_headers,
+ body = new_body
)