Compare commits

...

4 Commits

Author SHA1 Message Date
15a94b9879 CORS middleware! 2024-09-01 21:05:43 -04:00
9330918009 🔒 fix issues with CSRF middleware 2024-09-01 21:05:24 -04:00
572675b076 📝 add black code style icon 2024-09-01 19:12:51 -04:00
678190ae48 📝 add badges to readme 2024-09-01 18:16:28 -04:00
12 changed files with 371 additions and 50 deletions

View File

@ -1,5 +1,22 @@
# spiderweb
<p align="center">
<img
src="https://img.shields.io/pypi/v/spiderweb-framework.svg?style=for-the-badge"
alt="PyPI release version for Spiderweb"
/>
<a href="https://gitmoji.dev">
<img
src="https://img.shields.io/badge/gitmoji-%20😜%20😍-FFDD67.svg?style=for-the-badge"
alt="Gitmoji"
/>
</a>
<img
src="https://img.shields.io/badge/code%20style-black-000000.svg?style=for-the-badge"
alt="Code style: Black"
/>
</p>
As a professional web developer focusing on arcane uses of Django for arcane purposes, it occurred to me a little while ago that I didn't actually know how a web framework _worked_.
So I built one.

View File

@ -11,9 +11,6 @@ app = SpiderwebRouter(
)
```
> [!DANGER]
> The CSRFMiddleware is incomplete at best and dangerous at worst. I am not a security expert, and my implementation is [very susceptible to the thing it is meant to prevent](https://en.wikipedia.org/wiki/Cross-site_request_forgery). While this is an big issue (and moderately hilarious), the middleware is still provided to you in its unfinished state. Be aware.
Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more.
> [!TIP]

View File

@ -15,6 +15,7 @@ from spiderweb.response import (
app = SpiderwebRouter(
templates_dirs=["templates"],
middleware=[
"spiderweb.middleware.cors.CorsMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
"example_middleware.TestMiddleware",

View File

@ -8,3 +8,20 @@ __version__ = "0.12.0"
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
DATABASE_PROXY = DatabaseProxy()
DEFAULT_CORS_ALLOW_METHODS = (
"DELETE",
"GET",
"OPTIONS",
"PATCH",
"POST",
"PUT",
)
DEFAULT_CORS_ALLOW_HEADERS = (
"accept",
"authorization",
"content-type",
"user-agent",
"x-csrftoken",
"x-requested-with",
)

View File

@ -86,3 +86,7 @@ class UnusedMiddleware(SpiderwebException):
class NoResponseError(SpiderwebException):
pass
class StartupErrors(ExceptionGroup):
pass

View File

@ -1,16 +1,22 @@
import inspect
import logging
import pathlib
import re
import traceback
import urllib.parse as urlparse
from logging import Logger
from threading import Thread
from typing import Optional, Callable
from typing import Optional, Callable, Sequence, LiteralString, Literal
from wsgiref.simple_server import WSGIServer
from jinja2 import BaseLoader, Environment, FileSystemLoader
from peewee import Database, SqliteDatabase
from spiderweb.middleware import MiddlewareMixin
from spiderweb.constants import (
DEFAULT_CORS_ALLOW_METHODS,
DEFAULT_CORS_ALLOW_HEADERS,
)
from spiderweb.constants import (
DATABASE_PROXY,
DEFAULT_ENCODING,
@ -30,7 +36,7 @@ from spiderweb.request import Request
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
from spiderweb.routes import RoutesMixin
from spiderweb.secrets import FernetMixin
from spiderweb.utils import get_http_status_by_code
from spiderweb.utils import get_http_status_by_code, convert_url_to_regex
console_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@ -42,24 +48,32 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
*,
addr: str = None,
port: int = None,
allowed_hosts=None,
cors_allowed_origins=None,
cors_allow_all_origins=False,
allowed_hosts: Sequence[str | re.Pattern] = None,
cors_allowed_origins: Sequence[str] = None,
cors_allowed_origins_regexes: Sequence[str] = None,
cors_allow_all_origins: bool = False,
cors_urls_regex: str | re.Pattern[str] = r"^.*$",
cors_allow_methods: Sequence[str] = None,
cors_allow_headers: Sequence[str] = None,
cors_expose_headers: Sequence[str] = None,
cors_preflight_max_age: int = 86400,
cors_allow_credentials: bool = False,
csrf_trusted_origins: Sequence[str] = None,
db: Optional[Database] = None,
templates_dirs: list[str] = None,
middleware: list[str] = None,
templates_dirs: Sequence[str] = None,
middleware: Sequence[str] = None,
append_slash: bool = False,
staticfiles_dirs: list[str] = None,
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
staticfiles_dirs: Sequence[str] = None,
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
error_routes: dict[int, Callable] = None,
secret_key: str = None,
session_max_age=60 * 60 * 24 * 14, # 2 weeks
session_cookie_name="swsession",
session_cookie_secure=False, # should be true if serving over HTTPS
session_cookie_http_only=True,
session_cookie_same_site="lax",
session_cookie_path="/",
log=None,
session_max_age: int = 60 * 60 * 24 * 14, # 2 weeks
session_cookie_name: str = "swsession",
session_cookie_secure: bool = False, # should be true if serving over HTTPS
session_cookie_http_only: bool = True,
session_cookie_same_site: Literal["strict", "lax", "none"] = "lax",
session_cookie_path: str = "/",
log: Logger = None,
**kwargs,
):
self._routes = {}
@ -75,10 +89,23 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self._middleware: list[str] = middleware or []
self.middleware: list[Callable] = []
self.secret_key = secret_key if secret_key else self.generate_key()
self.allowed_hosts = allowed_hosts or ["*"]
self._allowed_hosts = allowed_hosts or ["*"]
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
self.cors_allowed_origins = cors_allowed_origins or []
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or []
self.cors_allow_all_origins = cors_allow_all_origins
self.cors_urls_regex = cors_urls_regex
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
self.cors_allow_headers = cors_allow_headers or DEFAULT_CORS_ALLOW_HEADERS
self.cors_expose_headers = cors_expose_headers or []
self.cors_preflight_max_age = cors_preflight_max_age
self.cors_allow_credentials = cors_allow_credentials
self._csrf_trusted_origins = csrf_trusted_origins or []
self.csrf_trusted_origins = [
convert_url_to_regex(i) for i in self._csrf_trusted_origins
]
self.extra_data = kwargs
@ -154,7 +181,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
for v in varies:
headers.append(("Vary", v))
start_response(status, headers)
rendered_output = resp.render()
@ -231,6 +257,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
start_response, request, self.get_error_route(500)(request)
)
def check_valid_host(self, request) -> bool:
host = request.headers.get("http_host")
if not host:
return False
for option in self.allowed_hosts:
if re.match(option, host):
return True
return False
def __call__(self, environ, start_response, *args, **kwargs):
"""Entry point for WSGI apps."""
request = self.get_request(environ)
@ -247,6 +282,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
# replace the potentially valid handler with the error route
handler = self.get_error_route(405)
if not self.check_valid_host(request):
handler = self.get_error_route(403)
if request.is_form_request():
form_data = urlparse.parse_qs(request.content)
for key, value in form_data.items():

View File

@ -1,9 +1,11 @@
from typing import Callable, ClassVar
import sys
from .base import SpiderwebMiddleware as SpiderwebMiddleware
from .cors import CorsMiddleware as CorsMiddleware
from .csrf import CSRFMiddleware as CSRFMiddleware
from .sessions import SessionMiddleware as SessionMiddleware
from ..exceptions import ConfigError, UnusedMiddleware
from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors
from ..request import Request
from ..response import HttpResponse
from ..utils import import_by_string
@ -27,10 +29,19 @@ class MiddlewareMixin:
self.middleware = middleware_by_reference
def run_middleware_checks(self):
errors = []
for middleware in self.middleware:
if hasattr(middleware, "checks"):
for check in middleware.checks:
check(server=self).check()
if issue := check(server=self).check():
errors.append(issue)
if errors:
# just show the messages
sys.tracebacklimit = 0
raise StartupErrors(
"Problems were identified during startup — cannot continue.", errors
)
def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.middleware:

View File

@ -1 +1,137 @@
# https://gist.github.com/FND/204ba41bf6ae485965ef
import re
from urllib.parse import urlsplit, SplitResult
from spiderweb.request import Request
from spiderweb.response import HttpResponse
from spiderweb.middleware import SpiderwebMiddleware
ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
class CorsMiddleware(SpiderwebMiddleware):
# heavily 'based' on https://github.com/adamchainz/django-cors-headers,
# which is provided under the MIT license. This is essentially a direct
# port, since django-cors-headers is battle-tested code that has been
# around for a long time and it works well. Shoutouts to Otto, Adam, and
# crew for helping make this a complete non-issue in Django for a very long
# time.
def is_enabled(self, request: Request):
return bool(re.match(self.server.cors_urls_regex, request.path))
def add_response_headers(self, request: Request, response: HttpResponse):
enabled = getattr(request, "_cors_enabled", None)
if enabled is None:
enabled = self.is_enabled(request)
if not enabled:
return response
if "vary" in response.headers:
response.headers["vary"].append("origin")
else:
response.headers["vary"] = ["origin"]
origin = request.headers.get("origin")
if not origin:
return response
try:
url = urlsplit(origin)
except ValueError:
return response
if (
not self.server.cors_allow_all_origins
and not self.origin_found_in_allow_lists(origin, url)
):
return response
if (
self.server.cors_allow_all_origins
and not self.server.cors_allow_credentials
):
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
else:
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
if self.server.cors_allow_credentials:
response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
if len(self.server.cors_expose_headers):
response.headers[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
self.server.cors_expose_headers
)
if request.method == "OPTIONS":
response.headers[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
self.server.cors_allow_headers
)
response.headers[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
self.server.cors_allow_methods
)
if self.server.cors_preflight_max_age:
response.headers[ACCESS_CONTROL_MAX_AGE] = str(
self.server.cors_preflight_max_age
)
if (
self.server.cors_allow_private_network
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
):
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
return response
def origin_found_in_allow_lists(self, origin: str, url: SplitResult) -> bool:
return (
(origin == "null" and origin in self.server.cors_allowed_origins)
or self._url_in_allowlist(url)
or self.regex_domain_match(origin)
)
def _url_in_allowlist(self, url: SplitResult) -> bool:
origins = [urlsplit(o) for o in self.server.cors_allowed_origins]
return any(
origin.scheme == url.scheme and origin.netloc == url.netloc
for origin in origins
)
def regex_domain_match(self, origin: str) -> bool:
return any(
re.match(domain_pattern, origin)
for domain_pattern in self.server.cors_allowed_origin_regexes
)
def process_request(self, request: Request) -> HttpResponse | None:
# Identify and handle a preflight request
# origin = request.META.get("HTTP_ORIGIN")
request._cors_enabled = self.is_enabled(request)
if (
request._cors_enabled
and request.method == "OPTIONS"
and "access-control-request-method" in request.headers
):
# this should be 204, but according to mozilla, not all browsers
# parse that correctly. See [204] comment below.
resp = HttpResponse(
"",
status_code=200,
headers={"content-type": "text/plain", "content-length": 0},
)
self.add_response_headers(request, resp)
return resp
def process_response(
self, request: Request, response: HttpResponse
) -> None:
self.add_response_headers(request, response)
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@ -1,4 +1,7 @@
import re
from re import Pattern
from datetime import datetime, timedelta
from typing import Optional
from spiderweb.exceptions import CSRFError, ConfigError
from spiderweb.middleware import SpiderwebMiddleware
@ -7,49 +10,85 @@ from spiderweb.response import HttpResponse
from spiderweb.server_checks import ServerCheck
class SessionCheck(ServerCheck):
class CheckForSessionMiddleware(ServerCheck):
SESSION_MIDDLEWARE_NOT_FOUND = (
"Session middleware is not enabled. It must be listed above"
"CSRFMiddleware in the middleware list."
)
def check(self) -> Optional[Exception]:
if (
"spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware
):
return ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
class VerifyCorrectMiddlewarePlacement(ServerCheck):
SESSION_MIDDLEWARE_BELOW_CSRF = (
"SessionMiddleware is enabled, but it must be listed above"
"CSRFMiddleware in the middleware list."
)
def check(self):
def check(self) -> Optional[Exception]:
if (
"spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware
):
raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
# this is handled by CheckForSessionMiddleware
return
if self.server._middleware.index(
"spiderweb.middleware.sessions.SessionMiddleware"
) > self.server._middleware.index(
"spiderweb.middleware.csrf.CSRFMiddleware"
):
raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
) > self.server._middleware.index("spiderweb.middleware.csrf.CSRFMiddleware"):
return ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
class VerifyCorrectFormatForTrustedOrigins(ServerCheck):
CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR = (
"The csrf_trusted_origins setting must be a list of strings."
)
def check(self) -> Optional[Exception]:
if not isinstance(self.server.csrf_trusted_origins, list):
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
for item in self.server.csrf_trusted_origins:
if not isinstance(item, Pattern):
# It's a pattern here because we've already manipulated it
# by the time this check runs
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
class CSRFMiddleware(SpiderwebMiddleware):
checks = [SessionCheck]
checks = [
CheckForSessionMiddleware,
VerifyCorrectMiddlewarePlacement,
VerifyCorrectFormatForTrustedOrigins,
]
CSRF_EXPIRY = 60 * 60 # 1 hour
def process_request(self, request: Request) -> HttpResponse | None:
if request.method == "POST":
trusted_origin = False
if hasattr(request.handler, "csrf_exempt"):
if request.handler.csrf_exempt is True:
return
if origin := request.headers.get("http_origin"):
for re_origin in self.server.csrf_trusted_origins:
if re.match(re_origin, origin):
trusted_origin = True
csrf_token = (
request.headers.get("X-CSRF-TOKEN")
or request.GET.get("csrf_token")
or request.POST.get("csrf_token")
)
if not trusted_origin:
if self.is_csrf_valid(request, csrf_token):
return None
else:

View File

@ -1,5 +1,5 @@
import re
from typing import Callable, Any, Optional
from typing import Callable, Any, Optional, Sequence
from spiderweb.constants import DEFAULT_ALLOWED_METHODS
from spiderweb.converters import * # noqa: F403
@ -30,7 +30,7 @@ class RoutesMixin:
# ones that start with underscores are the compiled versions, non-underscores
# are the user-supplied versions
_routes: dict
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]]
_error_routes: dict
error_routes: dict[int, Callable]
append_slash: bool

View File

@ -4,12 +4,16 @@ from datetime import timedelta
import pytest
from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors
from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
from spiderweb.tests.views_for_tests import (
form_view_with_csrf,
form_csrf_exempt,
form_view_without_csrf,
)
# app = SpiderwebRouter(
@ -99,18 +103,21 @@ def test_exploding_middleware():
def test_csrf_middleware_without_session_middleware():
_, environ, start_response = setup()
with pytest.raises(ConfigError) as e:
with pytest.raises(StartupErrors) as e:
SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"),
)
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
exceptiongroup = e.value.args[1]
assert (
exceptiongroup[0].args[0]
== csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND
)
def test_csrf_middleware_above_session_middleware():
_, environ, start_response = setup()
with pytest.raises(ConfigError) as e:
with pytest.raises(StartupErrors) as e:
SpiderwebRouter(
middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware",
@ -118,8 +125,11 @@ def test_csrf_middleware_above_session_middleware():
],
db=SqliteDatabase("spiderweb-tests.db"),
)
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
exceptiongroup = e.value.args[1]
assert (
exceptiongroup[0].args[0]
== csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF
)
def test_csrf_middleware():
@ -211,6 +221,7 @@ def test_csrf_expired_token():
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
)
environ["REQUEST_METHOD"] = "POST"
environ["HTTP_ORIGIN"] = "example.com"
environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata)
@ -254,3 +265,44 @@ def test_csrf_exempt():
environ["PATH_INFO"] = "/2"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp2
def test_csrf_trusted_origins():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
csrf_trusted_origins=[
"example.com",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["REQUEST_METHOD"] = "POST"
formdata = "name=bob"
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "notvalid.com"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "example.com"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert resp2 == '{"name": "bob"}'

View File

@ -1,4 +1,5 @@
import json
import re
import secrets
import string
from http import HTTPStatus
@ -76,5 +77,13 @@ class Headers(dict):
def get(self, key, default=None):
return super().get(key.lower(), default)
def setdefault(self, key, default = None):
def setdefault(self, key, default=None):
return super().setdefault(key.lower(), default)
def convert_url_to_regex(url: str | re.Pattern) -> re.Pattern:
if isinstance(url, re.Pattern):
return url
url = url.replace(".", "\\.")
url = url.replace("*", ".+")
return re.compile(url)