From 4c4bd153be2de7febeb0930651a147447296f081 Mon Sep 17 00:00:00 2001 From: Joe Kaufeld Date: Sat, 31 Aug 2024 22:40:54 -0400 Subject: [PATCH] :sparkles: make headers case-insensitive --- spiderweb/main.py | 22 +++++++++++++++++----- spiderweb/request.py | 18 ++++++++++-------- spiderweb/response.py | 23 +++++++++++++---------- spiderweb/tests/test_responses.py | 6 +++--- spiderweb/utils.py | 15 +++++++++++++++ 5 files changed, 58 insertions(+), 26 deletions(-) diff --git a/spiderweb/main.py b/spiderweb/main.py index 17ac9ca..ae1ad69 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -42,6 +42,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi *, addr: str = None, port: int = None, + allowed_hosts=None, + cors_allowed_origins=None, + cors_allow_all_origins=False, db: Optional[Database] = None, templates_dirs: list[str] = None, middleware: list[str] = None, @@ -49,7 +52,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi staticfiles_dirs: list[str] = None, routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None, error_routes: dict[int, Callable] = None, - allowed_hosts=None, secret_key: str = None, session_max_age=60 * 60 * 24 * 14, # 2 weeks session_cookie_name="swsession", @@ -75,6 +77,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi self.secret_key = secret_key if secret_key else self.generate_key() self.allowed_hosts = allowed_hosts or ["*"] + self.cors_allowed_origins = cors_allowed_origins or [] + self.cors_allow_all_origins = cors_allow_all_origins + self.extra_data = kwargs # session middleware @@ -136,12 +141,19 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi try: status = get_http_status_by_code(resp.status_code) cookies = [] - if "Set-Cookie" in resp.headers: - cookies = resp.headers["Set-Cookie"] - del resp.headers["Set-Cookie"] + varies = [] + if "set-cookie" in resp.headers: + cookies = resp.headers["set-cookie"] + del resp.headers["set-cookie"] + if "vary" in resp.headers: + varies = resp.headers["vary"] + del resp.headers["vary"] headers = list(resp.headers.items()) for c in cookies: headers.append(("Set-Cookie", c)) + for v in varies: + headers.append(("Vary", v)) + start_response(status, headers) @@ -182,7 +194,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi ): try: status = get_http_status_by_code(500) - headers = [("Content-type", "text/plain; charset=utf-8")] + headers = [("Content-Type", "text/plain; charset=utf-8")] start_response(status, headers) diff --git a/spiderweb/request.py b/spiderweb/request.py index a617e99..90e0160 100644 --- a/spiderweb/request.py +++ b/spiderweb/request.py @@ -2,7 +2,7 @@ import json from urllib.parse import urlparse from spiderweb.constants import DEFAULT_ENCODING -from spiderweb.utils import get_client_address +from spiderweb.utils import get_client_address, Headers class Request: @@ -38,20 +38,22 @@ class Request: self.populate_meta() self.populate_cookies() - content_length = int(self.headers.get("CONTENT_LENGTH") or 0) + content_length = int(self.headers.get("content_length") or 0) if content_length: self.content = ( self.environ["wsgi.input"].read(content_length).decode(DEFAULT_ENCODING) ) def populate_headers(self) -> None: - self.headers |= { - "CONTENT_TYPE": self.environ.get("CONTENT_TYPE"), - "CONTENT_LENGTH": self.environ.get("CONTENT_LENGTH"), + data = self.headers + data |= { + "content_type": self.environ.get("CONTENT_TYPE"), + "content_length": self.environ.get("CONTENT_LENGTH"), } for k, v in self.environ.items(): if k.startswith("HTTP_"): - self.headers[k] = v + data[k] = v + self.headers = Headers(**{k.lower(): v for k, v in data.items()}) def populate_meta(self) -> None: # all caps fields are from WSGI, lowercase names @@ -89,6 +91,6 @@ class Request: def is_form_request(self) -> bool: return ( - "CONTENT_TYPE" in self.headers - and self.headers["CONTENT_TYPE"] == "application/x-www-form-urlencoded" + "content_type" in self.headers + and self.headers["content_type"] == "application/x-www-form-urlencoded" ) diff --git a/spiderweb/response.py b/spiderweb/response.py index 0de90ab..0e79648 100644 --- a/spiderweb/response.py +++ b/spiderweb/response.py @@ -10,6 +10,8 @@ from wsgiref.util import FileWrapper from spiderweb.constants import REGEX_COOKIE_NAME from spiderweb.exceptions import GeneralException from spiderweb.request import Request +from spiderweb.utils import Headers + mimetypes.init() @@ -28,10 +30,11 @@ class HttpResponse: self.context = context if context else {} self.status_code = status_code self.headers = headers if headers else {} - if not self.headers.get("Content-Type"): - self.headers["Content-Type"] = "text/html; charset=utf-8" - self.headers["Server"] = "Spiderweb" - self.headers["Date"] = datetime.datetime.now(tz=datetime.UTC).strftime( + self.headers = Headers(**{k.lower(): v for k, v in self.headers.items()}) + if not self.headers.get("content-type"): + self.headers["content-type"] = "text/html; charset=utf-8" + self.headers["server"] = "Spiderweb" + self.headers["date"] = datetime.datetime.now(tz=datetime.UTC).strftime( "%a, %d %b %Y %H:%M:%S GMT" ) @@ -89,10 +92,10 @@ class HttpResponse: attrs = [urllib.parse.quote_plus(value)] + attrs cookie = f"{name}={'; '.join(attrs)}" - if "Set-Cookie" in self.headers: - self.headers["Set-Cookie"].append(cookie) + if "set-cookie" in self.headers: + self.headers["set-cookie"].append(cookie) else: - self.headers["Set-Cookie"] = [cookie] + self.headers["set-cookie"] = [cookie] def render(self) -> str: return str(self.body) @@ -103,7 +106,7 @@ class FileResponse(HttpResponse): super().__init__(*args, **kwargs) self.filename = filename self.content_type = mimetypes.guess_type(self.filename)[0] - self.headers["Content-Type"] = self.content_type + self.headers["content-type"] = self.content_type def render(self) -> list[bytes]: with open(self.filename, "rb") as f: @@ -114,7 +117,7 @@ class FileResponse(HttpResponse): class JsonResponse(HttpResponse): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.headers["Content-Type"] = "application/json" + self.headers["content-type"] = "application/json" def render(self) -> str: return json.dumps(self.data) @@ -124,7 +127,7 @@ class RedirectResponse(HttpResponse): def __init__(self, location: str, *args, **kwargs): super().__init__(*args, **kwargs) self.status_code = 302 - self.headers["Location"] = location + self.headers["location"] = location class TemplateResponse(HttpResponse): diff --git a/spiderweb/tests/test_responses.py b/spiderweb/tests/test_responses.py index 63c0011..df989bf 100644 --- a/spiderweb/tests/test_responses.py +++ b/spiderweb/tests/test_responses.py @@ -71,7 +71,7 @@ def test_redirect_response(): return RedirectResponse(location="/redirected") assert app(environ, start_response) == [b"None"] - assert start_response.get_headers()["Location"] == "/redirected" + assert start_response.get_headers()["location"] == "/redirected" def test_add_route_at_server_start(): @@ -91,7 +91,7 @@ def test_add_route_at_server_start(): ) assert app(environ, start_response) == [b"None"] - assert start_response.get_headers()["Location"] == "/redirected" + assert start_response.get_headers()["location"] == "/redirected" def test_redirect_on_append_slash(): @@ -104,7 +104,7 @@ def test_redirect_on_append_slash(): environ["PATH_INFO"] = f"/hello" assert app(environ, start_response) == [b"None"] - assert start_response.get_headers()["Location"] == "/hello/" + assert start_response.get_headers()["location"] == "/hello/" @given(st.text()) diff --git a/spiderweb/utils.py b/spiderweb/utils.py index 42baf35..d24ef04 100644 --- a/spiderweb/utils.py +++ b/spiderweb/utils.py @@ -63,3 +63,18 @@ def is_jsonable(data: str) -> bool: return True except (TypeError, OverflowError): return False + + +class Headers(dict): + # special dict that forces lowercase for all keys + def __getitem__(self, key): + return super().__getitem__(key.lower()) + + def __setitem__(self, key, value): + return super().__setitem__(key.lower(), value) + + def get(self, key, default=None): + return super().get(key.lower(), default) + + def setdefault(self, key, default = None): + return super().setdefault(key.lower(), default) \ No newline at end of file