✨ make headers case-insensitive
This commit is contained in:
parent
d98d61e4b1
commit
4c4bd153be
@ -42,6 +42,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
*,
|
||||
addr: str = None,
|
||||
port: int = None,
|
||||
allowed_hosts=None,
|
||||
cors_allowed_origins=None,
|
||||
cors_allow_all_origins=False,
|
||||
db: Optional[Database] = None,
|
||||
templates_dirs: list[str] = None,
|
||||
middleware: list[str] = None,
|
||||
@ -49,7 +52,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
staticfiles_dirs: list[str] = None,
|
||||
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
||||
error_routes: dict[int, Callable] = None,
|
||||
allowed_hosts=None,
|
||||
secret_key: str = None,
|
||||
session_max_age=60 * 60 * 24 * 14, # 2 weeks
|
||||
session_cookie_name="swsession",
|
||||
@ -75,6 +77,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
self.secret_key = secret_key if secret_key else self.generate_key()
|
||||
self.allowed_hosts = allowed_hosts or ["*"]
|
||||
|
||||
self.cors_allowed_origins = cors_allowed_origins or []
|
||||
self.cors_allow_all_origins = cors_allow_all_origins
|
||||
|
||||
self.extra_data = kwargs
|
||||
|
||||
# session middleware
|
||||
@ -136,12 +141,19 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
try:
|
||||
status = get_http_status_by_code(resp.status_code)
|
||||
cookies = []
|
||||
if "Set-Cookie" in resp.headers:
|
||||
cookies = resp.headers["Set-Cookie"]
|
||||
del resp.headers["Set-Cookie"]
|
||||
varies = []
|
||||
if "set-cookie" in resp.headers:
|
||||
cookies = resp.headers["set-cookie"]
|
||||
del resp.headers["set-cookie"]
|
||||
if "vary" in resp.headers:
|
||||
varies = resp.headers["vary"]
|
||||
del resp.headers["vary"]
|
||||
headers = list(resp.headers.items())
|
||||
for c in cookies:
|
||||
headers.append(("Set-Cookie", c))
|
||||
for v in varies:
|
||||
headers.append(("Vary", v))
|
||||
|
||||
|
||||
start_response(status, headers)
|
||||
|
||||
@ -182,7 +194,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
):
|
||||
try:
|
||||
status = get_http_status_by_code(500)
|
||||
headers = [("Content-type", "text/plain; charset=utf-8")]
|
||||
headers = [("Content-Type", "text/plain; charset=utf-8")]
|
||||
|
||||
start_response(status, headers)
|
||||
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from spiderweb.constants import DEFAULT_ENCODING
|
||||
from spiderweb.utils import get_client_address
|
||||
from spiderweb.utils import get_client_address, Headers
|
||||
|
||||
|
||||
class Request:
|
||||
@ -38,20 +38,22 @@ class Request:
|
||||
self.populate_meta()
|
||||
self.populate_cookies()
|
||||
|
||||
content_length = int(self.headers.get("CONTENT_LENGTH") or 0)
|
||||
content_length = int(self.headers.get("content_length") or 0)
|
||||
if content_length:
|
||||
self.content = (
|
||||
self.environ["wsgi.input"].read(content_length).decode(DEFAULT_ENCODING)
|
||||
)
|
||||
|
||||
def populate_headers(self) -> None:
|
||||
self.headers |= {
|
||||
"CONTENT_TYPE": self.environ.get("CONTENT_TYPE"),
|
||||
"CONTENT_LENGTH": self.environ.get("CONTENT_LENGTH"),
|
||||
data = self.headers
|
||||
data |= {
|
||||
"content_type": self.environ.get("CONTENT_TYPE"),
|
||||
"content_length": self.environ.get("CONTENT_LENGTH"),
|
||||
}
|
||||
for k, v in self.environ.items():
|
||||
if k.startswith("HTTP_"):
|
||||
self.headers[k] = v
|
||||
data[k] = v
|
||||
self.headers = Headers(**{k.lower(): v for k, v in data.items()})
|
||||
|
||||
def populate_meta(self) -> None:
|
||||
# all caps fields are from WSGI, lowercase names
|
||||
@ -89,6 +91,6 @@ class Request:
|
||||
|
||||
def is_form_request(self) -> bool:
|
||||
return (
|
||||
"CONTENT_TYPE" in self.headers
|
||||
and self.headers["CONTENT_TYPE"] == "application/x-www-form-urlencoded"
|
||||
"content_type" in self.headers
|
||||
and self.headers["content_type"] == "application/x-www-form-urlencoded"
|
||||
)
|
||||
|
@ -10,6 +10,8 @@ from wsgiref.util import FileWrapper
|
||||
from spiderweb.constants import REGEX_COOKIE_NAME
|
||||
from spiderweb.exceptions import GeneralException
|
||||
from spiderweb.request import Request
|
||||
from spiderweb.utils import Headers
|
||||
|
||||
|
||||
mimetypes.init()
|
||||
|
||||
@ -28,10 +30,11 @@ class HttpResponse:
|
||||
self.context = context if context else {}
|
||||
self.status_code = status_code
|
||||
self.headers = headers if headers else {}
|
||||
if not self.headers.get("Content-Type"):
|
||||
self.headers["Content-Type"] = "text/html; charset=utf-8"
|
||||
self.headers["Server"] = "Spiderweb"
|
||||
self.headers["Date"] = datetime.datetime.now(tz=datetime.UTC).strftime(
|
||||
self.headers = Headers(**{k.lower(): v for k, v in self.headers.items()})
|
||||
if not self.headers.get("content-type"):
|
||||
self.headers["content-type"] = "text/html; charset=utf-8"
|
||||
self.headers["server"] = "Spiderweb"
|
||||
self.headers["date"] = datetime.datetime.now(tz=datetime.UTC).strftime(
|
||||
"%a, %d %b %Y %H:%M:%S GMT"
|
||||
)
|
||||
|
||||
@ -89,10 +92,10 @@ class HttpResponse:
|
||||
attrs = [urllib.parse.quote_plus(value)] + attrs
|
||||
cookie = f"{name}={'; '.join(attrs)}"
|
||||
|
||||
if "Set-Cookie" in self.headers:
|
||||
self.headers["Set-Cookie"].append(cookie)
|
||||
if "set-cookie" in self.headers:
|
||||
self.headers["set-cookie"].append(cookie)
|
||||
else:
|
||||
self.headers["Set-Cookie"] = [cookie]
|
||||
self.headers["set-cookie"] = [cookie]
|
||||
|
||||
def render(self) -> str:
|
||||
return str(self.body)
|
||||
@ -103,7 +106,7 @@ class FileResponse(HttpResponse):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.filename = filename
|
||||
self.content_type = mimetypes.guess_type(self.filename)[0]
|
||||
self.headers["Content-Type"] = self.content_type
|
||||
self.headers["content-type"] = self.content_type
|
||||
|
||||
def render(self) -> list[bytes]:
|
||||
with open(self.filename, "rb") as f:
|
||||
@ -114,7 +117,7 @@ class FileResponse(HttpResponse):
|
||||
class JsonResponse(HttpResponse):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.headers["Content-Type"] = "application/json"
|
||||
self.headers["content-type"] = "application/json"
|
||||
|
||||
def render(self) -> str:
|
||||
return json.dumps(self.data)
|
||||
@ -124,7 +127,7 @@ class RedirectResponse(HttpResponse):
|
||||
def __init__(self, location: str, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.status_code = 302
|
||||
self.headers["Location"] = location
|
||||
self.headers["location"] = location
|
||||
|
||||
|
||||
class TemplateResponse(HttpResponse):
|
||||
|
@ -71,7 +71,7 @@ def test_redirect_response():
|
||||
return RedirectResponse(location="/redirected")
|
||||
|
||||
assert app(environ, start_response) == [b"None"]
|
||||
assert start_response.get_headers()["Location"] == "/redirected"
|
||||
assert start_response.get_headers()["location"] == "/redirected"
|
||||
|
||||
|
||||
def test_add_route_at_server_start():
|
||||
@ -91,7 +91,7 @@ def test_add_route_at_server_start():
|
||||
)
|
||||
|
||||
assert app(environ, start_response) == [b"None"]
|
||||
assert start_response.get_headers()["Location"] == "/redirected"
|
||||
assert start_response.get_headers()["location"] == "/redirected"
|
||||
|
||||
|
||||
def test_redirect_on_append_slash():
|
||||
@ -104,7 +104,7 @@ def test_redirect_on_append_slash():
|
||||
|
||||
environ["PATH_INFO"] = f"/hello"
|
||||
assert app(environ, start_response) == [b"None"]
|
||||
assert start_response.get_headers()["Location"] == "/hello/"
|
||||
assert start_response.get_headers()["location"] == "/hello/"
|
||||
|
||||
|
||||
@given(st.text())
|
||||
|
@ -63,3 +63,18 @@ def is_jsonable(data: str) -> bool:
|
||||
return True
|
||||
except (TypeError, OverflowError):
|
||||
return False
|
||||
|
||||
|
||||
class Headers(dict):
|
||||
# special dict that forces lowercase for all keys
|
||||
def __getitem__(self, key):
|
||||
return super().__getitem__(key.lower())
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return super().__setitem__(key.lower(), value)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return super().get(key.lower(), default)
|
||||
|
||||
def setdefault(self, key, default = None):
|
||||
return super().setdefault(key.lower(), default)
|
Loading…
Reference in New Issue
Block a user