add server checks & fix csrf middleware

This commit is contained in:
Joe Kaufeld 2024-08-29 17:29:28 -04:00
parent aabe20cff7
commit c451aff1e2
16 changed files with 360 additions and 74 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "spiderweb-framework" name = "spiderweb-framework"
version = "0.10.0" version = "0.11.0"
description = "A small web framework, just big enough for a spider." description = "A small web framework, just big enough for a spider."
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"] authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
readme = "README.md" readme = "README.md"

View File

@ -2,7 +2,7 @@ from peewee import DatabaseProxy
DEFAULT_ALLOWED_METHODS = ["GET"] DEFAULT_ALLOWED_METHODS = ["GET"]
DEFAULT_ENCODING = "UTF-8" DEFAULT_ENCODING = "UTF-8"
__version__ = "0.10.0" __version__ = "0.11.0"
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"

View File

@ -5,7 +5,7 @@ class SpiderwebException(Exception):
msg = self.args[0] if len(self.args) > 0 else "" msg = self.args[0] if len(self.args) > 0 else ""
if msg: if msg:
return f"{name}() - {msg}" return f"{name}() - {msg}"
return f"{self.__class__.__name__}()" return f"{name}()"
class SpiderwebNetworkException(SpiderwebException): class SpiderwebNetworkException(SpiderwebException):

View File

@ -32,7 +32,7 @@ from spiderweb.routes import RoutesMixin
from spiderweb.secrets import FernetMixin from spiderweb.secrets import FernetMixin
from spiderweb.utils import get_http_status_by_code from spiderweb.utils import get_http_status_by_code
file_logger = logging.getLogger(__name__) console_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -57,7 +57,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
session_cookie_same_site="lax", session_cookie_same_site="lax",
session_cookie_path="/", session_cookie_path="/",
log=None, log=None,
**kwargs **kwargs,
): ):
self._routes = {} self._routes = {}
self.routes = routes self.routes = routes
@ -69,7 +69,8 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.append_slash = append_slash self.append_slash = append_slash
self.templates_dirs = templates_dirs self.templates_dirs = templates_dirs
self.staticfiles_dirs = staticfiles_dirs self.staticfiles_dirs = staticfiles_dirs
self.middleware = middleware if middleware else [] self._middleware: list[str] = middleware if middleware else []
self.middleware: list[Callable] = []
self.secret_key = secret_key if secret_key else self.generate_key() self.secret_key = secret_key if secret_key else self.generate_key()
self.extra_data = kwargs self.extra_data = kwargs
@ -84,7 +85,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.DEFAULT_ENCODING = DEFAULT_ENCODING self.DEFAULT_ENCODING = DEFAULT_ENCODING
self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS
self.log: logging.Logger = log if log else file_logger self.log: logging.Logger = log if log else console_logger
# for using .start() and .stop() # for using .start() and .stop()
self._thread: Optional[Thread] = None self._thread: Optional[Thread] = None
@ -108,7 +109,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.add_error_routes() self.add_error_routes()
if self.templates_dirs: if self.templates_dirs:
self.template_loader = Environment(loader=FileSystemLoader(self.templates_dirs)) self.template_loader = Environment(
loader=FileSystemLoader(self.templates_dirs)
)
else: else:
self.template_loader = None self.template_loader = None
self.string_loader = Environment(loader=BaseLoader()) self.string_loader = Environment(loader=BaseLoader())
@ -117,12 +120,16 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
for static_dir in self.staticfiles_dirs: for static_dir in self.staticfiles_dirs:
static_dir = pathlib.Path(static_dir) static_dir = pathlib.Path(static_dir)
if not pathlib.Path(self.BASE_DIR / static_dir).exists(): if not pathlib.Path(self.BASE_DIR / static_dir).exists():
log.error( self.log.error(
f"Static files directory '{str(static_dir)}' does not exist." f"Static files directory '{str(static_dir)}' does not exist."
) )
raise ConfigError raise ConfigError
self.add_route(r"/static/<str:filename>", send_file) # noqa: F405 self.add_route(r"/static/<str:filename>", send_file) # noqa: F405
# finally, run the startup checks to verify everything is correct and happy.
self.log.info("Run startup checks...")
self.run_middleware_checks()
def fire_response(self, start_response, request: Request, resp: HttpResponse): def fire_response(self, start_response, request: Request, resp: HttpResponse):
try: try:
status = get_http_status_by_code(resp.status_code) status = get_http_status_by_code(resp.status_code)
@ -190,7 +197,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
def prepare_and_fire_response(self, start_response, request, resp) -> list[bytes]: def prepare_and_fire_response(self, start_response, request, resp) -> list[bytes]:
try: try:
if isinstance(resp, dict): if isinstance(resp, dict):
return self.fire_response(start_response, request, JsonResponse(data=resp)) return self.fire_response(
start_response, request, JsonResponse(data=resp)
)
if isinstance(resp, TemplateResponse): if isinstance(resp, TemplateResponse):
resp.set_template_loader(self.template_loader) resp.set_template_loader(self.template_loader)
resp.set_string_loader(self.string_loader) resp.set_string_loader(self.string_loader)

