make headers case-insensitive

This commit is contained in:
Joe Kaufeld 2024-08-31 22:40:54 -04:00
parent d98d61e4b1
commit 4c4bd153be
5 changed files with 58 additions and 26 deletions

View File

@ -42,6 +42,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
*, *,
addr: str = None, addr: str = None,
port: int = None, port: int = None,
allowed_hosts=None,
cors_allowed_origins=None,
cors_allow_all_origins=False,
db: Optional[Database] = None, db: Optional[Database] = None,
templates_dirs: list[str] = None, templates_dirs: list[str] = None,
middleware: list[str] = None, middleware: list[str] = None,
@ -49,7 +52,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
staticfiles_dirs: list[str] = None, staticfiles_dirs: list[str] = None,
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None, routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
error_routes: dict[int, Callable] = None, error_routes: dict[int, Callable] = None,
allowed_hosts=None,
secret_key: str = None, secret_key: str = None,
session_max_age=60 * 60 * 24 * 14, # 2 weeks session_max_age=60 * 60 * 24 * 14, # 2 weeks
session_cookie_name="swsession", session_cookie_name="swsession",
@ -75,6 +77,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.secret_key = secret_key if secret_key else self.generate_key() self.secret_key = secret_key if secret_key else self.generate_key()
self.allowed_hosts = allowed_hosts or ["*"] self.allowed_hosts = allowed_hosts or ["*"]
self.cors_allowed_origins = cors_allowed_origins or []
self.cors_allow_all_origins = cors_allow_all_origins
self.extra_data = kwargs self.extra_data = kwargs
# session middleware # session middleware
@ -136,12 +141,19 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
try: try:
status = get_http_status_by_code(resp.status_code) status = get_http_status_by_code(resp.status_code)
cookies = [] cookies = []
if "Set-Cookie" in resp.headers: varies = []
cookies = resp.headers["Set-Cookie"] if "set-cookie" in resp.headers:
del resp.headers["Set-Cookie"] cookies = resp.headers["set-cookie"]
del resp.headers["set-cookie"]
if "vary" in resp.headers:
varies = resp.headers["vary"]
del resp.headers["vary"]
headers = list(resp.headers.items()) headers = list(resp.headers.items())
for c in cookies: for c in cookies:
headers.append(("Set-Cookie", c)) headers.append(("Set-Cookie", c))
for v in varies:
headers.append(("Vary", v))
start_response(status, headers) start_response(status, headers)
@ -182,7 +194,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
): ):
try: try:
status = get_http_status_by_code(500) status = get_http_status_by_code(500)
headers = [("Content-type", "text/plain; charset=utf-8")] headers = [("Content-Type", "text/plain; charset=utf-8")]
start_response(status, headers) start_response(status, headers)

View File

@ -2,7 +2,7 @@ import json
from urllib.parse import urlparse from urllib.parse import urlparse
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.utils import get_client_address from spiderweb.utils import get_client_address, Headers
class Request: class Request:
@ -38,20 +38,22 @@ class Request:
self.populate_meta() self.populate_meta()
self.populate_cookies() self.populate_cookies()
content_length = int(self.headers.get("CONTENT_LENGTH") or 0) content_length = int(self.headers.get("content_length") or 0)
if content_length: if content_length:
self.content = ( self.content = (
self.environ["wsgi.input"].read(content_length).decode(DEFAULT_ENCODING) self.environ["wsgi.input"].read(content_length).decode(DEFAULT_ENCODING)
) )
def populate_headers(self) -> None: def populate_headers(self) -> None:
self.headers |= { data = self.headers
"CONTENT_TYPE": self.environ.get("CONTENT_TYPE"), data |= {
"CONTENT_LENGTH": self.environ.get("CONTENT_LENGTH"), "content_type": self.environ.get("CONTENT_TYPE"),
"content_length": self.environ.get("CONTENT_LENGTH"),
} }
for k, v in self.environ.items(): for k, v in self.environ.items():
if k.startswith("HTTP_"): if k.startswith("HTTP_"):
self.headers[k] = v data[k] = v
self.headers = Headers(**{k.lower(): v for k, v in data.items()})
def populate_meta(self) -> None: def populate_meta(self) -> None:
# all caps fields are from WSGI, lowercase names # all caps fields are from WSGI, lowercase names
@ -89,6 +91,6 @@ class Request:
def is_form_request(self) -> bool: def is_form_request(self) -> bool:
return ( return (
"CONTENT_TYPE" in self.headers "content_type" in self.headers
and self.headers["CONTENT_TYPE"] == "application/x-www-form-urlencoded" and self.headers["content_type"] == "application/x-www-form-urlencoded"
) )

View File

