add server checks & fix csrf middleware

This commit is contained in:
Joe Kaufeld 2024-08-29 17:29:28 -04:00
parent aabe20cff7
commit c451aff1e2
16 changed files with 360 additions and 74 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "spiderweb-framework"
version = "0.10.0"
version = "0.11.0"
description = "A small web framework, just big enough for a spider."
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
readme = "README.md"

View File

@ -2,7 +2,7 @@ from peewee import DatabaseProxy
DEFAULT_ALLOWED_METHODS = ["GET"]
DEFAULT_ENCODING = "UTF-8"
__version__ = "0.10.0"
__version__ = "0.11.0"
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"

View File

@ -5,7 +5,7 @@ class SpiderwebException(Exception):
msg = self.args[0] if len(self.args) > 0 else ""
if msg:
return f"{name}() - {msg}"
return f"{self.__class__.__name__}()"
return f"{name}()"
class SpiderwebNetworkException(SpiderwebException):

View File

@ -32,7 +32,7 @@ from spiderweb.routes import RoutesMixin
from spiderweb.secrets import FernetMixin
from spiderweb.utils import get_http_status_by_code
file_logger = logging.getLogger(__name__)
console_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@ -57,7 +57,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
session_cookie_same_site="lax",
session_cookie_path="/",
log=None,
**kwargs
**kwargs,
):
self._routes = {}
self.routes = routes
@ -69,7 +69,8 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.append_slash = append_slash
self.templates_dirs = templates_dirs
self.staticfiles_dirs = staticfiles_dirs
self.middleware = middleware if middleware else []
self._middleware: list[str] = middleware if middleware else []
self.middleware: list[Callable] = []
self.secret_key = secret_key if secret_key else self.generate_key()
self.extra_data = kwargs
@ -84,7 +85,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.DEFAULT_ENCODING = DEFAULT_ENCODING
self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS
self.log: logging.Logger = log if log else file_logger
self.log: logging.Logger = log if log else console_logger
# for using .start() and .stop()
self._thread: Optional[Thread] = None
@ -108,7 +109,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.add_error_routes()
if self.templates_dirs:
self.template_loader = Environment(loader=FileSystemLoader(self.templates_dirs))
self.template_loader = Environment(
loader=FileSystemLoader(self.templates_dirs)
)
else:
self.template_loader = None
self.string_loader = Environment(loader=BaseLoader())
@ -117,12 +120,16 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
for static_dir in self.staticfiles_dirs:
static_dir = pathlib.Path(static_dir)
if not pathlib.Path(self.BASE_DIR / static_dir).exists():
log.error(
self.log.error(
f"Static files directory '{str(static_dir)}' does not exist."
)
raise ConfigError
self.add_route(r"/static/<str:filename>", send_file) # noqa: F405
# finally, run the startup checks to verify everything is correct and happy.
self.log.info("Run startup checks...")
self.run_middleware_checks()
def fire_response(self, start_response, request: Request, resp: HttpResponse):
try:
status = get_http_status_by_code(resp.status_code)
@ -190,7 +197,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
def prepare_and_fire_response(self, start_response, request, resp) -> list[bytes]:
try:
if isinstance(resp, dict):
return self.fire_response(start_response, request, JsonResponse(data=resp))
return self.fire_response(
start_response, request, JsonResponse(data=resp)
)
if isinstance(resp, TemplateResponse):
resp.set_template_loader(self.template_loader)
resp.set_string_loader(self.string_loader)

View File

@ -12,19 +12,26 @@ from ..utils import import_by_string
class MiddlewareMixin:
"""Cannot be called on its own. Requires context of SpiderwebRouter."""
_middleware: list[str]
middleware: list[ClassVar]
fire_response: Callable
def init_middleware(self):
if self.middleware:
if self._middleware:
middleware_by_reference = []
for m in self.middleware:
for m in self._middleware:
try:
middleware_by_reference.append(import_by_string(m)(server=self))
except ImportError:
raise ConfigError(f"Middleware '{m}' not found.")
self.middleware = middleware_by_reference
def run_middleware_checks(self):
for middleware in self.middleware:
if hasattr(middleware, "checks"):
for check in middleware.checks:
check(server=self).check()
def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.middleware:
try:

View File

@ -1,25 +1,42 @@
from datetime import datetime, timedelta
from spiderweb.exceptions import CSRFError
from spiderweb.exceptions import CSRFError, ConfigError
from spiderweb.middleware import SpiderwebMiddleware
from spiderweb.request import Request
from spiderweb.response import HttpResponse
from spiderweb.server_checks import ServerCheck
class SessionCheck(ServerCheck):
SESSION_MIDDLEWARE_NOT_FOUND = (
"Session middleware is not enabled. It must be listed above"
"CSRFMiddleware in the middleware list."
)
SESSION_MIDDLEWARE_BELOW_CSRF = (
"SessionMiddleware is enabled, but it must be listed above"
"CSRFMiddleware in the middleware list."
)
def check(self):
if (
"spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware
):
raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
if self.server._middleware.index(
"spiderweb.middleware.sessions.SessionMiddleware"
) > self.server._middleware.index(
"spiderweb.middleware.csrf.CSRFMiddleware"
):
raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
class CSRFMiddleware(SpiderwebMiddleware):
"""
tl;dr: this is a naive implementation going off just what I could think of
at the time. It is very vulnerable to CSRF Forgery and should be updated.
Eventually I'll probably just pull everything out of Django and use their
implementation, as it's written by people who know a lot more about these
things than I do, but in the meantime, this is still here until I get
around to making it more solid.
todo: fix
https://en.wikipedia.org/wiki/Cross-site_request_forgery
"""
checks = [SessionCheck]
CSRF_EXPIRY = 60 * 60 # 1 hour
@ -33,14 +50,14 @@ class CSRFMiddleware(SpiderwebMiddleware):
or request.GET.get("csrf_token")
or request.POST.get("csrf_token")
)
if self.is_csrf_valid(csrf_token):
if self.is_csrf_valid(request, csrf_token):
return None
else:
raise CSRFError()
return None
def process_response(self, request: Request, response: HttpResponse) -> None:
token = self.get_csrf_token()
token = self.get_csrf_token(request)
# do we need it in both places?
response.headers["X-CSRF-TOKEN"] = token
response.context |= {
@ -48,17 +65,22 @@ class CSRFMiddleware(SpiderwebMiddleware):
"raw_csrf_token": token, # in case they want to format it themselves
}
def get_csrf_token(self):
return self.server.encrypt(str(datetime.now().isoformat())).decode(
self.server.DEFAULT_ENCODING
)
def get_csrf_token(self, request):
# the session key should be here because we've processed the session first
session_key = request._session["id"]
return self.server.encrypt(
f"{str(datetime.now().isoformat())}::{session_key}"
).decode(self.server.DEFAULT_ENCODING)
def is_csrf_valid(self, key):
def is_csrf_valid(self, request, key):
try:
decoded = self.server.decrypt(key)
timestamp, session_key = decoded.split("::")
if session_key != request._session["id"]:
return False
if datetime.now() - timedelta(
seconds=self.CSRF_EXPIRY
) > datetime.fromisoformat(decoded):
) > datetime.fromisoformat(timestamp):
return False
return True
except Exception:

View File

@ -30,7 +30,7 @@ class RoutesMixin:
# ones that start with underscores are the compiled versions, non-underscores
# are the user-supplied versions
_routes: dict
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
_error_routes: dict
error_routes: dict[int, Callable]
append_slash: bool

View File

@ -0,0 +1,18 @@
class ServerCheck:
"""
Single server check base class.
During startup, each middleware can request checks to be run against the
current state of the server. These are usually used to verify that things
are configured correctly, but can also be used for database setup or other
similar things.
To fail a check, raise any error that makes sense to raise. This will halt
startup so the error can be fixed.
"""
def __init__(self, server):
self.server = server
def check(self):
pass

View File

@ -1 +1,4 @@
from spiderweb.tests.middleware import ExplodingResponseMiddleware, ExplodingRequestMiddleware
from spiderweb.tests.middleware import (
ExplodingResponseMiddleware,
ExplodingRequestMiddleware,
)

View File

@ -7,5 +7,7 @@ class ExplodingRequestMiddleware(SpiderwebMiddleware):
class ExplodingResponseMiddleware(SpiderwebMiddleware):
def process_response(self, request: Request, response: HttpResponse) -> HttpResponse | None:
def process_response(
self, request: Request, response: HttpResponse
) -> HttpResponse | None:
raise UnusedMiddleware("Unfinished!")

View File