View File

@ -12,19 +12,26 @@ from ..utils import import_by_string
class MiddlewareMixin: class MiddlewareMixin:
"""Cannot be called on its own. Requires context of SpiderwebRouter.""" """Cannot be called on its own. Requires context of SpiderwebRouter."""
_middleware: list[str]
middleware: list[ClassVar] middleware: list[ClassVar]
fire_response: Callable fire_response: Callable
def init_middleware(self): def init_middleware(self):
if self.middleware: if self._middleware:
middleware_by_reference = [] middleware_by_reference = []
for m in self.middleware: for m in self._middleware:
try: try:
middleware_by_reference.append(import_by_string(m)(server=self)) middleware_by_reference.append(import_by_string(m)(server=self))
except ImportError: except ImportError:
raise ConfigError(f"Middleware '{m}' not found.") raise ConfigError(f"Middleware '{m}' not found.")
self.middleware = middleware_by_reference self.middleware = middleware_by_reference
def run_middleware_checks(self):
for middleware in self.middleware:
if hasattr(middleware, "checks"):
for check in middleware.checks:
check(server=self).check()
def process_request_middleware(self, request: Request) -> None | bool: def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.middleware: for middleware in self.middleware:
try: try:

View File

@ -1,25 +1,42 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from spiderweb.exceptions import CSRFError from spiderweb.exceptions import CSRFError, ConfigError
from spiderweb.middleware import SpiderwebMiddleware from spiderweb.middleware import SpiderwebMiddleware
from spiderweb.request import Request from spiderweb.request import Request
from spiderweb.response import HttpResponse from spiderweb.response import HttpResponse
from spiderweb.server_checks import ServerCheck
class SessionCheck(ServerCheck):
SESSION_MIDDLEWARE_NOT_FOUND = (
"Session middleware is not enabled. It must be listed above"
"CSRFMiddleware in the middleware list."
)
SESSION_MIDDLEWARE_BELOW_CSRF = (
"SessionMiddleware is enabled, but it must be listed above"
"CSRFMiddleware in the middleware list."
)
def check(self):
if (
"spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware
):
raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
if self.server._middleware.index(
"spiderweb.middleware.sessions.SessionMiddleware"
) > self.server._middleware.index(
"spiderweb.middleware.csrf.CSRFMiddleware"
):
raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
class CSRFMiddleware(SpiderwebMiddleware): class CSRFMiddleware(SpiderwebMiddleware):
"""
tl;dr: this is a naive implementation going off just what I could think of
at the time. It is very vulnerable to CSRF Forgery and should be updated.
Eventually I'll probably just pull everything out of Django and use their checks = [SessionCheck]
implementation, as it's written by people who know a lot more about these
things than I do, but in the meantime, this is still here until I get
around to making it more solid.
todo: fix
https://en.wikipedia.org/wiki/Cross-site_request_forgery
"""
CSRF_EXPIRY = 60 * 60 # 1 hour CSRF_EXPIRY = 60 * 60 # 1 hour
@ -33,14 +50,14 @@ class CSRFMiddleware(SpiderwebMiddleware):
or request.GET.get("csrf_token") or request.GET.get("csrf_token")
or request.POST.get("csrf_token") or request.POST.get("csrf_token")
) )
if self.is_csrf_valid(csrf_token): if self.is_csrf_valid(request, csrf_token):
return None return None
else: else:
raise CSRFError() raise CSRFError()
return None return None
def process_response(self, request: Request, response: HttpResponse) -> None: def process_response(self, request: Request, response: HttpResponse) -> None:
token = self.get_csrf_token() token = self.get_csrf_token(request)
# do we need it in both places? # do we need it in both places?
response.headers["X-CSRF-TOKEN"] = token response.headers["X-CSRF-TOKEN"] = token
response.context |= { response.context |= {
@ -48,17 +65,22 @@ class CSRFMiddleware(SpiderwebMiddleware):
"raw_csrf_token": token, # in case they want to format it themselves "raw_csrf_token": token, # in case they want to format it themselves
} }
def get_csrf_token(self): def get_csrf_token(self, request):
return self.server.encrypt(str(datetime.now().isoformat())).decode( # the session key should be here because we've processed the session first
self.server.DEFAULT_ENCODING session_key = request._session["id"]
) return self.server.encrypt(
f"{str(datetime.now().isoformat())}::{session_key}"
).decode(self.server.DEFAULT_ENCODING)
def is_csrf_valid(self, key): def is_csrf_valid(self, request, key):
try: try:
decoded = self.server.decrypt(key) decoded = self.server.decrypt(key)
timestamp, session_key = decoded.split("::")
if session_key != request._session["id"]:
return False
if datetime.now() - timedelta( if datetime.now() - timedelta(
seconds=self.CSRF_EXPIRY seconds=self.CSRF_EXPIRY
) > datetime.fromisoformat(decoded): ) > datetime.fromisoformat(timestamp):
return False return False
return True return True
except Exception: except Exception:

View File

@ -30,7 +30,7 @@ class RoutesMixin:
# ones that start with underscores are the compiled versions, non-underscores # ones that start with underscores are the compiled versions, non-underscores
# are the user-supplied versions # are the user-supplied versions
_routes: dict _routes: dict
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None, routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
_error_routes: dict _error_routes: dict
error_routes: dict[int, Callable] error_routes: dict[int, Callable]
append_slash: bool append_slash: bool

View File

@ -0,0 +1,18 @@
class ServerCheck:
"""
Single server check base class.
During startup, each middleware can request checks to be run against the
current state of the server. These are usually used to verify that things
are configured correctly, but can also be used for database setup or other
similar things.
To fail a check, raise any error that makes sense to raise. This will halt
startup so the error can be fixed.
"""
def __init__(self, server):
self.server = server
def check(self):
pass

View File

@ -1 +1,4 @@
from spiderweb.tests.middleware import ExplodingResponseMiddleware, ExplodingRequestMiddleware from spiderweb.tests.middleware import (
ExplodingResponseMiddleware,
ExplodingRequestMiddleware,
)

View File

@ -7,5 +7,7 @@ class ExplodingRequestMiddleware(SpiderwebMiddleware):
class ExplodingResponseMiddleware(SpiderwebMiddleware): class ExplodingResponseMiddleware(SpiderwebMiddleware):
def process_response(self, request: Request, response: HttpResponse) -> HttpResponse | None: def process_response(
self, request: Request, response: HttpResponse
) -> HttpResponse | None:
raise UnusedMiddleware("Unfinished!") raise UnusedMiddleware("Unfinished!")

