Compare commits
No commits in common. "15a94b9879e34eb50cdab635e1cf03ea96a45c2c" and "4c4bd153be2de7febeb0930651a147447296f081" have entirely different histories.
15a94b9879
...
4c4bd153be
17
README.md
17
README.md
@ -1,22 +1,5 @@
|
|||||||
# spiderweb
|
# spiderweb
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img
|
|
||||||
src="https://img.shields.io/pypi/v/spiderweb-framework.svg?style=for-the-badge"
|
|
||||||
alt="PyPI release version for Spiderweb"
|
|
||||||
/>
|
|
||||||
<a href="https://gitmoji.dev">
|
|
||||||
<img
|
|
||||||
src="https://img.shields.io/badge/gitmoji-%20😜%20😍-FFDD67.svg?style=for-the-badge"
|
|
||||||
alt="Gitmoji"
|
|
||||||
/>
|
|
||||||
</a>
|
|
||||||
<img
|
|
||||||
src="https://img.shields.io/badge/code%20style-black-000000.svg?style=for-the-badge"
|
|
||||||
alt="Code style: Black"
|
|
||||||
/>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
As a professional web developer focusing on arcane uses of Django for arcane purposes, it occurred to me a little while ago that I didn't actually know how a web framework _worked_.
|
As a professional web developer focusing on arcane uses of Django for arcane purposes, it occurred to me a little while ago that I didn't actually know how a web framework _worked_.
|
||||||
|
|
||||||
So I built one.
|
So I built one.
|
||||||
|
@ -11,6 +11,9 @@ app = SpiderwebRouter(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!DANGER]
|
||||||
|
> The CSRFMiddleware is incomplete at best and dangerous at worst. I am not a security expert, and my implementation is [very susceptible to the thing it is meant to prevent](https://en.wikipedia.org/wiki/Cross-site_request_forgery). While this is an big issue (and moderately hilarious), the middleware is still provided to you in its unfinished state. Be aware.
|
||||||
|
|
||||||
Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more.
|
Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
|
@ -15,7 +15,6 @@ from spiderweb.response import (
|
|||||||
app = SpiderwebRouter(
|
app = SpiderwebRouter(
|
||||||
templates_dirs=["templates"],
|
templates_dirs=["templates"],
|
||||||
middleware=[
|
middleware=[
|
||||||
"spiderweb.middleware.cors.CorsMiddleware",
|
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
"example_middleware.TestMiddleware",
|
"example_middleware.TestMiddleware",
|
||||||
|
@ -8,20 +8,3 @@ __version__ = "0.12.0"
|
|||||||
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
||||||
|
|
||||||
DATABASE_PROXY = DatabaseProxy()
|
DATABASE_PROXY = DatabaseProxy()
|
||||||
|
|
||||||
DEFAULT_CORS_ALLOW_METHODS = (
|
|
||||||
"DELETE",
|
|
||||||
"GET",
|
|
||||||
"OPTIONS",
|
|
||||||
"PATCH",
|
|
||||||
"POST",
|
|
||||||
"PUT",
|
|
||||||
)
|
|
||||||
DEFAULT_CORS_ALLOW_HEADERS = (
|
|
||||||
"accept",
|
|
||||||
"authorization",
|
|
||||||
"content-type",
|
|
||||||
"user-agent",
|
|
||||||
"x-csrftoken",
|
|
||||||
"x-requested-with",
|
|
||||||
)
|
|
||||||
|
@ -86,7 +86,3 @@ class UnusedMiddleware(SpiderwebException):
|
|||||||
|
|
||||||
class NoResponseError(SpiderwebException):
|
class NoResponseError(SpiderwebException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StartupErrors(ExceptionGroup):
|
|
||||||
pass
|
|
||||||
|
@ -1,22 +1,16 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
|
||||||
import traceback
|
import traceback
|
||||||
import urllib.parse as urlparse
|
import urllib.parse as urlparse
|
||||||
from logging import Logger
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Optional, Callable, Sequence, LiteralString, Literal
|
from typing import Optional, Callable
|
||||||
from wsgiref.simple_server import WSGIServer
|
from wsgiref.simple_server import WSGIServer
|
||||||
|
|
||||||
from jinja2 import BaseLoader, Environment, FileSystemLoader
|
from jinja2 import BaseLoader, Environment, FileSystemLoader
|
||||||
from peewee import Database, SqliteDatabase
|
from peewee import Database, SqliteDatabase
|
||||||
|
|
||||||
from spiderweb.middleware import MiddlewareMixin
|
from spiderweb.middleware import MiddlewareMixin
|
||||||
from spiderweb.constants import (
|
|
||||||
DEFAULT_CORS_ALLOW_METHODS,
|
|
||||||
DEFAULT_CORS_ALLOW_HEADERS,
|
|
||||||
)
|
|
||||||
from spiderweb.constants import (
|
from spiderweb.constants import (
|
||||||
DATABASE_PROXY,
|
DATABASE_PROXY,
|
||||||
DEFAULT_ENCODING,
|
DEFAULT_ENCODING,
|
||||||
@ -36,7 +30,7 @@ from spiderweb.request import Request
|
|||||||
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
|
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
|
||||||
from spiderweb.routes import RoutesMixin
|
from spiderweb.routes import RoutesMixin
|
||||||
from spiderweb.secrets import FernetMixin
|
from spiderweb.secrets import FernetMixin
|
||||||
from spiderweb.utils import get_http_status_by_code, convert_url_to_regex
|
from spiderweb.utils import get_http_status_by_code
|
||||||
|
|
||||||
console_logger = logging.getLogger(__name__)
|
console_logger = logging.getLogger(__name__)
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -48,32 +42,24 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
*,
|
*,
|
||||||
addr: str = None,
|
addr: str = None,
|
||||||
port: int = None,
|
port: int = None,
|
||||||
allowed_hosts: Sequence[str | re.Pattern] = None,
|
allowed_hosts=None,
|
||||||
cors_allowed_origins: Sequence[str] = None,
|
cors_allowed_origins=None,
|
||||||
cors_allowed_origins_regexes: Sequence[str] = None,
|
cors_allow_all_origins=False,
|
||||||
cors_allow_all_origins: bool = False,
|
|
||||||
cors_urls_regex: str | re.Pattern[str] = r"^.*$",
|
|
||||||
cors_allow_methods: Sequence[str] = None,
|
|
||||||
cors_allow_headers: Sequence[str] = None,
|
|
||||||
cors_expose_headers: Sequence[str] = None,
|
|
||||||
cors_preflight_max_age: int = 86400,
|
|
||||||
cors_allow_credentials: bool = False,
|
|
||||||
csrf_trusted_origins: Sequence[str] = None,
|
|
||||||
db: Optional[Database] = None,
|
db: Optional[Database] = None,
|
||||||
templates_dirs: Sequence[str] = None,
|
templates_dirs: list[str] = None,
|
||||||
middleware: Sequence[str] = None,
|
middleware: list[str] = None,
|
||||||
append_slash: bool = False,
|
append_slash: bool = False,
|
||||||
staticfiles_dirs: Sequence[str] = None,
|
staticfiles_dirs: list[str] = None,
|
||||||
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
||||||
error_routes: dict[int, Callable] = None,
|
error_routes: dict[int, Callable] = None,
|
||||||
secret_key: str = None,
|
secret_key: str = None,
|
||||||
session_max_age: int = 60 * 60 * 24 * 14, # 2 weeks
|
session_max_age=60 * 60 * 24 * 14, # 2 weeks
|
||||||
session_cookie_name: str = "swsession",
|
session_cookie_name="swsession",
|
||||||
session_cookie_secure: bool = False, # should be true if serving over HTTPS
|
session_cookie_secure=False, # should be true if serving over HTTPS
|
||||||
session_cookie_http_only: bool = True,
|
session_cookie_http_only=True,
|
||||||
session_cookie_same_site: Literal["strict", "lax", "none"] = "lax",
|
session_cookie_same_site="lax",
|
||||||
session_cookie_path: str = "/",
|
session_cookie_path="/",
|
||||||
log: Logger = None,
|
log=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._routes = {}
|
self._routes = {}
|
||||||
@ -89,23 +75,10 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
self._middleware: list[str] = middleware or []
|
self._middleware: list[str] = middleware or []
|
||||||
self.middleware: list[Callable] = []
|
self.middleware: list[Callable] = []
|
||||||
self.secret_key = secret_key if secret_key else self.generate_key()
|
self.secret_key = secret_key if secret_key else self.generate_key()
|
||||||
self._allowed_hosts = allowed_hosts or ["*"]
|
self.allowed_hosts = allowed_hosts or ["*"]
|
||||||
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
|
|
||||||
|
|
||||||
self.cors_allowed_origins = cors_allowed_origins or []
|
self.cors_allowed_origins = cors_allowed_origins or []
|
||||||
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or []
|
|
||||||
self.cors_allow_all_origins = cors_allow_all_origins
|
self.cors_allow_all_origins = cors_allow_all_origins
|
||||||
self.cors_urls_regex = cors_urls_regex
|
|
||||||
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
|
|
||||||
self.cors_allow_headers = cors_allow_headers or DEFAULT_CORS_ALLOW_HEADERS
|
|
||||||
self.cors_expose_headers = cors_expose_headers or []
|
|
||||||
self.cors_preflight_max_age = cors_preflight_max_age
|
|
||||||
self.cors_allow_credentials = cors_allow_credentials
|
|
||||||
|
|
||||||
self._csrf_trusted_origins = csrf_trusted_origins or []
|
|
||||||
self.csrf_trusted_origins = [
|
|
||||||
convert_url_to_regex(i) for i in self._csrf_trusted_origins
|
|
||||||
]
|
|
||||||
|
|
||||||
self.extra_data = kwargs
|
self.extra_data = kwargs
|
||||||
|
|
||||||
@ -181,6 +154,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
for v in varies:
|
for v in varies:
|
||||||
headers.append(("Vary", v))
|
headers.append(("Vary", v))
|
||||||
|
|
||||||
|
|
||||||
start_response(status, headers)
|
start_response(status, headers)
|
||||||
|
|
||||||
rendered_output = resp.render()
|
rendered_output = resp.render()
|
||||||
@ -257,15 +231,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
start_response, request, self.get_error_route(500)(request)
|
start_response, request, self.get_error_route(500)(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_valid_host(self, request) -> bool:
|
|
||||||
host = request.headers.get("http_host")
|
|
||||||
if not host:
|
|
||||||
return False
|
|
||||||
for option in self.allowed_hosts:
|
|
||||||
if re.match(option, host):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __call__(self, environ, start_response, *args, **kwargs):
|
def __call__(self, environ, start_response, *args, **kwargs):
|
||||||
"""Entry point for WSGI apps."""
|
"""Entry point for WSGI apps."""
|
||||||
request = self.get_request(environ)
|
request = self.get_request(environ)
|
||||||
@ -282,9 +247,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
# replace the potentially valid handler with the error route
|
# replace the potentially valid handler with the error route
|
||||||
handler = self.get_error_route(405)
|
handler = self.get_error_route(405)
|
||||||
|
|
||||||
if not self.check_valid_host(request):
|
|
||||||
handler = self.get_error_route(403)
|
|
||||||
|
|
||||||
if request.is_form_request():
|
if request.is_form_request():
|
||||||
form_data = urlparse.parse_qs(request.content)
|
form_data = urlparse.parse_qs(request.content)
|
||||||
for key, value in form_data.items():
|
for key, value in form_data.items():
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from typing import Callable, ClassVar
|
from typing import Callable, ClassVar
|
||||||
import sys
|
|
||||||
|
|
||||||
from .base import SpiderwebMiddleware as SpiderwebMiddleware
|
from .base import SpiderwebMiddleware as SpiderwebMiddleware
|
||||||
from .cors import CorsMiddleware as CorsMiddleware
|
|
||||||
from .csrf import CSRFMiddleware as CSRFMiddleware
|
from .csrf import CSRFMiddleware as CSRFMiddleware
|
||||||
from .sessions import SessionMiddleware as SessionMiddleware
|
from .sessions import SessionMiddleware as SessionMiddleware
|
||||||
from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors
|
from ..exceptions import ConfigError, UnusedMiddleware
|
||||||
from ..request import Request
|
from ..request import Request
|
||||||
from ..response import HttpResponse
|
from ..response import HttpResponse
|
||||||
from ..utils import import_by_string
|
from ..utils import import_by_string
|
||||||
@ -29,19 +27,10 @@ class MiddlewareMixin:
|
|||||||
self.middleware = middleware_by_reference
|
self.middleware = middleware_by_reference
|
||||||
|
|
||||||
def run_middleware_checks(self):
|
def run_middleware_checks(self):
|
||||||
errors = []
|
|
||||||
for middleware in self.middleware:
|
for middleware in self.middleware:
|
||||||
if hasattr(middleware, "checks"):
|
if hasattr(middleware, "checks"):
|
||||||
for check in middleware.checks:
|
for check in middleware.checks:
|
||||||
if issue := check(server=self).check():
|
check(server=self).check()
|
||||||
errors.append(issue)
|
|
||||||
|
|
||||||
if errors:
|
|
||||||
# just show the messages
|
|
||||||
sys.tracebacklimit = 0
|
|
||||||
raise StartupErrors(
|
|
||||||
"Problems were identified during startup — cannot continue.", errors
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_request_middleware(self, request: Request) -> None | bool:
|
def process_request_middleware(self, request: Request) -> None | bool:
|
||||||
for middleware in self.middleware:
|
for middleware in self.middleware:
|
||||||
|
@ -1,137 +1 @@
|
|||||||
import re
|
# https://gist.github.com/FND/204ba41bf6ae485965ef
|
||||||
from urllib.parse import urlsplit, SplitResult
|
|
||||||
|
|
||||||
from spiderweb.request import Request
|
|
||||||
from spiderweb.response import HttpResponse
|
|
||||||
from spiderweb.middleware import SpiderwebMiddleware
|
|
||||||
|
|
||||||
ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
|
|
||||||
ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
|
|
||||||
ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
|
|
||||||
ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
|
|
||||||
ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
|
|
||||||
ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
|
|
||||||
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
|
|
||||||
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
|
|
||||||
|
|
||||||
|
|
||||||
class CorsMiddleware(SpiderwebMiddleware):
|
|
||||||
# heavily 'based' on https://github.com/adamchainz/django-cors-headers,
|
|
||||||
# which is provided under the MIT license. This is essentially a direct
|
|
||||||
# port, since django-cors-headers is battle-tested code that has been
|
|
||||||
# around for a long time and it works well. Shoutouts to Otto, Adam, and
|
|
||||||
# crew for helping make this a complete non-issue in Django for a very long
|
|
||||||
# time.
|
|
||||||
|
|
||||||
def is_enabled(self, request: Request):
|
|
||||||
return bool(re.match(self.server.cors_urls_regex, request.path))
|
|
||||||
|
|
||||||
def add_response_headers(self, request: Request, response: HttpResponse):
|
|
||||||
enabled = getattr(request, "_cors_enabled", None)
|
|
||||||
if enabled is None:
|
|
||||||
enabled = self.is_enabled(request)
|
|
||||||
|
|
||||||
if not enabled:
|
|
||||||
return response
|
|
||||||
|
|
||||||
if "vary" in response.headers:
|
|
||||||
response.headers["vary"].append("origin")
|
|
||||||
else:
|
|
||||||
response.headers["vary"] = ["origin"]
|
|
||||||
|
|
||||||
origin = request.headers.get("origin")
|
|
||||||
if not origin:
|
|
||||||
return response
|
|
||||||
|
|
||||||
try:
|
|
||||||
url = urlsplit(origin)
|
|
||||||
except ValueError:
|
|
||||||
return response
|
|
||||||
|
|
||||||
if (
|
|
||||||
not self.server.cors_allow_all_origins
|
|
||||||
and not self.origin_found_in_allow_lists(origin, url)
|
|
||||||
):
|
|
||||||
return response
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.server.cors_allow_all_origins
|
|
||||||
and not self.server.cors_allow_credentials
|
|
||||||
):
|
|
||||||
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
|
|
||||||
else:
|
|
||||||
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
|
|
||||||
|
|
||||||
if self.server.cors_allow_credentials:
|
|
||||||
response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
|
|
||||||
|
|
||||||
if len(self.server.cors_expose_headers):
|
|
||||||
response.headers[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
|
|
||||||
self.server.cors_expose_headers
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.method == "OPTIONS":
|
|
||||||
response.headers[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
|
|
||||||
self.server.cors_allow_headers
|
|
||||||
)
|
|
||||||
response.headers[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
|
|
||||||
self.server.cors_allow_methods
|
|
||||||
)
|
|
||||||
if self.server.cors_preflight_max_age:
|
|
||||||
response.headers[ACCESS_CONTROL_MAX_AGE] = str(
|
|
||||||
self.server.cors_preflight_max_age
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.server.cors_allow_private_network
|
|
||||||
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
|
|
||||||
):
|
|
||||||
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def origin_found_in_allow_lists(self, origin: str, url: SplitResult) -> bool:
|
|
||||||
return (
|
|
||||||
(origin == "null" and origin in self.server.cors_allowed_origins)
|
|
||||||
or self._url_in_allowlist(url)
|
|
||||||
or self.regex_domain_match(origin)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _url_in_allowlist(self, url: SplitResult) -> bool:
|
|
||||||
origins = [urlsplit(o) for o in self.server.cors_allowed_origins]
|
|
||||||
return any(
|
|
||||||
origin.scheme == url.scheme and origin.netloc == url.netloc
|
|
||||||
for origin in origins
|
|
||||||
)
|
|
||||||
|
|
||||||
def regex_domain_match(self, origin: str) -> bool:
|
|
||||||
return any(
|
|
||||||
re.match(domain_pattern, origin)
|
|
||||||
for domain_pattern in self.server.cors_allowed_origin_regexes
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_request(self, request: Request) -> HttpResponse | None:
|
|
||||||
# Identify and handle a preflight request
|
|
||||||
# origin = request.META.get("HTTP_ORIGIN")
|
|
||||||
request._cors_enabled = self.is_enabled(request)
|
|
||||||
if (
|
|
||||||
request._cors_enabled
|
|
||||||
and request.method == "OPTIONS"
|
|
||||||
and "access-control-request-method" in request.headers
|
|
||||||
):
|
|
||||||
# this should be 204, but according to mozilla, not all browsers
|
|
||||||
# parse that correctly. See [204] comment below.
|
|
||||||
resp = HttpResponse(
|
|
||||||
"",
|
|
||||||
status_code=200,
|
|
||||||
headers={"content-type": "text/plain", "content-length": 0},
|
|
||||||
)
|
|
||||||
self.add_response_headers(request, resp)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def process_response(
|
|
||||||
self, request: Request, response: HttpResponse
|
|
||||||
) -> None:
|
|
||||||
self.add_response_headers(request, response)
|
|
||||||
|
|
||||||
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
import re
|
|
||||||
from re import Pattern
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from spiderweb.exceptions import CSRFError, ConfigError
|
from spiderweb.exceptions import CSRFError, ConfigError
|
||||||
from spiderweb.middleware import SpiderwebMiddleware
|
from spiderweb.middleware import SpiderwebMiddleware
|
||||||
@ -10,85 +7,49 @@ from spiderweb.response import HttpResponse
|
|||||||
from spiderweb.server_checks import ServerCheck
|
from spiderweb.server_checks import ServerCheck
|
||||||
|
|
||||||
|
|
||||||
class CheckForSessionMiddleware(ServerCheck):
|
class SessionCheck(ServerCheck):
|
||||||
|
|
||||||
SESSION_MIDDLEWARE_NOT_FOUND = (
|
SESSION_MIDDLEWARE_NOT_FOUND = (
|
||||||
"Session middleware is not enabled. It must be listed above"
|
"Session middleware is not enabled. It must be listed above"
|
||||||
"CSRFMiddleware in the middleware list."
|
"CSRFMiddleware in the middleware list."
|
||||||
)
|
)
|
||||||
|
|
||||||
def check(self) -> Optional[Exception]:
|
|
||||||
if (
|
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware"
|
|
||||||
not in self.server._middleware
|
|
||||||
):
|
|
||||||
return ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
|
|
||||||
|
|
||||||
|
|
||||||
class VerifyCorrectMiddlewarePlacement(ServerCheck):
|
|
||||||
SESSION_MIDDLEWARE_BELOW_CSRF = (
|
SESSION_MIDDLEWARE_BELOW_CSRF = (
|
||||||
"SessionMiddleware is enabled, but it must be listed above"
|
"SessionMiddleware is enabled, but it must be listed above"
|
||||||
"CSRFMiddleware in the middleware list."
|
"CSRFMiddleware in the middleware list."
|
||||||
)
|
)
|
||||||
|
|
||||||
def check(self) -> Optional[Exception]:
|
def check(self):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware"
|
"spiderweb.middleware.sessions.SessionMiddleware"
|
||||||
not in self.server._middleware
|
not in self.server._middleware
|
||||||
):
|
):
|
||||||
# this is handled by CheckForSessionMiddleware
|
raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
|
||||||
return
|
|
||||||
|
|
||||||
if self.server._middleware.index(
|
if self.server._middleware.index(
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware"
|
"spiderweb.middleware.sessions.SessionMiddleware"
|
||||||
) > self.server._middleware.index("spiderweb.middleware.csrf.CSRFMiddleware"):
|
) > self.server._middleware.index(
|
||||||
return ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
|
"spiderweb.middleware.csrf.CSRFMiddleware"
|
||||||
|
):
|
||||||
|
raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
|
||||||
class VerifyCorrectFormatForTrustedOrigins(ServerCheck):
|
|
||||||
CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR = (
|
|
||||||
"The csrf_trusted_origins setting must be a list of strings."
|
|
||||||
)
|
|
||||||
|
|
||||||
def check(self) -> Optional[Exception]:
|
|
||||||
if not isinstance(self.server.csrf_trusted_origins, list):
|
|
||||||
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
|
|
||||||
|
|
||||||
for item in self.server.csrf_trusted_origins:
|
|
||||||
if not isinstance(item, Pattern):
|
|
||||||
# It's a pattern here because we've already manipulated it
|
|
||||||
# by the time this check runs
|
|
||||||
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
|
|
||||||
|
|
||||||
|
|
||||||
class CSRFMiddleware(SpiderwebMiddleware):
|
class CSRFMiddleware(SpiderwebMiddleware):
|
||||||
|
|
||||||
checks = [
|
checks = [SessionCheck]
|
||||||
CheckForSessionMiddleware,
|
|
||||||
VerifyCorrectMiddlewarePlacement,
|
|
||||||
VerifyCorrectFormatForTrustedOrigins,
|
|
||||||
]
|
|
||||||
|
|
||||||
CSRF_EXPIRY = 60 * 60 # 1 hour
|
CSRF_EXPIRY = 60 * 60 # 1 hour
|
||||||
|
|
||||||
def process_request(self, request: Request) -> HttpResponse | None:
|
def process_request(self, request: Request) -> HttpResponse | None:
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
trusted_origin = False
|
|
||||||
if hasattr(request.handler, "csrf_exempt"):
|
if hasattr(request.handler, "csrf_exempt"):
|
||||||
if request.handler.csrf_exempt is True:
|
if request.handler.csrf_exempt is True:
|
||||||
return
|
return
|
||||||
if origin := request.headers.get("http_origin"):
|
|
||||||
|
|
||||||
for re_origin in self.server.csrf_trusted_origins:
|
|
||||||
if re.match(re_origin, origin):
|
|
||||||
trusted_origin = True
|
|
||||||
|
|
||||||
csrf_token = (
|
csrf_token = (
|
||||||
request.headers.get("X-CSRF-TOKEN")
|
request.headers.get("X-CSRF-TOKEN")
|
||||||
or request.GET.get("csrf_token")
|
or request.GET.get("csrf_token")
|
||||||
or request.POST.get("csrf_token")
|
or request.POST.get("csrf_token")
|
||||||
)
|
)
|
||||||
|
|
||||||
if not trusted_origin:
|
|
||||||
if self.is_csrf_valid(request, csrf_token):
|
if self.is_csrf_valid(request, csrf_token):
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Callable, Any, Optional, Sequence
|
from typing import Callable, Any, Optional
|
||||||
|
|
||||||
from spiderweb.constants import DEFAULT_ALLOWED_METHODS
|
from spiderweb.constants import DEFAULT_ALLOWED_METHODS
|
||||||
from spiderweb.converters import * # noqa: F403
|
from spiderweb.converters import * # noqa: F403
|
||||||
@ -30,7 +30,7 @@ class RoutesMixin:
|
|||||||
# ones that start with underscores are the compiled versions, non-underscores
|
# ones that start with underscores are the compiled versions, non-underscores
|
||||||
# are the user-supplied versions
|
# are the user-supplied versions
|
||||||
_routes: dict
|
_routes: dict
|
||||||
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]]
|
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
|
||||||
_error_routes: dict
|
_error_routes: dict
|
||||||
error_routes: dict[int, Callable]
|
error_routes: dict[int, Callable]
|
||||||
append_slash: bool
|
append_slash: bool
|
||||||
|
@ -4,16 +4,12 @@ from datetime import timedelta
|
|||||||
import pytest
|
import pytest
|
||||||
from peewee import SqliteDatabase
|
from peewee import SqliteDatabase
|
||||||
|
|
||||||
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors
|
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
|
||||||
from spiderweb.constants import DEFAULT_ENCODING
|
from spiderweb.constants import DEFAULT_ENCODING
|
||||||
from spiderweb.middleware.sessions import Session
|
from spiderweb.middleware.sessions import Session
|
||||||
from spiderweb.middleware import csrf
|
from spiderweb.middleware import csrf
|
||||||
from spiderweb.tests.helpers import setup
|
from spiderweb.tests.helpers import setup
|
||||||
from spiderweb.tests.views_for_tests import (
|
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
|
||||||
form_view_with_csrf,
|
|
||||||
form_csrf_exempt,
|
|
||||||
form_view_without_csrf,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# app = SpiderwebRouter(
|
# app = SpiderwebRouter(
|
||||||
@ -103,21 +99,18 @@ def test_exploding_middleware():
|
|||||||
|
|
||||||
def test_csrf_middleware_without_session_middleware():
|
def test_csrf_middleware_without_session_middleware():
|
||||||
_, environ, start_response = setup()
|
_, environ, start_response = setup()
|
||||||
with pytest.raises(StartupErrors) as e:
|
with pytest.raises(ConfigError) as e:
|
||||||
SpiderwebRouter(
|
SpiderwebRouter(
|
||||||
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
|
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
|
||||||
db=SqliteDatabase("spiderweb-tests.db"),
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
exceptiongroup = e.value.args[1]
|
|
||||||
assert (
|
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
|
||||||
exceptiongroup[0].args[0]
|
|
||||||
== csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_csrf_middleware_above_session_middleware():
|
def test_csrf_middleware_above_session_middleware():
|
||||||
_, environ, start_response = setup()
|
_, environ, start_response = setup()
|
||||||
with pytest.raises(StartupErrors) as e:
|
with pytest.raises(ConfigError) as e:
|
||||||
SpiderwebRouter(
|
SpiderwebRouter(
|
||||||
middleware=[
|
middleware=[
|
||||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
@ -125,11 +118,8 @@ def test_csrf_middleware_above_session_middleware():
|
|||||||
],
|
],
|
||||||
db=SqliteDatabase("spiderweb-tests.db"),
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
exceptiongroup = e.value.args[1]
|
|
||||||
assert (
|
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
|
||||||
exceptiongroup[0].args[0]
|
|
||||||
== csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_csrf_middleware():
|
def test_csrf_middleware():
|
||||||
@ -221,7 +211,6 @@ def test_csrf_expired_token():
|
|||||||
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
|
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
|
||||||
)
|
)
|
||||||
environ["REQUEST_METHOD"] = "POST"
|
environ["REQUEST_METHOD"] = "POST"
|
||||||
environ["HTTP_ORIGIN"] = "example.com"
|
|
||||||
environ["HTTP_X_CSRF_TOKEN"] = token
|
environ["HTTP_X_CSRF_TOKEN"] = token
|
||||||
environ["CONTENT_LENGTH"] = len(formdata)
|
environ["CONTENT_LENGTH"] = len(formdata)
|
||||||
|
|
||||||
@ -265,44 +254,3 @@ def test_csrf_exempt():
|
|||||||
environ["PATH_INFO"] = "/2"
|
environ["PATH_INFO"] = "/2"
|
||||||
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
assert "CSRF token is invalid" in resp2
|
assert "CSRF token is invalid" in resp2
|
||||||
|
|
||||||
|
|
||||||
def test_csrf_trusted_origins():
|
|
||||||
_, environ, start_response = setup()
|
|
||||||
app = SpiderwebRouter(
|
|
||||||
middleware=[
|
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
|
||||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
|
||||||
],
|
|
||||||
csrf_trusted_origins=[
|
|
||||||
"example.com",
|
|
||||||
],
|
|
||||||
db=SqliteDatabase("spiderweb-tests.db"),
|
|
||||||
)
|
|
||||||
|
|
||||||
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
|
|
||||||
|
|
||||||
environ["HTTP_USER_AGENT"] = "hi"
|
|
||||||
environ["REMOTE_ADDR"] = "1.1.1.1"
|
|
||||||
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
|
|
||||||
environ["REQUEST_METHOD"] = "POST"
|
|
||||||
|
|
||||||
formdata = "name=bob"
|
|
||||||
environ["CONTENT_LENGTH"] = len(formdata)
|
|
||||||
b_handle = BytesIO()
|
|
||||||
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
|
||||||
b_handle.seek(0)
|
|
||||||
environ["wsgi.input"] = BufferedReader(b_handle)
|
|
||||||
|
|
||||||
environ["HTTP_ORIGIN"] = "notvalid.com"
|
|
||||||
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
|
||||||
assert "CSRF token is invalid" in resp
|
|
||||||
|
|
||||||
b_handle = BytesIO()
|
|
||||||
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
|
||||||
b_handle.seek(0)
|
|
||||||
environ["wsgi.input"] = BufferedReader(b_handle)
|
|
||||||
|
|
||||||
environ["HTTP_ORIGIN"] = "example.com"
|
|
||||||
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
|
||||||
assert resp2 == '{"name": "bob"}'
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@ -77,13 +76,5 @@ class Headers(dict):
|
|||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
return super().get(key.lower(), default)
|
return super().get(key.lower(), default)
|
||||||
|
|
||||||
def setdefault(self, key, default=None):
|
def setdefault(self, key, default = None):
|
||||||
return super().setdefault(key.lower(), default)
|
return super().setdefault(key.lower(), default)
|
||||||
|
|
||||||
|
|
||||||
def convert_url_to_regex(url: str | re.Pattern) -> re.Pattern:
|
|
||||||
if isinstance(url, re.Pattern):
|
|
||||||
return url
|
|
||||||
url = url.replace(".", "\\.")
|
|
||||||
url = url.replace("*", ".+")
|
|
||||||
return re.compile(url)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user