@ -1,11 +1,16 @@
from io import BytesIO, BufferedReader
from datetime import timedelta
import pytest
from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
# app = SpiderwebRouter(
# middleware=[
@ -18,19 +23,20 @@ from spiderweb.tests.helpers import setup
# ],
# )
def index(request):
if "value" in request.SESSION:
request.SESSION['value'] += 1
request.SESSION["value"] += 1
else:
request.SESSION['value'] = 0
return HttpResponse(body=str(request.SESSION['value']))
request.SESSION["value"] = 0
return HttpResponse(body=str(request.SESSION["value"]))
def test_session_middleware():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
db=SqliteDatabase("spiderweb-tests.db")
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", index)
@ -46,11 +52,12 @@ def test_session_middleware():
assert app(environ, start_response) == [bytes(str(1), DEFAULT_ENCODING)]
assert app(environ, start_response) == [bytes(str(2), DEFAULT_ENCODING)]
def test_expired_session():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
db=SqliteDatabase("spiderweb-tests.db")
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", index)
@ -80,9 +87,170 @@ def test_exploding_middleware():
"spiderweb.tests.middleware.ExplodingRequestMiddleware",
"spiderweb.tests.middleware.ExplodingResponseMiddleware",
],
db=SqliteDatabase("spiderweb-tests.db")
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", index)
assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)]
# make sure it kicked out the middleware and isn't just ignoring it
assert len(app.middleware) == 0
def test_csrf_middleware_without_session_middleware():
_, environ, start_response = setup()
with pytest.raises(ConfigError) as e:
SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"),
)
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
def test_csrf_middleware_above_session_middleware():
_, environ, start_response = setup()
with pytest.raises(ConfigError) as e:
SpiderwebRouter(
middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
def test_csrf_middleware():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "<form" in resp
assert '<input type="hidden" name="csrf_token"' in resp
token = resp.split('value="')[1].split('"')[0]
formdata = f"name=bob&csrf_token={token}"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["HTTP_COOKIE"] = (
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
)
environ["REQUEST_METHOD"] = "POST"
environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata)
# setup form data
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "bob" in resp2
# test that it raises a CSRF error on wrong token
formdata = f"name=bob&csrf_token=badtoken"
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp3
# test that the wrong session also raises a CSRF error
token = app.decrypt(token).split("::")[0]
token = app.encrypt(f"{token}::badsession").decode(DEFAULT_ENCODING)
formdata = f"name=bob&csrf_token={token}"
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp4 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp4
def test_csrf_expired_token():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.middleware[1].CSRF_EXPIRY = -1
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
token = resp.split('value="')[1].split('"')[0]
formdata = f"name=bob&csrf_token={token}"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["HTTP_COOKIE"] = (
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
)
environ["REQUEST_METHOD"] = "POST"
environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp
def test_csrf_exempt():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", form_csrf_exempt, ["GET", "POST"])
app.add_route("/2", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["REQUEST_METHOD"] = "POST"
formdata = "name=bob"
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "bob" in resp
environ["PATH_INFO"] = "/2"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp2

View File

@ -3,7 +3,12 @@ import pytest
from spiderweb import SpiderwebRouter, ConfigError
from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.exceptions import NoResponseError, SpiderwebNetworkException
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
from spiderweb.response import (
HttpResponse,
JsonResponse,
TemplateResponse,
RedirectResponse,
)
from hypothesis import given, strategies as st
from spiderweb.tests.helpers import setup
@ -27,7 +32,9 @@ def test_json_response():
def index(request):
return JsonResponse(data={"message": "text"})
assert app(environ, start_response) == [bytes('{"message": "text"}', DEFAULT_ENCODING)]
assert app(environ, start_response) == [
bytes('{"message": "text"}', DEFAULT_ENCODING)
]
def test_dict_response():
@ -51,7 +58,9 @@ def test_template_response(text):
request, template_string=template, context={"message": text}
)
assert app(environ, start_response) == [b"MESSAGE: " + bytes(text, DEFAULT_ENCODING)]
assert app(environ, start_response) == [
b"MESSAGE: " + bytes(text, DEFAULT_ENCODING)
]
def test_redirect_response():
@ -61,7 +70,7 @@ def test_redirect_response():
def index(request):
return RedirectResponse(location="/redirected")
assert app(environ, start_response) == [b'None']
assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/redirected"
@ -74,12 +83,14 @@ def test_add_route_at_server_start():
def view2(request):
return HttpResponse("View 2")
app = SpiderwebRouter(routes=[
app = SpiderwebRouter(
routes=[
("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}),
("/view2", view2),
])
]
)
assert app(environ, start_response) == [b'None']
assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/redirected"
@ -92,7 +103,7 @@ def test_redirect_on_append_slash():
pass
environ["PATH_INFO"] = f"/hello"
assert app(environ, start_response) == [b'None']
assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/hello/"
@ -104,11 +115,11 @@ def test_template_response_with_template(text):
@app.route("/")
def index(request):
return TemplateResponse(
request, "test.html", context={"message": text}
)
return TemplateResponse(request, "test.html", context={"message": text})
assert app(environ, start_response) == [b"TEMPLATE! " + bytes(text, DEFAULT_ENCODING)]
assert app(environ, start_response) == [
b"TEMPLATE! " + bytes(text, DEFAULT_ENCODING)
]
def test_view_returns_none():
@ -119,7 +130,7 @@ def test_view_returns_none():
pass
with pytest.raises(NoResponseError):
assert app(environ, start_response) == [b'None']
assert app(environ, start_response) == [b"None"]
def test_exploding_view():
@ -130,9 +141,10 @@ def test_exploding_view():
raise SpiderwebNetworkException("Boom!")
assert app(environ, start_response) == [
b'Something went wrong.\n\nCode: Boom!\n\nMsg: None\n\nDesc: None'
b"Something went wrong.\n\nCode: Boom!\n\nMsg: None\n\nDesc: None"
]
def test_missing_view():
app, environ, start_response = setup()
@ -146,20 +158,19 @@ def test_missing_view_with_custom_404():
def custom_404(request):
return HttpResponse("Custom 404")
assert app(environ, start_response) == [b'Custom 404']
assert app(environ, start_response) == [b"Custom 404"]
def test_duplicate_error_view():
app, environ, start_response = setup()
@app.error(404)
def custom_404(request):
...
def custom_404(request): ...
with pytest.raises(ConfigError):
@app.error(404)
def custom_404(request):
...
def custom_404(request): ...
def test_missing_view_with_custom_404_alt():
@ -170,4 +181,4 @@ def test_missing_view_with_custom_404_alt():
app = SpiderwebRouter(error_routes={404: custom_404})
assert app(environ, start_response) == [b'Custom 404 2']
assert app(environ, start_response) == [b"Custom 404 2"]

View File

@ -3,7 +3,12 @@ import pytest
from spiderweb import SpiderwebRouter
from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.exceptions import ParseError, ConfigError
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
from spiderweb.response import (
HttpResponse,
JsonResponse,
TemplateResponse,
RedirectResponse,
)
from hypothesis import given, strategies as st, assume
from peewee import SqliteDatabase
@ -43,6 +48,7 @@ def test_unknown_converter():
app, environ, start_response = setup()
with pytest.raises(ParseError):
@app.route("/<asdf:test_input>")
def index(request, test_input: str):
return HttpResponse(test_input)
@ -52,19 +58,19 @@ def test_duplicate_route():
app, environ, start_response = setup()
@app.route("/")
def index(request):
...
def index(request): ...
with pytest.raises(ConfigError):
@app.route("/")
def index(request):
...
def index(request): ...
def test_url_with_double_underscore():
app, environ, start_response = setup()
with pytest.raises(ConfigError):
@app.route("/<asdf:test__input>")
def index(request, test_input: str):
return HttpResponse(test_input)

View File

@ -0,0 +1,40 @@
from spiderweb.decorators import csrf_exempt
from spiderweb.response import JsonResponse, TemplateResponse
EXAMPLE_HTML_FORM = """
<form action="" method="post">
<input type="text" name="name" />
<input type="submit" />
</form>
"""
EXAMPLE_HTML_FORM_WITH_CSRF = """
<form action="" method="post">
<input type="text" name="name" />
<input type="submit" />
{{ csrf_token }}
</form>
"""
def form_view_without_csrf(request):
if request.method == "POST":
return JsonResponse(data=request.POST)
else:
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM)
@csrf_exempt
def form_csrf_exempt(request):
if request.method == "POST":
return JsonResponse(data=request.POST)
else:
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)
def form_view_with_csrf(request):
if request.method == "POST":
return JsonResponse(data=request.POST)
else:
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)