From d98d61e4b15723f0f7da8fce0846083cf39b00ea Mon Sep 17 00:00:00 2001
From: Joe Kaufeld
Date: Fri, 30 Aug 2024 20:34:43 -0400
Subject: [PATCH 1/9] :construction: add groundwork for origins
---
pyproject.toml | 2 +-
spiderweb/constants.py | 2 +-
spiderweb/main.py | 4 +++-
spiderweb/middleware/cors.py | 1 +
spiderweb/request.py | 3 +++
templates/test.html | 3 +++
6 files changed, 12 insertions(+), 3 deletions(-)
create mode 100644 spiderweb/middleware/cors.py
diff --git a/pyproject.toml b/pyproject.toml
index 5f2dc91..7878d3b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "spiderweb-framework"
-version = "0.11.0"
+version = "0.12.0"
description = "A small web framework, just big enough for a spider."
authors = ["Joe Kaufeld "]
readme = "README.md"
diff --git a/spiderweb/constants.py b/spiderweb/constants.py
index cf8734d..9ebc2ad 100644
--- a/spiderweb/constants.py
+++ b/spiderweb/constants.py
@@ -2,7 +2,7 @@ from peewee import DatabaseProxy
DEFAULT_ALLOWED_METHODS = ["GET"]
DEFAULT_ENCODING = "UTF-8"
-__version__ = "0.11.0"
+__version__ = "0.12.0"
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
diff --git a/spiderweb/main.py b/spiderweb/main.py
index b0333dc..17ac9ca 100644
--- a/spiderweb/main.py
+++ b/spiderweb/main.py
@@ -49,6 +49,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
staticfiles_dirs: list[str] = None,
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
error_routes: dict[int, Callable] = None,
+ allowed_hosts=None,
secret_key: str = None,
session_max_age=60 * 60 * 24 * 14, # 2 weeks
session_cookie_name="swsession",
@@ -69,9 +70,10 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.append_slash = append_slash
self.templates_dirs = templates_dirs
self.staticfiles_dirs = staticfiles_dirs
- self._middleware: list[str] = middleware if middleware else []
+ self._middleware: list[str] = middleware or []
self.middleware: list[Callable] = []
self.secret_key = secret_key if secret_key else self.generate_key()
+ self.allowed_hosts = allowed_hosts or ["*"]
self.extra_data = kwargs
diff --git a/spiderweb/middleware/cors.py b/spiderweb/middleware/cors.py
new file mode 100644
index 0000000..37de52b
--- /dev/null
+++ b/spiderweb/middleware/cors.py
@@ -0,0 +1 @@
+# https://gist.github.com/FND/204ba41bf6ae485965ef
diff --git a/spiderweb/request.py b/spiderweb/request.py
index 6f95cde..a617e99 100644
--- a/spiderweb/request.py
+++ b/spiderweb/request.py
@@ -72,6 +72,9 @@ class Request:
]
for f in fields:
self.META[f] = self.environ.get(f)
+ for f in self.environ.keys():
+ if f.startswith("HTTP_"):
+ self.META[f] = self.environ[f]
self.META["client_address"] = get_client_address(self.environ)
def populate_cookies(self) -> None:
diff --git a/templates/test.html b/templates/test.html
index db08db0..6fa59b6 100644
--- a/templates/test.html
+++ b/templates/test.html
@@ -15,4 +15,7 @@
+
+ {{ request.META }}
+
{% endblock %}
From 4c4bd153be2de7febeb0930651a147447296f081 Mon Sep 17 00:00:00 2001
From: Joe Kaufeld
Date: Sat, 31 Aug 2024 22:40:54 -0400
Subject: [PATCH 2/9] :sparkles: make headers case-insensitive
---
spiderweb/main.py | 22 +++++++++++++++++-----
spiderweb/request.py | 18 ++++++++++--------
spiderweb/response.py | 23 +++++++++++++----------
spiderweb/tests/test_responses.py | 6 +++---
spiderweb/utils.py | 15 +++++++++++++++
5 files changed, 58 insertions(+), 26 deletions(-)
diff --git a/spiderweb/main.py b/spiderweb/main.py
index 17ac9ca..ae1ad69 100644
--- a/spiderweb/main.py
+++ b/spiderweb/main.py
@@ -42,6 +42,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
*,
addr: str = None,
port: int = None,
+ allowed_hosts=None,
+ cors_allowed_origins=None,
+ cors_allow_all_origins=False,
db: Optional[Database] = None,
templates_dirs: list[str] = None,
middleware: list[str] = None,
@@ -49,7 +52,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
staticfiles_dirs: list[str] = None,
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
error_routes: dict[int, Callable] = None,
- allowed_hosts=None,
secret_key: str = None,
session_max_age=60 * 60 * 24 * 14, # 2 weeks
session_cookie_name="swsession",
@@ -75,6 +77,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.secret_key = secret_key if secret_key else self.generate_key()
self.allowed_hosts = allowed_hosts or ["*"]
+ self.cors_allowed_origins = cors_allowed_origins or []
+ self.cors_allow_all_origins = cors_allow_all_origins
+
self.extra_data = kwargs
# session middleware
@@ -136,12 +141,19 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
try:
status = get_http_status_by_code(resp.status_code)
cookies = []
- if "Set-Cookie" in resp.headers:
- cookies = resp.headers["Set-Cookie"]
- del resp.headers["Set-Cookie"]
+ varies = []
+ if "set-cookie" in resp.headers:
+ cookies = resp.headers["set-cookie"]
+ del resp.headers["set-cookie"]
+ if "vary" in resp.headers:
+ varies = resp.headers["vary"]
+ del resp.headers["vary"]
headers = list(resp.headers.items())
for c in cookies:
headers.append(("Set-Cookie", c))
+ for v in varies:
+ headers.append(("Vary", v))
+
start_response(status, headers)
@@ -182,7 +194,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
):
try:
status = get_http_status_by_code(500)
- headers = [("Content-type", "text/plain; charset=utf-8")]
+ headers = [("Content-Type", "text/plain; charset=utf-8")]
start_response(status, headers)
diff --git a/spiderweb/request.py b/spiderweb/request.py
index a617e99..90e0160 100644
--- a/spiderweb/request.py
+++ b/spiderweb/request.py
@@ -2,7 +2,7 @@ import json
from urllib.parse import urlparse
from spiderweb.constants import DEFAULT_ENCODING
-from spiderweb.utils import get_client_address
+from spiderweb.utils import get_client_address, Headers
class Request:
@@ -38,20 +38,22 @@ class Request:
self.populate_meta()
self.populate_cookies()
- content_length = int(self.headers.get("CONTENT_LENGTH") or 0)
+ content_length = int(self.headers.get("content_length") or 0)
if content_length:
self.content = (
self.environ["wsgi.input"].read(content_length).decode(DEFAULT_ENCODING)
)
def populate_headers(self) -> None:
- self.headers |= {
- "CONTENT_TYPE": self.environ.get("CONTENT_TYPE"),
- "CONTENT_LENGTH": self.environ.get("CONTENT_LENGTH"),
+ data = self.headers
+ data |= {
+ "content_type": self.environ.get("CONTENT_TYPE"),
+ "content_length": self.environ.get("CONTENT_LENGTH"),
}
for k, v in self.environ.items():
if k.startswith("HTTP_"):
- self.headers[k] = v
+ data[k] = v
+ self.headers = Headers(**{k.lower(): v for k, v in data.items()})
def populate_meta(self) -> None:
# all caps fields are from WSGI, lowercase names
@@ -89,6 +91,6 @@ class Request:
def is_form_request(self) -> bool:
return (
- "CONTENT_TYPE" in self.headers
- and self.headers["CONTENT_TYPE"] == "application/x-www-form-urlencoded"
+ "content_type" in self.headers
+ and self.headers["content_type"] == "application/x-www-form-urlencoded"
)
diff --git a/spiderweb/response.py b/spiderweb/response.py
index 0de90ab..0e79648 100644
--- a/spiderweb/response.py
+++ b/spiderweb/response.py
@@ -10,6 +10,8 @@ from wsgiref.util import FileWrapper
from spiderweb.constants import REGEX_COOKIE_NAME
from spiderweb.exceptions import GeneralException
from spiderweb.request import Request
+from spiderweb.utils import Headers
+
mimetypes.init()
@@ -28,10 +30,11 @@ class HttpResponse:
self.context = context if context else {}
self.status_code = status_code
self.headers = headers if headers else {}
- if not self.headers.get("Content-Type"):
- self.headers["Content-Type"] = "text/html; charset=utf-8"
- self.headers["Server"] = "Spiderweb"
- self.headers["Date"] = datetime.datetime.now(tz=datetime.UTC).strftime(
+ self.headers = Headers(**{k.lower(): v for k, v in self.headers.items()})
+ if not self.headers.get("content-type"):
+ self.headers["content-type"] = "text/html; charset=utf-8"
+ self.headers["server"] = "Spiderweb"
+ self.headers["date"] = datetime.datetime.now(tz=datetime.UTC).strftime(
"%a, %d %b %Y %H:%M:%S GMT"
)
@@ -89,10 +92,10 @@ class HttpResponse:
attrs = [urllib.parse.quote_plus(value)] + attrs
cookie = f"{name}={'; '.join(attrs)}"
- if "Set-Cookie" in self.headers:
- self.headers["Set-Cookie"].append(cookie)
+ if "set-cookie" in self.headers:
+ self.headers["set-cookie"].append(cookie)
else:
- self.headers["Set-Cookie"] = [cookie]
+ self.headers["set-cookie"] = [cookie]
def render(self) -> str:
return str(self.body)
@@ -103,7 +106,7 @@ class FileResponse(HttpResponse):
super().__init__(*args, **kwargs)
self.filename = filename
self.content_type = mimetypes.guess_type(self.filename)[0]
- self.headers["Content-Type"] = self.content_type
+ self.headers["content-type"] = self.content_type
def render(self) -> list[bytes]:
with open(self.filename, "rb") as f:
@@ -114,7 +117,7 @@ class FileResponse(HttpResponse):
class JsonResponse(HttpResponse):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.headers["Content-Type"] = "application/json"
+ self.headers["content-type"] = "application/json"
def render(self) -> str:
return json.dumps(self.data)
@@ -124,7 +127,7 @@ class RedirectResponse(HttpResponse):
def __init__(self, location: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.status_code = 302
- self.headers["Location"] = location
+ self.headers["location"] = location
class TemplateResponse(HttpResponse):
diff --git a/spiderweb/tests/test_responses.py b/spiderweb/tests/test_responses.py
index 63c0011..df989bf 100644
--- a/spiderweb/tests/test_responses.py
+++ b/spiderweb/tests/test_responses.py
@@ -71,7 +71,7 @@ def test_redirect_response():
return RedirectResponse(location="/redirected")
assert app(environ, start_response) == [b"None"]
- assert start_response.get_headers()["Location"] == "/redirected"
+ assert start_response.get_headers()["location"] == "/redirected"
def test_add_route_at_server_start():
@@ -91,7 +91,7 @@ def test_add_route_at_server_start():
)
assert app(environ, start_response) == [b"None"]
- assert start_response.get_headers()["Location"] == "/redirected"
+ assert start_response.get_headers()["location"] == "/redirected"
def test_redirect_on_append_slash():
@@ -104,7 +104,7 @@ def test_redirect_on_append_slash():
environ["PATH_INFO"] = f"/hello"
assert app(environ, start_response) == [b"None"]
- assert start_response.get_headers()["Location"] == "/hello/"
+ assert start_response.get_headers()["location"] == "/hello/"
@given(st.text())
diff --git a/spiderweb/utils.py b/spiderweb/utils.py
index 42baf35..d24ef04 100644
--- a/spiderweb/utils.py
+++ b/spiderweb/utils.py
@@ -63,3 +63,18 @@ def is_jsonable(data: str) -> bool:
return True
except (TypeError, OverflowError):
return False
+
+
+class Headers(dict):
+ # special dict that forces lowercase for all keys
+ def __getitem__(self, key):
+ return super().__getitem__(key.lower())
+
+ def __setitem__(self, key, value):
+ return super().__setitem__(key.lower(), value)
+
+ def get(self, key, default=None):
+ return super().get(key.lower(), default)
+
+ def setdefault(self, key, default = None):
+ return super().setdefault(key.lower(), default)
\ No newline at end of file
From 678190ae480a592e6146cf40280fde63ec7fdf53 Mon Sep 17 00:00:00 2001
From: Joe Kaufeld
Date: Sun, 1 Sep 2024 18:16:28 -0400
Subject: [PATCH 3/9] :memo: add badges to readme
---
README.md | 13 +++++++++++++
1 file changed, 13 insertions(+)
diff --git a/README.md b/README.md
index 31b9039..85ffb09 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,18 @@
# spiderweb
+
+
+
+
+
+
+
As a professional web developer focusing on arcane uses of Django for arcane purposes, it occurred to me a little while ago that I didn't actually know how a web framework _worked_.
So I built one.
From 572675b07610b4eabb8e79980e3b4066f25d7595 Mon Sep 17 00:00:00 2001
From: Joe Kaufeld
Date: Sun, 1 Sep 2024 19:12:51 -0400
Subject: [PATCH 4/9] :memo: add black code style icon
---
README.md | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/README.md b/README.md
index 85ffb09..5c3944d 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,10 @@
alt="Gitmoji"
/>
+
As a professional web developer focusing on arcane uses of Django for arcane purposes, it occurred to me a little while ago that I didn't actually know how a web framework _worked_.
From 9330918009daf01e1ba32f041bda797c25483e2a Mon Sep 17 00:00:00 2001
From: Joe Kaufeld
Date: Sun, 1 Sep 2024 21:05:24 -0400
Subject: [PATCH 5/9] :lock: fix issues with CSRF middleware
---
docs/middleware/csrf.md | 3 --
spiderweb/main.py | 21 +++++++++--
spiderweb/middleware/csrf.py | 67 ++++++++++++++++++++++++++++--------
3 files changed, 72 insertions(+), 19 deletions(-)
diff --git a/docs/middleware/csrf.md b/docs/middleware/csrf.md
index 8458dd3..b1f9c7d 100644
--- a/docs/middleware/csrf.md
+++ b/docs/middleware/csrf.md
@@ -11,9 +11,6 @@ app = SpiderwebRouter(
)
```
-> [!DANGER]
-> The CSRFMiddleware is incomplete at best and dangerous at worst. I am not a security expert, and my implementation is [very susceptible to the thing it is meant to prevent](https://en.wikipedia.org/wiki/Cross-site_request_forgery). While this is an big issue (and moderately hilarious), the middleware is still provided to you in its unfinished state. Be aware.
-
Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more.
> [!TIP]
diff --git a/spiderweb/main.py b/spiderweb/main.py
index ae1ad69..2f5dfc9 100644
--- a/spiderweb/main.py
+++ b/spiderweb/main.py
@@ -45,6 +45,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
allowed_hosts=None,
cors_allowed_origins=None,
cors_allow_all_origins=False,
+ csrf_trusted_origins: Sequence[str] = None,
db: Optional[Database] = None,
templates_dirs: list[str] = None,
middleware: list[str] = None,
@@ -75,10 +76,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self._middleware: list[str] = middleware or []
self.middleware: list[Callable] = []
self.secret_key = secret_key if secret_key else self.generate_key()
- self.allowed_hosts = allowed_hosts or ["*"]
+ self._allowed_hosts = allowed_hosts or ["*"]
+ self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
self.cors_allowed_origins = cors_allowed_origins or []
self.cors_allow_all_origins = cors_allow_all_origins
+ self._csrf_trusted_origins = csrf_trusted_origins or []
+ self.csrf_trusted_origins = [
+ convert_url_to_regex(i) for i in self._csrf_trusted_origins
+ ]
self.extra_data = kwargs
@@ -154,7 +160,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
for v in varies:
headers.append(("Vary", v))
-
start_response(status, headers)
rendered_output = resp.render()
@@ -231,6 +236,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
start_response, request, self.get_error_route(500)(request)
)
+ def check_valid_host(self, request) -> bool:
+ host = request.headers.get("http_host")
+ if not host:
+ return False
+ for option in self.allowed_hosts:
+ if re.match(option, host):
+ return True
+ return False
+
def __call__(self, environ, start_response, *args, **kwargs):
"""Entry point for WSGI apps."""
request = self.get_request(environ)
@@ -247,6 +261,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
# replace the potentially valid handler with the error route
handler = self.get_error_route(405)
+ if not self.check_valid_host(request):
+ handler = self.get_error_route(403)
+
if request.is_form_request():
form_data = urlparse.parse_qs(request.content)
for key, value in form_data.items():
diff --git a/spiderweb/middleware/csrf.py b/spiderweb/middleware/csrf.py
index 5a128e7..3a0ffa9 100644
--- a/spiderweb/middleware/csrf.py
+++ b/spiderweb/middleware/csrf.py
@@ -1,4 +1,7 @@
+import re
+from re import Pattern
from datetime import datetime, timedelta
+from typing import Optional
from spiderweb.exceptions import CSRFError, ConfigError
from spiderweb.middleware import SpiderwebMiddleware
@@ -7,53 +10,89 @@ from spiderweb.response import HttpResponse
from spiderweb.server_checks import ServerCheck
-class SessionCheck(ServerCheck):
-
+class CheckForSessionMiddleware(ServerCheck):
SESSION_MIDDLEWARE_NOT_FOUND = (
"Session middleware is not enabled. It must be listed above"
"CSRFMiddleware in the middleware list."
)
+
+ def check(self) -> Optional[Exception]:
+ if (
+ "spiderweb.middleware.sessions.SessionMiddleware"
+ not in self.server._middleware
+ ):
+ return ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
+
+
+class VerifyCorrectMiddlewarePlacement(ServerCheck):
SESSION_MIDDLEWARE_BELOW_CSRF = (
"SessionMiddleware is enabled, but it must be listed above"
"CSRFMiddleware in the middleware list."
)
- def check(self):
-
+ def check(self) -> Optional[Exception]:
if (
"spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware
):
- raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
+ # this is handled by CheckForSessionMiddleware
+ return
if self.server._middleware.index(
"spiderweb.middleware.sessions.SessionMiddleware"
- ) > self.server._middleware.index(
- "spiderweb.middleware.csrf.CSRFMiddleware"
- ):
- raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
+ ) > self.server._middleware.index("spiderweb.middleware.csrf.CSRFMiddleware"):
+ return ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
+
+
+class VerifyCorrectFormatForTrustedOrigins(ServerCheck):
+ CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR = (
+ "The csrf_trusted_origins setting must be a list of strings."
+ )
+
+ def check(self) -> Optional[Exception]:
+ if not isinstance(self.server.csrf_trusted_origins, list):
+ return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
+
+ for item in self.server.csrf_trusted_origins:
+ if not isinstance(item, Pattern):
+ # It's a pattern here because we've already manipulated it
+ # by the time this check runs
+ return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
class CSRFMiddleware(SpiderwebMiddleware):
- checks = [SessionCheck]
+ checks = [
+ CheckForSessionMiddleware,
+ VerifyCorrectMiddlewarePlacement,
+ VerifyCorrectFormatForTrustedOrigins,
+ ]
CSRF_EXPIRY = 60 * 60 # 1 hour
def process_request(self, request: Request) -> HttpResponse | None:
if request.method == "POST":
+ trusted_origin = False
if hasattr(request.handler, "csrf_exempt"):
if request.handler.csrf_exempt is True:
return
+ if origin := request.headers.get("http_origin"):
+
+ for re_origin in self.server.csrf_trusted_origins:
+ if re.match(re_origin, origin):
+ trusted_origin = True
+
csrf_token = (
request.headers.get("X-CSRF-TOKEN")
or request.GET.get("csrf_token")
or request.POST.get("csrf_token")
)
- if self.is_csrf_valid(request, csrf_token):
- return None
- else:
- raise CSRFError()
+
+ if not trusted_origin:
+ if self.is_csrf_valid(request, csrf_token):
+ return None
+ else:
+ raise CSRFError()
return None
def process_response(self, request: Request, response: HttpResponse) -> None:
From 15a94b9879e34eb50cdab635e1cf03ea96a45c2c Mon Sep 17 00:00:00 2001
From: Joe Kaufeld
Date: Sun, 1 Sep 2024 21:05:43 -0400
Subject: [PATCH 6/9] :sparkles: CORS middleware!
---
example.py | 1 +
spiderweb/constants.py | 17 ++++
spiderweb/exceptions.py | 4 +
spiderweb/main.py | 53 +++++++----
spiderweb/middleware/__init__.py | 15 +++-
spiderweb/middleware/cors.py | 138 ++++++++++++++++++++++++++++-
spiderweb/routes.py | 4 +-
spiderweb/tests/test_middleware.py | 68 ++++++++++++--
spiderweb/utils.py | 13 ++-
9 files changed, 282 insertions(+), 31 deletions(-)
diff --git a/example.py b/example.py
index c2543fe..f763eca 100644
--- a/example.py
+++ b/example.py
@@ -15,6 +15,7 @@ from spiderweb.response import (
app = SpiderwebRouter(
templates_dirs=["templates"],
middleware=[
+ "spiderweb.middleware.cors.CorsMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
"example_middleware.TestMiddleware",
diff --git a/spiderweb/constants.py b/spiderweb/constants.py
index 9ebc2ad..cb46532 100644
--- a/spiderweb/constants.py
+++ b/spiderweb/constants.py
@@ -8,3 +8,20 @@ __version__ = "0.12.0"
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
DATABASE_PROXY = DatabaseProxy()
+
+DEFAULT_CORS_ALLOW_METHODS = (
+ "DELETE",
+ "GET",
+ "OPTIONS",
+ "PATCH",
+ "POST",
+ "PUT",
+)
+DEFAULT_CORS_ALLOW_HEADERS = (
+ "accept",
+ "authorization",
+ "content-type",
+ "user-agent",
+ "x-csrftoken",
+ "x-requested-with",
+)
diff --git a/spiderweb/exceptions.py b/spiderweb/exceptions.py
index bdba675..f784c23 100644
--- a/spiderweb/exceptions.py
+++ b/spiderweb/exceptions.py
@@ -86,3 +86,7 @@ class UnusedMiddleware(SpiderwebException):
class NoResponseError(SpiderwebException):
pass
+
+
+class StartupErrors(ExceptionGroup):
+ pass
diff --git a/spiderweb/main.py b/spiderweb/main.py
index 2f5dfc9..eb0e44d 100644
--- a/spiderweb/main.py
+++ b/spiderweb/main.py
@@ -1,16 +1,22 @@
import inspect
import logging
import pathlib
+import re
import traceback
import urllib.parse as urlparse
+from logging import Logger
from threading import Thread
-from typing import Optional, Callable
+from typing import Optional, Callable, Sequence, LiteralString, Literal
from wsgiref.simple_server import WSGIServer
from jinja2 import BaseLoader, Environment, FileSystemLoader
from peewee import Database, SqliteDatabase
from spiderweb.middleware import MiddlewareMixin
+from spiderweb.constants import (
+ DEFAULT_CORS_ALLOW_METHODS,
+ DEFAULT_CORS_ALLOW_HEADERS,
+)
from spiderweb.constants import (
DATABASE_PROXY,
DEFAULT_ENCODING,
@@ -30,7 +36,7 @@ from spiderweb.request import Request
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
from spiderweb.routes import RoutesMixin
from spiderweb.secrets import FernetMixin
-from spiderweb.utils import get_http_status_by_code
+from spiderweb.utils import get_http_status_by_code, convert_url_to_regex
console_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@@ -42,25 +48,32 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
*,
addr: str = None,
port: int = None,
- allowed_hosts=None,
- cors_allowed_origins=None,
- cors_allow_all_origins=False,
+ allowed_hosts: Sequence[str | re.Pattern] = None,
+ cors_allowed_origins: Sequence[str] = None,
+ cors_allowed_origins_regexes: Sequence[str] = None,
+ cors_allow_all_origins: bool = False,
+ cors_urls_regex: str | re.Pattern[str] = r"^.*$",
+ cors_allow_methods: Sequence[str] = None,
+ cors_allow_headers: Sequence[str] = None,
+ cors_expose_headers: Sequence[str] = None,
+ cors_preflight_max_age: int = 86400,
+ cors_allow_credentials: bool = False,
csrf_trusted_origins: Sequence[str] = None,
db: Optional[Database] = None,
- templates_dirs: list[str] = None,
- middleware: list[str] = None,
+ templates_dirs: Sequence[str] = None,
+ middleware: Sequence[str] = None,
append_slash: bool = False,
- staticfiles_dirs: list[str] = None,
- routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
+ staticfiles_dirs: Sequence[str] = None,
+ routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
error_routes: dict[int, Callable] = None,
secret_key: str = None,
- session_max_age=60 * 60 * 24 * 14, # 2 weeks
- session_cookie_name="swsession",
- session_cookie_secure=False, # should be true if serving over HTTPS
- session_cookie_http_only=True,
- session_cookie_same_site="lax",
- session_cookie_path="/",
- log=None,
+ session_max_age: int = 60 * 60 * 24 * 14, # 2 weeks
+ session_cookie_name: str = "swsession",
+ session_cookie_secure: bool = False, # should be true if serving over HTTPS
+ session_cookie_http_only: bool = True,
+ session_cookie_same_site: Literal["strict", "lax", "none"] = "lax",
+ session_cookie_path: str = "/",
+ log: Logger = None,
**kwargs,
):
self._routes = {}
@@ -80,7 +93,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
self.cors_allowed_origins = cors_allowed_origins or []
+ self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or []
self.cors_allow_all_origins = cors_allow_all_origins
+ self.cors_urls_regex = cors_urls_regex
+ self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
+ self.cors_allow_headers = cors_allow_headers or DEFAULT_CORS_ALLOW_HEADERS
+ self.cors_expose_headers = cors_expose_headers or []
+ self.cors_preflight_max_age = cors_preflight_max_age
+ self.cors_allow_credentials = cors_allow_credentials
+
self._csrf_trusted_origins = csrf_trusted_origins or []
self.csrf_trusted_origins = [
convert_url_to_regex(i) for i in self._csrf_trusted_origins
diff --git a/spiderweb/middleware/__init__.py b/spiderweb/middleware/__init__.py
index 3ffeb8e..265f2a5 100644
--- a/spiderweb/middleware/__init__.py
+++ b/spiderweb/middleware/__init__.py
@@ -1,9 +1,11 @@
from typing import Callable, ClassVar
+import sys
from .base import SpiderwebMiddleware as SpiderwebMiddleware
+from .cors import CorsMiddleware as CorsMiddleware
from .csrf import CSRFMiddleware as CSRFMiddleware
from .sessions import SessionMiddleware as SessionMiddleware
-from ..exceptions import ConfigError, UnusedMiddleware
+from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors
from ..request import Request
from ..response import HttpResponse
from ..utils import import_by_string
@@ -27,10 +29,19 @@ class MiddlewareMixin:
self.middleware = middleware_by_reference
def run_middleware_checks(self):
+ errors = []
for middleware in self.middleware:
if hasattr(middleware, "checks"):
for check in middleware.checks:
- check(server=self).check()
+ if issue := check(server=self).check():
+ errors.append(issue)
+
+ if errors:
+ # just show the messages
+ sys.tracebacklimit = 0
+ raise StartupErrors(
+ "Problems were identified during startup — cannot continue.", errors
+ )
def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.middleware:
diff --git a/spiderweb/middleware/cors.py b/spiderweb/middleware/cors.py
index 37de52b..9a1bcc1 100644
--- a/spiderweb/middleware/cors.py
+++ b/spiderweb/middleware/cors.py
@@ -1 +1,137 @@
-# https://gist.github.com/FND/204ba41bf6ae485965ef
+import re
+from urllib.parse import urlsplit, SplitResult
+
+from spiderweb.request import Request
+from spiderweb.response import HttpResponse
+from spiderweb.middleware import SpiderwebMiddleware
+
+ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
+ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
+ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
+ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
+ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
+ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
+ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
+ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
+
+
+class CorsMiddleware(SpiderwebMiddleware):
+ # heavily 'based' on https://github.com/adamchainz/django-cors-headers,
+ # which is provided under the MIT license. This is essentially a direct
+ # port, since django-cors-headers is battle-tested code that has been
+ # around for a long time and it works well. Shoutouts to Otto, Adam, and
+ # crew for helping make this a complete non-issue in Django for a very long
+ # time.
+
+ def is_enabled(self, request: Request):
+ return bool(re.match(self.server.cors_urls_regex, request.path))
+
+ def add_response_headers(self, request: Request, response: HttpResponse):
+ enabled = getattr(request, "_cors_enabled", None)
+ if enabled is None:
+ enabled = self.is_enabled(request)
+
+ if not enabled:
+ return response
+
+ if "vary" in response.headers:
+ response.headers["vary"].append("origin")
+ else:
+ response.headers["vary"] = ["origin"]
+
+ origin = request.headers.get("origin")
+ if not origin:
+ return response
+
+ try:
+ url = urlsplit(origin)
+ except ValueError:
+ return response
+
+ if (
+ not self.server.cors_allow_all_origins
+ and not self.origin_found_in_allow_lists(origin, url)
+ ):
+ return response
+
+ if (
+ self.server.cors_allow_all_origins
+ and not self.server.cors_allow_credentials
+ ):
+ response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
+ else:
+ response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
+
+ if self.server.cors_allow_credentials:
+ response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
+
+ if len(self.server.cors_expose_headers):
+ response.headers[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
+ self.server.cors_expose_headers
+ )
+
+ if request.method == "OPTIONS":
+ response.headers[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
+ self.server.cors_allow_headers
+ )
+ response.headers[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
+ self.server.cors_allow_methods
+ )
+ if self.server.cors_preflight_max_age:
+ response.headers[ACCESS_CONTROL_MAX_AGE] = str(
+ self.server.cors_preflight_max_age
+ )
+
+ if (
+ self.server.cors_allow_private_network
+ and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
+ ):
+ response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
+
+ return response
+
+ def origin_found_in_allow_lists(self, origin: str, url: SplitResult) -> bool:
+ return (
+ (origin == "null" and origin in self.server.cors_allowed_origins)
+ or self._url_in_allowlist(url)
+ or self.regex_domain_match(origin)
+ )
+
+ def _url_in_allowlist(self, url: SplitResult) -> bool:
+ origins = [urlsplit(o) for o in self.server.cors_allowed_origins]
+ return any(
+ origin.scheme == url.scheme and origin.netloc == url.netloc
+ for origin in origins
+ )
+
+ def regex_domain_match(self, origin: str) -> bool:
+ return any(
+ re.match(domain_pattern, origin)
+ for domain_pattern in self.server.cors_allowed_origin_regexes
+ )
+
+ def process_request(self, request: Request) -> HttpResponse | None:
+ # Identify and handle a preflight request
+ # origin = request.META.get("HTTP_ORIGIN")
+ request._cors_enabled = self.is_enabled(request)
+ if (
+ request._cors_enabled
+ and request.method == "OPTIONS"
+ and "access-control-request-method" in request.headers
+ ):
+ # this should be 204, but according to mozilla, not all browsers
+ # parse that correctly. See [204] comment below.
+ resp = HttpResponse(
+ "",
+ status_code=200,
+ headers={"content-type": "text/plain", "content-length": 0},
+ )
+ self.add_response_headers(request, resp)
+ return resp
+
+ def process_response(
+ self, request: Request, response: HttpResponse
+ ) -> None:
+ self.add_response_headers(request, response)
+
+# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
diff --git a/spiderweb/routes.py b/spiderweb/routes.py
index 4b26448..3622d8b 100644
--- a/spiderweb/routes.py
+++ b/spiderweb/routes.py
@@ -1,5 +1,5 @@
import re
-from typing import Callable, Any, Optional
+from typing import Callable, Any, Optional, Sequence
from spiderweb.constants import DEFAULT_ALLOWED_METHODS
from spiderweb.converters import * # noqa: F403
@@ -30,7 +30,7 @@ class RoutesMixin:
# ones that start with underscores are the compiled versions, non-underscores
# are the user-supplied versions
_routes: dict
- routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
+ routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]]
_error_routes: dict
error_routes: dict[int, Callable]
append_slash: bool
diff --git a/spiderweb/tests/test_middleware.py b/spiderweb/tests/test_middleware.py
index e727b07..f785875 100644
--- a/spiderweb/tests/test_middleware.py
+++ b/spiderweb/tests/test_middleware.py
@@ -4,12 +4,16 @@ from datetime import timedelta
import pytest
from peewee import SqliteDatabase
-from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
+from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors
from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup
-from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
+from spiderweb.tests.views_for_tests import (
+ form_view_with_csrf,
+ form_csrf_exempt,
+ form_view_without_csrf,
+)
# app = SpiderwebRouter(
@@ -99,18 +103,21 @@ def test_exploding_middleware():
def test_csrf_middleware_without_session_middleware():
_, environ, start_response = setup()
- with pytest.raises(ConfigError) as e:
+ with pytest.raises(StartupErrors) as e:
SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"),
)
-
- assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
+ exceptiongroup = e.value.args[1]
+ assert (
+ exceptiongroup[0].args[0]
+ == csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND
+ )
def test_csrf_middleware_above_session_middleware():
_, environ, start_response = setup()
- with pytest.raises(ConfigError) as e:
+ with pytest.raises(StartupErrors) as e:
SpiderwebRouter(
middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware",
@@ -118,8 +125,11 @@ def test_csrf_middleware_above_session_middleware():
],
db=SqliteDatabase("spiderweb-tests.db"),
)
-
- assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
+ exceptiongroup = e.value.args[1]
+ assert (
+ exceptiongroup[0].args[0]
+ == csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF
+ )
def test_csrf_middleware():
@@ -211,6 +221,7 @@ def test_csrf_expired_token():
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
)
environ["REQUEST_METHOD"] = "POST"
+ environ["HTTP_ORIGIN"] = "example.com"
environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata)
@@ -254,3 +265,44 @@ def test_csrf_exempt():
environ["PATH_INFO"] = "/2"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp2
+
+
+def test_csrf_trusted_origins():
+ _, environ, start_response = setup()
+ app = SpiderwebRouter(
+ middleware=[
+ "spiderweb.middleware.sessions.SessionMiddleware",
+ "spiderweb.middleware.csrf.CSRFMiddleware",
+ ],
+ csrf_trusted_origins=[
+ "example.com",
+ ],
+ db=SqliteDatabase("spiderweb-tests.db"),
+ )
+
+ app.add_route("/", form_view_without_csrf, ["GET", "POST"])
+
+ environ["HTTP_USER_AGENT"] = "hi"
+ environ["REMOTE_ADDR"] = "1.1.1.1"
+ environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
+ environ["REQUEST_METHOD"] = "POST"
+
+ formdata = "name=bob"
+ environ["CONTENT_LENGTH"] = len(formdata)
+ b_handle = BytesIO()
+ b_handle.write(formdata.encode(DEFAULT_ENCODING))
+ b_handle.seek(0)
+ environ["wsgi.input"] = BufferedReader(b_handle)
+
+ environ["HTTP_ORIGIN"] = "notvalid.com"
+ resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
+ assert "CSRF token is invalid" in resp
+
+ b_handle = BytesIO()
+ b_handle.write(formdata.encode(DEFAULT_ENCODING))
+ b_handle.seek(0)
+ environ["wsgi.input"] = BufferedReader(b_handle)
+
+ environ["HTTP_ORIGIN"] = "example.com"
+ resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
+ assert resp2 == '{"name": "bob"}'
diff --git a/spiderweb/utils.py b/spiderweb/utils.py
index d24ef04..e00bcb7 100644
--- a/spiderweb/utils.py
+++ b/spiderweb/utils.py
@@ -1,4 +1,5 @@
import json
+import re
import secrets
import string
from http import HTTPStatus
@@ -76,5 +77,13 @@ class Headers(dict):
def get(self, key, default=None):
return super().get(key.lower(), default)
- def setdefault(self, key, default = None):
- return super().setdefault(key.lower(), default)
\ No newline at end of file
+ def setdefault(self, key, default=None):
+ return super().setdefault(key.lower(), default)
+
+
+def convert_url_to_regex(url: str | re.Pattern) -> re.Pattern:
+ if isinstance(url, re.Pattern):
+ return url
+ url = url.replace(".", "\\.")
+ url = url.replace("*", ".+")
+ return re.compile(url)
From e6f477fa57c513c7eb9e40987b2f28055a31d2f4 Mon Sep 17 00:00:00 2001
From: Joe Kaufeld
Date: Sun, 1 Sep 2024 23:16:01 -0400
Subject: [PATCH 7/9] :construction_worker: add plugins for docsify
---
docs/index.html | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/docs/index.html b/docs/index.html
index 1464fc9..7826a54 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -48,6 +48,7 @@
+
@@ -57,5 +58,8 @@
+
+
+