diff options
Diffstat (limited to 'src/hydrilla/proxy/policies/base.py')
-rw-r--r-- | src/hydrilla/proxy/policies/base.py | 114 |
1 files changed, 106 insertions, 8 deletions
diff --git a/src/hydrilla/proxy/policies/base.py b/src/hydrilla/proxy/policies/base.py index 8ea792f..7ce8663 100644 --- a/src/hydrilla/proxy/policies/base.py +++ b/src/hydrilla/proxy/policies/base.py @@ -40,6 +40,7 @@ from immutables import Map from ... url_patterns import ParsedUrl from .. import state from .. import http_messages +from .. import csp class PolicyPriority(int, enum.Enum): @@ -53,6 +54,15 @@ MessageInfo = t.Union[ http_messages.ResponseInfo ] + +UTF8_BOM = b'\xEF\xBB\xBF' +BOMs = ( + (UTF8_BOM, 'utf-8'), + (b'\xFE\xFF', 'utf-16be'), + (b'\xFF\xFE', 'utf-16le') +) + + class Policy(ABC): """....""" _process_request: t.ClassVar[bool] = False @@ -70,23 +80,111 @@ class Policy(ABC): def should_process_response( self, request_info: http_messages.RequestInfo, - response_info: http_messages.BodylessResponseInfo + response_info: http_messages.AnyResponseInfo ) -> bool: return self._process_response + def _csp_to_clear(self, http_info: http_messages.FullHTTPInfo) \ + -> t.Union[t.Sequence[str], t.Literal['all']]: + return () + + def _csp_to_add(self, http_info: http_messages.FullHTTPInfo) \ + -> t.Mapping[str, t.Sequence[str]]: + return Map() + + def _csp_to_extend(self, http_info: http_messages.FullHTTPInfo) \ + -> t.Mapping[str, t.Sequence[str]]: + return Map() + + def _modify_response_headers(self, http_info: http_messages.FullHTTPInfo) \ + -> http_messages.IHeaders: + csp_to_clear = self._csp_to_clear(http_info) + csp_to_add = self._csp_to_add(http_info) + csp_to_extend = self._csp_to_extend(http_info) + + if len(csp_to_clear) + len(csp_to_extend) + len(csp_to_add) == 0: + return http_info.response_info.headers + + return csp.modify( + headers = http_info.response_info.headers, + clear = csp_to_clear, + add = csp_to_add, + extend = csp_to_extend + ) + + def _modify_response_document( + self, + http_info: http_messages.FullHTTPInfo, + encoding: t.Optional[str] + ) -> t.Union[str, bytes]: + return http_info.response_info.body + + def _modify_response_body(self, http_info: http_messages.FullHTTPInfo) \ + -> bytes: + if not http_messages.is_likely_a_page( + request_info = http_info.request_info, + response_info = http_info.response_info + ): + return http_info.response_info.body + + data = http_info.response_info.body + + _, encoding = http_info.response_info.deduce_content_type() + + # A UTF BOM overrides encoding specified by the header. + for bom, encoding_name in BOMs: + if data.startswith(bom): + encoding = encoding_name + + new_data = self._modify_response_document(http_info, encoding) + + if isinstance(new_data, str): + # Appending a three-byte Byte Order Mark (BOM) will force the + # browser to decode this as UTF-8 regardless of the 'Content-Type' + # header. See + # https://www.w3.org/International/tests/repository/html5/the-input-byte-stream/results-basics#precedence + new_data = UTF8_BOM + new_data.encode() + + return new_data + def consume_request(self, request_info: http_messages.RequestInfo) \ -> t.Optional[MessageInfo]: + # We're not using @abstractmethod because not every Policy needs it and + # we don't want to force child classes into implementing dummy methods. raise NotImplementedError( 'This kind of policy does not consume requests.' ) - def consume_response( - self, - request_info: http_messages.RequestInfo, - response_info: http_messages.ResponseInfo - ) -> t.Optional[http_messages.ResponseInfo]: - raise NotImplementedError( - 'This kind of policy does not consume responses.' + def consume_response(self, http_info: http_messages.FullHTTPInfo) \ + -> t.Optional[http_messages.ResponseInfo]: + try: + new_headers = self._modify_response_headers(http_info) + new_body = self._modify_response_body(http_info) + except Exception as e: + # In the future we might want to actually describe eventual errors. + # For now, we're just printing the stack trace. + import traceback + + error_info_list = traceback.format_exception( + type(e), + e, + e.__traceback__ + ) + + return http_messages.ResponseInfo.make( + status_code = 500, + headers = (('Content-Type', 'text/plain; charset=utf-8'),), + body = '\n'.join(error_info_list).encode() + ) + + if (new_headers is http_info.response_info.headers and + new_body is http_info.response_info.body): + return None + + return dc.replace( + http_info.response_info, + headers = new_headers, + body = new_body ) |