aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/csp.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/csp.py')
-rw-r--r--src/hydrilla/proxy/csp.py146
1 files changed, 109 insertions, 37 deletions
diff --git a/src/hydrilla/proxy/csp.py b/src/hydrilla/proxy/csp.py
index 8eb914f..df2f65b 100644
--- a/src/hydrilla/proxy/csp.py
+++ b/src/hydrilla/proxy/csp.py
@@ -38,33 +38,56 @@ from immutables import Map, MapMutation
from . import http_messages
-header_names_and_dispositions = (
- ('content-security-policy', 'enforce'),
- ('content-security-policy-report-only', 'report'),
- ('x-content-security-policy', 'enforce'),
- ('x-content-security-policy', 'report'),
- ('x-webkit-csp', 'enforce'),
- ('x-webkit-csp', 'report')
+enforce_header_names = (
+ 'content-security-policy',
+ 'x-content-security-policy',
+ 'x-webkit-csp'
)
-enforce_header_names_set = {
- name for name, disposition in header_names_and_dispositions
- if disposition == 'enforce'
-}
+header_names = (*enforce_header_names, 'content-security-policy-report-only')
@dc.dataclass
class ContentSecurityPolicy:
directives: Map[str, t.Sequence[str]]
- header_name: str
- disposition: str
+ header_name: str = 'Content-Security-Policy'
+ disposition: str = 'enforce'
- def serialize(self) -> str:
- """...."""
+ def remove(self, directives: t.Sequence[str]) -> 'ContentSecurityPolicy':
+ mutation = self.directives.mutate()
+
+ for name in directives:
+ mutation.pop(name, None)
+
+ return dc.replace(self, directives = mutation.finish())
+
+ def extend(self, directives: t.Mapping[str, t.Sequence[str]]) \
+ -> 'ContentSecurityPolicy':
+ mutation = self.directives.mutate()
+
+ for name, extras in directives.items():
+ if name in mutation:
+ mutation[name] = (*mutation[name], *extras)
+
+ return dc.replace(self, directives = mutation.finish())
+
+ def serialize(self) -> tuple[str, str]:
+ """
+ Produces (name, value) pair suitable for use as an HTTP header.
+
+ If a deserialized policy is being reserialized, the resulting value is
+ not guaranteed to be the same as the original one. It shall be merely
+ semantically equivalent.
+ """
serialized_directives = []
- for name, value_list in self.directives.items():
- serialized_directives.append(f'{name} {" ".join(value_list)}')
+ for name, value_seq in self.directives.items():
+ if all(val == "'none'" for val in value_seq):
+ value_seq = ["'none'"]
+ else:
+ value_seq = [val for val in value_seq if val != "'none'"]
+
+ serialized_directives.append(f'{name} {" ".join(value_seq)}')
- return ';'.join(serialized_directives)
+ return (self.header_name, ';'.join(serialized_directives))
@staticmethod
def deserialize(
@@ -72,7 +95,13 @@ class ContentSecurityPolicy:
header_name: str,
disposition: str = 'enforce'
) -> 'ContentSecurityPolicy':
- """...."""
+ """
+ Parses the policy as required by W3C Working Draft.
+
+ Extra whitespace information, invalid/empty directives and the order of
+ directives are not preserved, only the semantically-relevant information
+ is.
+ """
# For more info, see:
# https://www.w3.org/TR/CSP3/#parse-serialized-policy
empty_directives: Map[str, t.Sequence[str]] = Map()
@@ -104,21 +133,64 @@ class ContentSecurityPolicy:
disposition = disposition
)
-def extract(headers: http_messages.IHeaders) \
- -> tuple[ContentSecurityPolicy, ...]:
- """...."""
- csp_policies = []
-
- for header_name, disposition in header_names_and_dispositions:
- for serialized_list in headers.get_all(header_name):
- for serialized in serialized_list.split(','):
- policy = ContentSecurityPolicy.deserialize(
- serialized,
- header_name,
- disposition
- )
-
- if policy.directives != Map():
- csp_policies.append(policy)
-
- return tuple(csp_policies)
+# def extract(headers: http_messages.IHeaders) \
+# -> tuple[ContentSecurityPolicy, ...]:
+# """...."""
+# csp_policies = []
+
+# for header_name, disposition in header_names_and_dispositions:
+# for serialized_list in headers.get_all(header_name):
+# for serialized in serialized_list.split(','):
+# policy = ContentSecurityPolicy.deserialize(
+# serialized,
+# header_name,
+# disposition
+# )
+
+# if policy.directives != Map():
+# csp_policies.append(policy)
+
+# return tuple(csp_policies)
+
+def modify(
+ headers: http_messages.IHeaders,
+ clear: t.Union[t.Sequence[str], t.Literal['all']] = (),
+ extend: t.Mapping[str, t.Sequence[str]] = Map(),
+ add: t.Mapping[str, t.Sequence[str]] = Map(),
+) -> http_messages.IHeaders:
+ """
+ This function modifies the CSP Headers. The following actions are performed
+ *in order*
+ 1. report-only CSP Headers are removed,
+ 2. directives with names in `clear` are removed,
+ 3. directives that could cause CSP reports to be sent are removed,
+ 4. directives from `add` are added in a separate Content-Security-Policy,
+ header.
+ 5. directives from `extend` are merged into the existing directives,
+ effectively loosening them,
+
+ No measures are yet implemented to prevent fingerprinting when serving HTTP
+ responses with headers modified by this function. Please use wisely, you
+ have been warned.
+ """
+ headers_list = [
+ (key, val)
+ for key, val in headers.items()
+ if key.lower() not in header_names
+ ]
+
+ if clear != 'all':
+ for name in header_names:
+ for serialized_list in headers.get_all(name):
+ for serialized in serialized_list.split(','):
+ policy = ContentSecurityPolicy.deserialize(serialized, name)
+ policy = policy.remove((*clear, 'report-to', 'report-uri'))
+ policy = policy.extend(extend)
+ if policy.directives != Map():
+ headers_list.append(policy.serialize())
+
+ if add != Map():
+ csp_to_add = ContentSecurityPolicy(Map(add)).extend(extend)
+ headers_list.append(csp_to_add.serialize())
+
+ return http_messages.make_headers(headers_list)