View File

@ -1,11 +1,16 @@
from io import BytesIO, BufferedReader
from datetime import timedelta from datetime import timedelta
import pytest
from peewee import SqliteDatabase from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.sessions import Session from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
# app = SpiderwebRouter( # app = SpiderwebRouter(
# middleware=[ # middleware=[
@ -18,19 +23,20 @@ from spiderweb.tests.helpers import setup
# ], # ],
# ) # )
def index(request): def index(request):
if "value" in request.SESSION: if "value" in request.SESSION:
request.SESSION['value'] += 1 request.SESSION["value"] += 1
else: else:
request.SESSION['value'] = 0 request.SESSION["value"] = 0
return HttpResponse(body=str(request.SESSION['value'])) return HttpResponse(body=str(request.SESSION["value"]))
def test_session_middleware(): def test_session_middleware():
_, environ, start_response = setup() _, environ, start_response = setup()
app = SpiderwebRouter( app = SpiderwebRouter(
middleware=["spiderweb.middleware.sessions.SessionMiddleware"], middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
db=SqliteDatabase("spiderweb-tests.db") db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", index) app.add_route("/", index)
@ -46,11 +52,12 @@ def test_session_middleware():
assert app(environ, start_response) == [bytes(str(1), DEFAULT_ENCODING)] assert app(environ, start_response) == [bytes(str(1), DEFAULT_ENCODING)]
assert app(environ, start_response) == [bytes(str(2), DEFAULT_ENCODING)] assert app(environ, start_response) == [bytes(str(2), DEFAULT_ENCODING)]
def test_expired_session(): def test_expired_session():
_, environ, start_response = setup() _, environ, start_response = setup()
app = SpiderwebRouter( app = SpiderwebRouter(
middleware=["spiderweb.middleware.sessions.SessionMiddleware"], middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
db=SqliteDatabase("spiderweb-tests.db") db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", index) app.add_route("/", index)
@ -80,9 +87,170 @@ def test_exploding_middleware():
"spiderweb.tests.middleware.ExplodingRequestMiddleware", "spiderweb.tests.middleware.ExplodingRequestMiddleware",
"spiderweb.tests.middleware.ExplodingResponseMiddleware", "spiderweb.tests.middleware.ExplodingResponseMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db") db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", index) app.add_route("/", index)
assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)] assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)]
# make sure it kicked out the middleware and isn't just ignoring it
assert len(app.middleware) == 0
def test_csrf_middleware_without_session_middleware():
_, environ, start_response = setup()
with pytest.raises(ConfigError) as e:
SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"),
)
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
def test_csrf_middleware_above_session_middleware():
_, environ, start_response = setup()
with pytest.raises(ConfigError) as e:
SpiderwebRouter(
middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
def test_csrf_middleware():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "<form" in resp
assert '<input type="hidden" name="csrf_token"' in resp
token = resp.split('value="')[1].split('"')[0]
formdata = f"name=bob&csrf_token={token}"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["HTTP_COOKIE"] = (
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
)
environ["REQUEST_METHOD"] = "POST"
environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata)
# setup form data
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "bob" in resp2
# test that it raises a CSRF error on wrong token
formdata = f"name=bob&csrf_token=badtoken"
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp3
# test that the wrong session also raises a CSRF error
token = app.decrypt(token).split("::")[0]
token = app.encrypt(f"{token}::badsession").decode(DEFAULT_ENCODING)
formdata = f"name=bob&csrf_token={token}"
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp4 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp4
def test_csrf_expired_token():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.middleware[1].CSRF_EXPIRY = -1
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
token = resp.split('value="')[1].split('"')[0]
formdata = f"name=bob&csrf_token={token}"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["HTTP_COOKIE"] = (
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
)
environ["REQUEST_METHOD"] = "POST"
environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp
def test_csrf_exempt():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", form_csrf_exempt, ["GET", "POST"])
app.add_route("/2", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["REQUEST_METHOD"] = "POST"
formdata = "name=bob"
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "bob" in resp
environ["PATH_INFO"] = "/2"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp2

View File

@ -3,7 +3,12 @@ import pytest
from spiderweb import SpiderwebRouter, ConfigError from spiderweb import SpiderwebRouter, ConfigError
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.exceptions import NoResponseError, SpiderwebNetworkException from spiderweb.exceptions import NoResponseError, SpiderwebNetworkException
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse from spiderweb.response import (
HttpResponse,
JsonResponse,
TemplateResponse,
RedirectResponse,
)
from hypothesis import given, strategies as st from hypothesis import given, strategies as st
from spiderweb.tests.helpers import setup from spiderweb.tests.helpers import setup
@ -27,7 +32,9 @@ def test_json_response():
def index(request): def index(request):
return JsonResponse(data={"message": "text"}) return JsonResponse(data={"message": "text"})
assert app(environ, start_response) == [bytes('{"message": "text"}', DEFAULT_ENCODING)] assert app(environ, start_response) == [
bytes('{"message": "text"}', DEFAULT_ENCODING)
]
def test_dict_response(): def test_dict_response():
@ -51,7 +58,9 @@ def test_template_response(text):
request, template_string=template, context={"message": text} request, template_string=template, context={"message": text}
) )
assert app(environ, start_response) == [b"MESSAGE: " + bytes(text, DEFAULT_ENCODING)] assert app(environ, start_response) == [
b"MESSAGE: " + bytes(text, DEFAULT_ENCODING)
]
def test_redirect_response(): def test_redirect_response():
@ -61,7 +70,7 @@ def test_redirect_response():
def index(request): def index(request):
return RedirectResponse(location="/redirected") return RedirectResponse(location="/redirected")
assert app(environ, start_response) == [b'None'] assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/redirected" assert start_response.get_headers()["Location"] == "/redirected"
@ -74,12 +83,14 @@ def test_add_route_at_server_start():
def view2(request): def view2(request):
return HttpResponse("View 2") return HttpResponse("View 2")
app = SpiderwebRouter(routes=[ app = SpiderwebRouter(
routes=[
("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}), ("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}),
("/view2", view2), ("/view2", view2),
]) ]
)
assert app(environ, start_response) == [b'None'] assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/redirected" assert start_response.get_headers()["Location"] == "/redirected"
@ -92,7 +103,7 @@ def test_redirect_on_append_slash():
pass pass
environ["PATH_INFO"] = f"/hello" environ["PATH_INFO"] = f"/hello"
assert app(environ, start_response) == [b'None'] assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/hello/" assert start_response.get_headers()["Location"] == "/hello/"
@ -104,11 +115,11 @@ def test_template_response_with_template(text):
@app.route("/") @app.route("/")
def index(request): def index(request):
return TemplateResponse( return TemplateResponse(request, "test.html", context={"message": text})
request, "test.html", context={"message": text}
)
assert app(environ, start_response) == [b"TEMPLATE! " + bytes(text, DEFAULT_ENCODING)] assert app(environ, start_response) == [
b"TEMPLATE! " + bytes(text, DEFAULT_ENCODING)
]
def test_view_returns_none(): def test_view_returns_none():
@ -119,7 +130,7 @@ def test_view_returns_none():
pass pass
with pytest.raises(NoResponseError): with pytest.raises(NoResponseError):
assert app(environ, start_response) == [b'None'] assert app(environ, start_response) == [b"None"]
def test_exploding_view(): def test_exploding_view():
@ -130,9 +141,10 @@ def test_exploding_view():
raise SpiderwebNetworkException("Boom!") raise SpiderwebNetworkException("Boom!")
assert app(environ, start_response) == [ assert app(environ, start_response) == [
b'Something went wrong.\n\nCode: Boom!\n\nMsg: None\n\nDesc: None' b"Something went wrong.\n\nCode: Boom!\n\nMsg: None\n\nDesc: None"
] ]
def test_missing_view(): def test_missing_view():
app, environ, start_response = setup() app, environ, start_response = setup()
@ -146,20 +158,19 @@ def test_missing_view_with_custom_404():
def custom_404(request): def custom_404(request):
return HttpResponse("Custom 404") return HttpResponse("Custom 404")
assert app(environ, start_response) == [b'Custom 404'] assert app(environ, start_response) == [b"Custom 404"]
def test_duplicate_error_view(): def test_duplicate_error_view():
app, environ, start_response = setup() app, environ, start_response = setup()
@app.error(404) @app.error(404)
def custom_404(request): def custom_404(request): ...
...
with pytest.raises(ConfigError): with pytest.raises(ConfigError):
@app.error(404) @app.error(404)
def custom_404(request): def custom_404(request): ...
...
def test_missing_view_with_custom_404_alt(): def test_missing_view_with_custom_404_alt():
@ -170,4 +181,4 @@ def test_missing_view_with_custom_404_alt():
app = SpiderwebRouter(error_routes={404: custom_404}) app = SpiderwebRouter(error_routes={404: custom_404})
assert app(environ, start_response) == [b'Custom 404 2'] assert app(environ, start_response) == [b"Custom 404 2"]