@ -10,6 +10,8 @@ from wsgiref.util import FileWrapper
from spiderweb.constants import REGEX_COOKIE_NAME from spiderweb.constants import REGEX_COOKIE_NAME
from spiderweb.exceptions import GeneralException from spiderweb.exceptions import GeneralException
from spiderweb.request import Request from spiderweb.request import Request
from spiderweb.utils import Headers
mimetypes.init() mimetypes.init()
@ -28,10 +30,11 @@ class HttpResponse:
self.context = context if context else {} self.context = context if context else {}
self.status_code = status_code self.status_code = status_code
self.headers = headers if headers else {} self.headers = headers if headers else {}
if not self.headers.get("Content-Type"): self.headers = Headers(**{k.lower(): v for k, v in self.headers.items()})
self.headers["Content-Type"] = "text/html; charset=utf-8" if not self.headers.get("content-type"):
self.headers["Server"] = "Spiderweb" self.headers["content-type"] = "text/html; charset=utf-8"
self.headers["Date"] = datetime.datetime.now(tz=datetime.UTC).strftime( self.headers["server"] = "Spiderweb"
self.headers["date"] = datetime.datetime.now(tz=datetime.UTC).strftime(
"%a, %d %b %Y %H:%M:%S GMT" "%a, %d %b %Y %H:%M:%S GMT"
) )
@ -89,10 +92,10 @@ class HttpResponse:
attrs = [urllib.parse.quote_plus(value)] + attrs attrs = [urllib.parse.quote_plus(value)] + attrs
cookie = f"{name}={'; '.join(attrs)}" cookie = f"{name}={'; '.join(attrs)}"
if "Set-Cookie" in self.headers: if "set-cookie" in self.headers:
self.headers["Set-Cookie"].append(cookie) self.headers["set-cookie"].append(cookie)
else: else:
self.headers["Set-Cookie"] = [cookie] self.headers["set-cookie"] = [cookie]
def render(self) -> str: def render(self) -> str:
return str(self.body) return str(self.body)
@ -103,7 +106,7 @@ class FileResponse(HttpResponse):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.filename = filename self.filename = filename
self.content_type = mimetypes.guess_type(self.filename)[0] self.content_type = mimetypes.guess_type(self.filename)[0]
self.headers["Content-Type"] = self.content_type self.headers["content-type"] = self.content_type
def render(self) -> list[bytes]: def render(self) -> list[bytes]:
with open(self.filename, "rb") as f: with open(self.filename, "rb") as f:
@ -114,7 +117,7 @@ class FileResponse(HttpResponse):
class JsonResponse(HttpResponse): class JsonResponse(HttpResponse):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.headers["Content-Type"] = "application/json" self.headers["content-type"] = "application/json"
def render(self) -> str: def render(self) -> str:
return json.dumps(self.data) return json.dumps(self.data)
@ -124,7 +127,7 @@ class RedirectResponse(HttpResponse):
def __init__(self, location: str, *args, **kwargs): def __init__(self, location: str, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.status_code = 302 self.status_code = 302
self.headers["Location"] = location self.headers["location"] = location
class TemplateResponse(HttpResponse): class TemplateResponse(HttpResponse):

View File

@ -71,7 +71,7 @@ def test_redirect_response():
return RedirectResponse(location="/redirected") return RedirectResponse(location="/redirected")
assert app(environ, start_response) == [b"None"] assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/redirected" assert start_response.get_headers()["location"] == "/redirected"
def test_add_route_at_server_start(): def test_add_route_at_server_start():
@ -91,7 +91,7 @@ def test_add_route_at_server_start():
) )
assert app(environ, start_response) == [b"None"] assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/redirected" assert start_response.get_headers()["location"] == "/redirected"
def test_redirect_on_append_slash(): def test_redirect_on_append_slash():
@ -104,7 +104,7 @@ def test_redirect_on_append_slash():
environ["PATH_INFO"] = f"/hello" environ["PATH_INFO"] = f"/hello"
assert app(environ, start_response) == [b"None"] assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/hello/" assert start_response.get_headers()["location"] == "/hello/"
@given(st.text()) @given(st.text())

View File

@ -63,3 +63,18 @@ def is_jsonable(data: str) -> bool:
return True return True
except (TypeError, OverflowError): except (TypeError, OverflowError):
return False return False
class Headers(dict):
# special dict that forces lowercase for all keys
def __getitem__(self, key):
return super().__getitem__(key.lower())
def __setitem__(self, key, value):
return super().__setitem__(key.lower(), value)
def get(self, key, default=None):
return super().get(key.lower(), default)
def setdefault(self, key, default = None):
return super().setdefault(key.lower(), default)