CORS! #1
@ -42,6 +42,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
*,
|
*,
|
||||||
addr: str = None,
|
addr: str = None,
|
||||||
port: int = None,
|
port: int = None,
|
||||||
|
allowed_hosts=None,
|
||||||
|
cors_allowed_origins=None,
|
||||||
|
cors_allow_all_origins=False,
|
||||||
db: Optional[Database] = None,
|
db: Optional[Database] = None,
|
||||||
templates_dirs: list[str] = None,
|
templates_dirs: list[str] = None,
|
||||||
middleware: list[str] = None,
|
middleware: list[str] = None,
|
||||||
@ -49,7 +52,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
staticfiles_dirs: list[str] = None,
|
staticfiles_dirs: list[str] = None,
|
||||||
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
||||||
error_routes: dict[int, Callable] = None,
|
error_routes: dict[int, Callable] = None,
|
||||||
allowed_hosts=None,
|
|
||||||
secret_key: str = None,
|
secret_key: str = None,
|
||||||
session_max_age=60 * 60 * 24 * 14, # 2 weeks
|
session_max_age=60 * 60 * 24 * 14, # 2 weeks
|
||||||
session_cookie_name="swsession",
|
session_cookie_name="swsession",
|
||||||
@ -75,6 +77,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
self.secret_key = secret_key if secret_key else self.generate_key()
|
self.secret_key = secret_key if secret_key else self.generate_key()
|
||||||
self.allowed_hosts = allowed_hosts or ["*"]
|
self.allowed_hosts = allowed_hosts or ["*"]
|
||||||
|
|
||||||
|
self.cors_allowed_origins = cors_allowed_origins or []
|
||||||
|
self.cors_allow_all_origins = cors_allow_all_origins
|
||||||
|
|
||||||
self.extra_data = kwargs
|
self.extra_data = kwargs
|
||||||
|
|
||||||
# session middleware
|
# session middleware
|
||||||
@ -136,12 +141,19 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
try:
|
try:
|
||||||
status = get_http_status_by_code(resp.status_code)
|
status = get_http_status_by_code(resp.status_code)
|
||||||
cookies = []
|
cookies = []
|
||||||
if "Set-Cookie" in resp.headers:
|
varies = []
|
||||||
cookies = resp.headers["Set-Cookie"]
|
if "set-cookie" in resp.headers:
|
||||||
del resp.headers["Set-Cookie"]
|
cookies = resp.headers["set-cookie"]
|
||||||
|
del resp.headers["set-cookie"]
|
||||||
|
if "vary" in resp.headers:
|
||||||
|
varies = resp.headers["vary"]
|
||||||
|
del resp.headers["vary"]
|
||||||
headers = list(resp.headers.items())
|
headers = list(resp.headers.items())
|
||||||
for c in cookies:
|
for c in cookies:
|
||||||
headers.append(("Set-Cookie", c))
|
headers.append(("Set-Cookie", c))
|
||||||
|
for v in varies:
|
||||||
|
headers.append(("Vary", v))
|
||||||
|
|
||||||
|
|
||||||
start_response(status, headers)
|
start_response(status, headers)
|
||||||
|
|
||||||
@ -182,7 +194,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
status = get_http_status_by_code(500)
|
status = get_http_status_by_code(500)
|
||||||
headers = [("Content-type", "text/plain; charset=utf-8")]
|
headers = [("Content-Type", "text/plain; charset=utf-8")]
|
||||||
|
|
||||||
start_response(status, headers)
|
start_response(status, headers)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from spiderweb.constants import DEFAULT_ENCODING
|
from spiderweb.constants import DEFAULT_ENCODING
|
||||||
from spiderweb.utils import get_client_address
|
from spiderweb.utils import get_client_address, Headers
|
||||||
|
|
||||||
|
|
||||||
class Request:
|
class Request:
|
||||||
@ -38,20 +38,22 @@ class Request:
|
|||||||
self.populate_meta()
|
self.populate_meta()
|
||||||
self.populate_cookies()
|
self.populate_cookies()
|
||||||
|
|
||||||
content_length = int(self.headers.get("CONTENT_LENGTH") or 0)
|
content_length = int(self.headers.get("content_length") or 0)
|
||||||
if content_length:
|
if content_length:
|
||||||
self.content = (
|
self.content = (
|
||||||
self.environ["wsgi.input"].read(content_length).decode(DEFAULT_ENCODING)
|
self.environ["wsgi.input"].read(content_length).decode(DEFAULT_ENCODING)
|
||||||
)
|
)
|
||||||
|
|
||||||
def populate_headers(self) -> None:
|
def populate_headers(self) -> None:
|
||||||
self.headers |= {
|
data = self.headers
|
||||||
"CONTENT_TYPE": self.environ.get("CONTENT_TYPE"),
|
data |= {
|
||||||
"CONTENT_LENGTH": self.environ.get("CONTENT_LENGTH"),
|
"content_type": self.environ.get("CONTENT_TYPE"),
|
||||||
|
"content_length": self.environ.get("CONTENT_LENGTH"),
|
||||||
}
|
}
|
||||||
for k, v in self.environ.items():
|
for k, v in self.environ.items():
|
||||||
if k.startswith("HTTP_"):
|
if k.startswith("HTTP_"):
|
||||||
self.headers[k] = v
|
data[k] = v
|
||||||
|
self.headers = Headers(**{k.lower(): v for k, v in data.items()})
|
||||||
|
|
||||||
def populate_meta(self) -> None:
|
def populate_meta(self) -> None:
|
||||||
# all caps fields are from WSGI, lowercase names
|
# all caps fields are from WSGI, lowercase names
|
||||||
@ -89,6 +91,6 @@ class Request:
|
|||||||
|
|
||||||
def is_form_request(self) -> bool:
|
def is_form_request(self) -> bool:
|
||||||
return (
|
return (
|
||||||
"CONTENT_TYPE" in self.headers
|
"content_type" in self.headers
|
||||||
and self.headers["CONTENT_TYPE"] == "application/x-www-form-urlencoded"
|
and self.headers["content_type"] == "application/x-www-form-urlencoded"
|
||||||
)
|
)
|
||||||
|
@ -10,6 +10,8 @@ from wsgiref.util import FileWrapper
|
|||||||
from spiderweb.constants import REGEX_COOKIE_NAME
|
from spiderweb.constants import REGEX_COOKIE_NAME
|
||||||
from spiderweb.exceptions import GeneralException
|
from spiderweb.exceptions import GeneralException
|
||||||
from spiderweb.request import Request
|
from spiderweb.request import Request
|
||||||
|
from spiderweb.utils import Headers
|
||||||
|
|
||||||
|
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
|
|
||||||
@ -28,10 +30,11 @@ class HttpResponse:
|
|||||||
self.context = context if context else {}
|
self.context = context if context else {}
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.headers = headers if headers else {}
|
self.headers = headers if headers else {}
|
||||||
if not self.headers.get("Content-Type"):
|
self.headers = Headers(**{k.lower(): v for k, v in self.headers.items()})
|
||||||
self.headers["Content-Type"] = "text/html; charset=utf-8"
|
if not self.headers.get("content-type"):
|
||||||
self.headers["Server"] = "Spiderweb"
|
self.headers["content-type"] = "text/html; charset=utf-8"
|
||||||
self.headers["Date"] = datetime.datetime.now(tz=datetime.UTC).strftime(
|
self.headers["server"] = "Spiderweb"
|
||||||
|
self.headers["date"] = datetime.datetime.now(tz=datetime.UTC).strftime(
|
||||||
"%a, %d %b %Y %H:%M:%S GMT"
|
"%a, %d %b %Y %H:%M:%S GMT"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -89,10 +92,10 @@ class HttpResponse:
|
|||||||
attrs = [urllib.parse.quote_plus(value)] + attrs
|
attrs = [urllib.parse.quote_plus(value)] + attrs
|
||||||
cookie = f"{name}={'; '.join(attrs)}"
|
cookie = f"{name}={'; '.join(attrs)}"
|
||||||
|
|
||||||
if "Set-Cookie" in self.headers:
|
if "set-cookie" in self.headers:
|
||||||
self.headers["Set-Cookie"].append(cookie)
|
self.headers["set-cookie"].append(cookie)
|
||||||
else:
|
else:
|
||||||
self.headers["Set-Cookie"] = [cookie]
|
self.headers["set-cookie"] = [cookie]
|
||||||
|
|
||||||
def render(self) -> str:
|
def render(self) -> str:
|
||||||
return str(self.body)
|
return str(self.body)
|
||||||
@ -103,7 +106,7 @@ class FileResponse(HttpResponse):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.content_type = mimetypes.guess_type(self.filename)[0]
|
self.content_type = mimetypes.guess_type(self.filename)[0]
|
||||||
self.headers["Content-Type"] = self.content_type
|
self.headers["content-type"] = self.content_type
|
||||||
|
|
||||||
def render(self) -> list[bytes]:
|
def render(self) -> list[bytes]:
|
||||||
with open(self.filename, "rb") as f:
|
with open(self.filename, "rb") as f:
|
||||||
@ -114,7 +117,7 @@ class FileResponse(HttpResponse):
|
|||||||
class JsonResponse(HttpResponse):
|
class JsonResponse(HttpResponse):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.headers["Content-Type"] = "application/json"
|
self.headers["content-type"] = "application/json"
|
||||||
|
|
||||||
def render(self) -> str:
|
def render(self) -> str:
|
||||||
return json.dumps(self.data)
|
return json.dumps(self.data)
|
||||||
@ -124,7 +127,7 @@ class RedirectResponse(HttpResponse):
|
|||||||
def __init__(self, location: str, *args, **kwargs):
|
def __init__(self, location: str, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.status_code = 302
|
self.status_code = 302
|
||||||
self.headers["Location"] = location
|
self.headers["location"] = location
|
||||||
|
|
||||||
|
|
||||||
class TemplateResponse(HttpResponse):
|
class TemplateResponse(HttpResponse):
|
||||||
|
@ -71,7 +71,7 @@ def test_redirect_response():
|
|||||||
return RedirectResponse(location="/redirected")
|
return RedirectResponse(location="/redirected")
|
||||||
|
|
||||||
assert app(environ, start_response) == [b"None"]
|
assert app(environ, start_response) == [b"None"]
|
||||||
assert start_response.get_headers()["Location"] == "/redirected"
|
assert start_response.get_headers()["location"] == "/redirected"
|
||||||
|
|
||||||
|
|
||||||
def test_add_route_at_server_start():
|
def test_add_route_at_server_start():
|
||||||
@ -91,7 +91,7 @@ def test_add_route_at_server_start():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert app(environ, start_response) == [b"None"]
|
assert app(environ, start_response) == [b"None"]
|
||||||
assert start_response.get_headers()["Location"] == "/redirected"
|
assert start_response.get_headers()["location"] == "/redirected"
|
||||||
|
|
||||||
|
|
||||||
def test_redirect_on_append_slash():
|
def test_redirect_on_append_slash():
|
||||||
@ -104,7 +104,7 @@ def test_redirect_on_append_slash():
|
|||||||
|
|
||||||
environ["PATH_INFO"] = f"/hello"
|
environ["PATH_INFO"] = f"/hello"
|
||||||
assert app(environ, start_response) == [b"None"]
|
assert app(environ, start_response) == [b"None"]
|
||||||
assert start_response.get_headers()["Location"] == "/hello/"
|
assert start_response.get_headers()["location"] == "/hello/"
|
||||||
|
|
||||||
|
|
||||||
@given(st.text())
|
@given(st.text())
|
||||||
|
@ -63,3 +63,18 @@ def is_jsonable(data: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
except (TypeError, OverflowError):
|
except (TypeError, OverflowError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Headers(dict):
|
||||||
|
# special dict that forces lowercase for all keys
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return super().__getitem__(key.lower())
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
return super().__setitem__(key.lower(), value)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
return super().get(key.lower(), default)
|
||||||
|
|
||||||
|
def setdefault(self, key, default = None):
|
||||||
|
return super().setdefault(key.lower(), default)
|
Loading…
Reference in New Issue
Block a user