View File

@ -3,7 +3,12 @@ import pytest
from spiderweb import SpiderwebRouter from spiderweb import SpiderwebRouter
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.exceptions import ParseError, ConfigError from spiderweb.exceptions import ParseError, ConfigError
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse from spiderweb.response import (
HttpResponse,
JsonResponse,
TemplateResponse,
RedirectResponse,
)
from hypothesis import given, strategies as st, assume from hypothesis import given, strategies as st, assume
from peewee import SqliteDatabase from peewee import SqliteDatabase
@ -43,6 +48,7 @@ def test_unknown_converter():
app, environ, start_response = setup() app, environ, start_response = setup()
with pytest.raises(ParseError): with pytest.raises(ParseError):
@app.route("/<asdf:test_input>") @app.route("/<asdf:test_input>")
def index(request, test_input: str): def index(request, test_input: str):
return HttpResponse(test_input) return HttpResponse(test_input)
@ -52,19 +58,19 @@ def test_duplicate_route():
app, environ, start_response = setup() app, environ, start_response = setup()
@app.route("/") @app.route("/")
def index(request): def index(request): ...
...
with pytest.raises(ConfigError): with pytest.raises(ConfigError):
@app.route("/") @app.route("/")
def index(request): def index(request): ...
...
def test_url_with_double_underscore(): def test_url_with_double_underscore():
app, environ, start_response = setup() app, environ, start_response = setup()
with pytest.raises(ConfigError): with pytest.raises(ConfigError):
@app.route("/<asdf:test__input>") @app.route("/<asdf:test__input>")
def index(request, test_input: str): def index(request, test_input: str):
return HttpResponse(test_input) return HttpResponse(test_input)

View File

@ -0,0 +1,40 @@
from spiderweb.decorators import csrf_exempt
from spiderweb.response import JsonResponse, TemplateResponse
EXAMPLE_HTML_FORM = """
<form action="" method="post">
<input type="text" name="name" />
<input type="submit" />
</form>
"""
EXAMPLE_HTML_FORM_WITH_CSRF = """
<form action="" method="post">
<input type="text" name="name" />
<input type="submit" />
{{ csrf_token }}
</form>
"""
def form_view_without_csrf(request):
if request.method == "POST":
return JsonResponse(data=request.POST)
else:
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM)
@csrf_exempt
def form_csrf_exempt(request):
if request.method == "POST":
return JsonResponse(data=request.POST)
else:
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)
def form_view_with_csrf(request):
if request.method == "POST":
return JsonResponse(data=request.POST)
else:
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)