✨ add server checks & fix csrf middleware
This commit is contained in:
parent
aabe20cff7
commit
c451aff1e2
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "spiderweb-framework"
|
||||
version = "0.10.0"
|
||||
version = "0.11.0"
|
||||
description = "A small web framework, just big enough for a spider."
|
||||
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
|
||||
readme = "README.md"
|
||||
|
@ -2,7 +2,7 @@ from peewee import DatabaseProxy
|
||||
|
||||
DEFAULT_ALLOWED_METHODS = ["GET"]
|
||||
DEFAULT_ENCODING = "UTF-8"
|
||||
__version__ = "0.10.0"
|
||||
__version__ = "0.11.0"
|
||||
|
||||
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
|
||||
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
||||
|
@ -5,7 +5,7 @@ class SpiderwebException(Exception):
|
||||
msg = self.args[0] if len(self.args) > 0 else ""
|
||||
if msg:
|
||||
return f"{name}() - {msg}"
|
||||
return f"{self.__class__.__name__}()"
|
||||
return f"{name}()"
|
||||
|
||||
|
||||
class SpiderwebNetworkException(SpiderwebException):
|
||||
|
@ -32,7 +32,7 @@ from spiderweb.routes import RoutesMixin
|
||||
from spiderweb.secrets import FernetMixin
|
||||
from spiderweb.utils import get_http_status_by_code
|
||||
|
||||
file_logger = logging.getLogger(__name__)
|
||||
console_logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
session_cookie_same_site="lax",
|
||||
session_cookie_path="/",
|
||||
log=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self._routes = {}
|
||||
self.routes = routes
|
||||
@ -69,7 +69,8 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
self.append_slash = append_slash
|
||||
self.templates_dirs = templates_dirs
|
||||
self.staticfiles_dirs = staticfiles_dirs
|
||||
self.middleware = middleware if middleware else []
|
||||
self._middleware: list[str] = middleware if middleware else []
|
||||
self.middleware: list[Callable] = []
|
||||
self.secret_key = secret_key if secret_key else self.generate_key()
|
||||
|
||||
self.extra_data = kwargs
|
||||
@ -84,7 +85,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
|
||||
self.DEFAULT_ENCODING = DEFAULT_ENCODING
|
||||
self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS
|
||||
self.log: logging.Logger = log if log else file_logger
|
||||
self.log: logging.Logger = log if log else console_logger
|
||||
|
||||
# for using .start() and .stop()
|
||||
self._thread: Optional[Thread] = None
|
||||
@ -108,7 +109,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
self.add_error_routes()
|
||||
|
||||
if self.templates_dirs:
|
||||
self.template_loader = Environment(loader=FileSystemLoader(self.templates_dirs))
|
||||
self.template_loader = Environment(
|
||||
loader=FileSystemLoader(self.templates_dirs)
|
||||
)
|
||||
else:
|
||||
self.template_loader = None
|
||||
self.string_loader = Environment(loader=BaseLoader())
|
||||
@ -117,12 +120,16 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
for static_dir in self.staticfiles_dirs:
|
||||
static_dir = pathlib.Path(static_dir)
|
||||
if not pathlib.Path(self.BASE_DIR / static_dir).exists():
|
||||
log.error(
|
||||
self.log.error(
|
||||
f"Static files directory '{str(static_dir)}' does not exist."
|
||||
)
|
||||
raise ConfigError
|
||||
self.add_route(r"/static/<str:filename>", send_file) # noqa: F405
|
||||
|
||||
# finally, run the startup checks to verify everything is correct and happy.
|
||||
self.log.info("Run startup checks...")
|
||||
self.run_middleware_checks()
|
||||
|
||||
def fire_response(self, start_response, request: Request, resp: HttpResponse):
|
||||
try:
|
||||
status = get_http_status_by_code(resp.status_code)
|
||||
@ -190,7 +197,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
def prepare_and_fire_response(self, start_response, request, resp) -> list[bytes]:
|
||||
try:
|
||||
if isinstance(resp, dict):
|
||||
return self.fire_response(start_response, request, JsonResponse(data=resp))
|
||||
return self.fire_response(
|
||||
start_response, request, JsonResponse(data=resp)
|
||||
)
|
||||
if isinstance(resp, TemplateResponse):
|
||||
resp.set_template_loader(self.template_loader)
|
||||
resp.set_string_loader(self.string_loader)
|
||||
|
@ -12,19 +12,26 @@ from ..utils import import_by_string
|
||||
class MiddlewareMixin:
|
||||
"""Cannot be called on its own. Requires context of SpiderwebRouter."""
|
||||
|
||||
_middleware: list[str]
|
||||
middleware: list[ClassVar]
|
||||
fire_response: Callable
|
||||
|
||||
def init_middleware(self):
|
||||
if self.middleware:
|
||||
if self._middleware:
|
||||
middleware_by_reference = []
|
||||
for m in self.middleware:
|
||||
for m in self._middleware:
|
||||
try:
|
||||
middleware_by_reference.append(import_by_string(m)(server=self))
|
||||
except ImportError:
|
||||
raise ConfigError(f"Middleware '{m}' not found.")
|
||||
self.middleware = middleware_by_reference
|
||||
|
||||
def run_middleware_checks(self):
|
||||
for middleware in self.middleware:
|
||||
if hasattr(middleware, "checks"):
|
||||
for check in middleware.checks:
|
||||
check(server=self).check()
|
||||
|
||||
def process_request_middleware(self, request: Request) -> None | bool:
|
||||
for middleware in self.middleware:
|
||||
try:
|
||||
|
@ -1,25 +1,42 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from spiderweb.exceptions import CSRFError
|
||||
from spiderweb.exceptions import CSRFError, ConfigError
|
||||
from spiderweb.middleware import SpiderwebMiddleware
|
||||
from spiderweb.request import Request
|
||||
from spiderweb.response import HttpResponse
|
||||
from spiderweb.server_checks import ServerCheck
|
||||
|
||||
|
||||
class SessionCheck(ServerCheck):
|
||||
|
||||
SESSION_MIDDLEWARE_NOT_FOUND = (
|
||||
"Session middleware is not enabled. It must be listed above"
|
||||
"CSRFMiddleware in the middleware list."
|
||||
)
|
||||
SESSION_MIDDLEWARE_BELOW_CSRF = (
|
||||
"SessionMiddleware is enabled, but it must be listed above"
|
||||
"CSRFMiddleware in the middleware list."
|
||||
)
|
||||
|
||||
def check(self):
|
||||
|
||||
if (
|
||||
"spiderweb.middleware.sessions.SessionMiddleware"
|
||||
not in self.server._middleware
|
||||
):
|
||||
raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
|
||||
|
||||
if self.server._middleware.index(
|
||||
"spiderweb.middleware.sessions.SessionMiddleware"
|
||||
) > self.server._middleware.index(
|
||||
"spiderweb.middleware.csrf.CSRFMiddleware"
|
||||
):
|
||||
raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
|
||||
|
||||
|
||||
class CSRFMiddleware(SpiderwebMiddleware):
|
||||
"""
|
||||
tl;dr: this is a naive implementation going off just what I could think of
|
||||
at the time. It is very vulnerable to CSRF Forgery and should be updated.
|
||||
|
||||
Eventually I'll probably just pull everything out of Django and use their
|
||||
implementation, as it's written by people who know a lot more about these
|
||||
things than I do, but in the meantime, this is still here until I get
|
||||
around to making it more solid.
|
||||
|
||||
todo: fix
|
||||
|
||||
https://en.wikipedia.org/wiki/Cross-site_request_forgery
|
||||
"""
|
||||
checks = [SessionCheck]
|
||||
|
||||
CSRF_EXPIRY = 60 * 60 # 1 hour
|
||||
|
||||
@ -33,14 +50,14 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
||||
or request.GET.get("csrf_token")
|
||||
or request.POST.get("csrf_token")
|
||||
)
|
||||
if self.is_csrf_valid(csrf_token):
|
||||
if self.is_csrf_valid(request, csrf_token):
|
||||
return None
|
||||
else:
|
||||
raise CSRFError()
|
||||
return None
|
||||
|
||||
def process_response(self, request: Request, response: HttpResponse) -> None:
|
||||
token = self.get_csrf_token()
|
||||
token = self.get_csrf_token(request)
|
||||
# do we need it in both places?
|
||||
response.headers["X-CSRF-TOKEN"] = token
|
||||
response.context |= {
|
||||
@ -48,17 +65,22 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
||||
"raw_csrf_token": token, # in case they want to format it themselves
|
||||
}
|
||||
|
||||
def get_csrf_token(self):
|
||||
return self.server.encrypt(str(datetime.now().isoformat())).decode(
|
||||
self.server.DEFAULT_ENCODING
|
||||
)
|
||||
def get_csrf_token(self, request):
|
||||
# the session key should be here because we've processed the session first
|
||||
session_key = request._session["id"]
|
||||
return self.server.encrypt(
|
||||
f"{str(datetime.now().isoformat())}::{session_key}"
|
||||
).decode(self.server.DEFAULT_ENCODING)
|
||||
|
||||
def is_csrf_valid(self, key):
|
||||
def is_csrf_valid(self, request, key):
|
||||
try:
|
||||
decoded = self.server.decrypt(key)
|
||||
timestamp, session_key = decoded.split("::")
|
||||
if session_key != request._session["id"]:
|
||||
return False
|
||||
if datetime.now() - timedelta(
|
||||
seconds=self.CSRF_EXPIRY
|
||||
) > datetime.fromisoformat(decoded):
|
||||
) > datetime.fromisoformat(timestamp):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
|
@ -30,7 +30,7 @@ class RoutesMixin:
|
||||
# ones that start with underscores are the compiled versions, non-underscores
|
||||
# are the user-supplied versions
|
||||
_routes: dict
|
||||
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
||||
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
|
||||
_error_routes: dict
|
||||
error_routes: dict[int, Callable]
|
||||
append_slash: bool
|
||||
|
18
spiderweb/server_checks.py
Normal file
18
spiderweb/server_checks.py
Normal file
@ -0,0 +1,18 @@
|
||||
class ServerCheck:
|
||||
"""
|
||||
Single server check base class.
|
||||
|
||||
During startup, each middleware can request checks to be run against the
|
||||
current state of the server. These are usually used to verify that things
|
||||
are configured correctly, but can also be used for database setup or other
|
||||
similar things.
|
||||
|
||||
To fail a check, raise any error that makes sense to raise. This will halt
|
||||
startup so the error can be fixed.
|
||||
"""
|
||||
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
|
||||
def check(self):
|
||||
pass
|
@ -1 +1,4 @@
|
||||
from spiderweb.tests.middleware import ExplodingResponseMiddleware, ExplodingRequestMiddleware
|
||||
from spiderweb.tests.middleware import (
|
||||
ExplodingResponseMiddleware,
|
||||
ExplodingRequestMiddleware,
|
||||
)
|
||||
|
@ -7,5 +7,7 @@ class ExplodingRequestMiddleware(SpiderwebMiddleware):
|
||||
|
||||
|
||||
class ExplodingResponseMiddleware(SpiderwebMiddleware):
|
||||
def process_response(self, request: Request, response: HttpResponse) -> HttpResponse | None:
|
||||
def process_response(
|
||||
self, request: Request, response: HttpResponse
|
||||
) -> HttpResponse | None:
|
||||
raise UnusedMiddleware("Unfinished!")
|
@ -1,11 +1,16 @@
|
||||
from io import BytesIO, BufferedReader
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
from peewee import SqliteDatabase
|
||||
|
||||
from spiderweb import SpiderwebRouter, HttpResponse
|
||||
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
|
||||
from spiderweb.constants import DEFAULT_ENCODING
|
||||
from spiderweb.middleware.sessions import Session
|
||||
from spiderweb.middleware import csrf
|
||||
from spiderweb.tests.helpers import setup
|
||||
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
|
||||
|
||||
|
||||
# app = SpiderwebRouter(
|
||||
# middleware=[
|
||||
@ -18,19 +23,20 @@ from spiderweb.tests.helpers import setup
|
||||
# ],
|
||||
# )
|
||||
|
||||
|
||||
def index(request):
|
||||
if "value" in request.SESSION:
|
||||
request.SESSION['value'] += 1
|
||||
request.SESSION["value"] += 1
|
||||
else:
|
||||
request.SESSION['value'] = 0
|
||||
return HttpResponse(body=str(request.SESSION['value']))
|
||||
request.SESSION["value"] = 0
|
||||
return HttpResponse(body=str(request.SESSION["value"]))
|
||||
|
||||
|
||||
def test_session_middleware():
|
||||
_, environ, start_response = setup()
|
||||
app = SpiderwebRouter(
|
||||
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
|
||||
db=SqliteDatabase("spiderweb-tests.db")
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
|
||||
app.add_route("/", index)
|
||||
@ -46,11 +52,12 @@ def test_session_middleware():
|
||||
assert app(environ, start_response) == [bytes(str(1), DEFAULT_ENCODING)]
|
||||
assert app(environ, start_response) == [bytes(str(2), DEFAULT_ENCODING)]
|
||||
|
||||
|
||||
def test_expired_session():
|
||||
_, environ, start_response = setup()
|
||||
app = SpiderwebRouter(
|
||||
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
|
||||
db=SqliteDatabase("spiderweb-tests.db")
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
|
||||
app.add_route("/", index)
|
||||
@ -80,9 +87,170 @@ def test_exploding_middleware():
|
||||
"spiderweb.tests.middleware.ExplodingRequestMiddleware",
|
||||
"spiderweb.tests.middleware.ExplodingResponseMiddleware",
|
||||
],
|
||||
db=SqliteDatabase("spiderweb-tests.db")
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
|
||||
app.add_route("/", index)
|
||||
|
||||
assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)]
|
||||
# make sure it kicked out the middleware and isn't just ignoring it
|
||||
assert len(app.middleware) == 0
|
||||
|
||||
|
||||
def test_csrf_middleware_without_session_middleware():
|
||||
_, environ, start_response = setup()
|
||||
with pytest.raises(ConfigError) as e:
|
||||
SpiderwebRouter(
|
||||
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
|
||||
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
|
||||
|
||||
|
||||
def test_csrf_middleware_above_session_middleware():
|
||||
_, environ, start_response = setup()
|
||||
with pytest.raises(ConfigError) as e:
|
||||
SpiderwebRouter(
|
||||
middleware=[
|
||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||
],
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
|
||||
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
|
||||
|
||||
|
||||
def test_csrf_middleware():
|
||||
_, environ, start_response = setup()
|
||||
app = SpiderwebRouter(
|
||||
middleware=[
|
||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||
],
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
|
||||
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
|
||||
|
||||
environ["HTTP_USER_AGENT"] = "hi"
|
||||
environ["REMOTE_ADDR"] = "1.1.1.1"
|
||||
|
||||
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
|
||||
assert "<form" in resp
|
||||
assert '<input type="hidden" name="csrf_token"' in resp
|
||||
|
||||
token = resp.split('value="')[1].split('"')[0]
|
||||
|
||||
formdata = f"name=bob&csrf_token={token}"
|
||||
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
|
||||
environ["HTTP_COOKIE"] = (
|
||||
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
|
||||
)
|
||||
environ["REQUEST_METHOD"] = "POST"
|
||||
environ["HTTP_X_CSRF_TOKEN"] = token
|
||||
environ["CONTENT_LENGTH"] = len(formdata)
|
||||
|
||||
# setup form data
|
||||
b_handle = BytesIO()
|
||||
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||
b_handle.seek(0)
|
||||
|
||||
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||
|
||||
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
|
||||
assert "bob" in resp2
|
||||
|
||||
# test that it raises a CSRF error on wrong token
|
||||
formdata = f"name=bob&csrf_token=badtoken"
|
||||
b_handle = BytesIO()
|
||||
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||
b_handle.seek(0)
|
||||
|
||||
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||
resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
assert "CSRF token is invalid" in resp3
|
||||
|
||||
# test that the wrong session also raises a CSRF error
|
||||
token = app.decrypt(token).split("::")[0]
|
||||
token = app.encrypt(f"{token}::badsession").decode(DEFAULT_ENCODING)
|
||||
formdata = f"name=bob&csrf_token={token}"
|
||||
b_handle = BytesIO()
|
||||
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||
b_handle.seek(0)
|
||||
|
||||
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||
resp4 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
assert "CSRF token is invalid" in resp4
|
||||
|
||||
|
||||
def test_csrf_expired_token():
|
||||
_, environ, start_response = setup()
|
||||
app = SpiderwebRouter(
|
||||
middleware=[
|
||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||
],
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
app.middleware[1].CSRF_EXPIRY = -1
|
||||
|
||||
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
|
||||
|
||||
environ["HTTP_USER_AGENT"] = "hi"
|
||||
environ["REMOTE_ADDR"] = "1.1.1.1"
|
||||
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
token = resp.split('value="')[1].split('"')[0]
|
||||
|
||||
formdata = f"name=bob&csrf_token={token}"
|
||||
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
|
||||
environ["HTTP_COOKIE"] = (
|
||||
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
|
||||
)
|
||||
environ["REQUEST_METHOD"] = "POST"
|
||||
environ["HTTP_X_CSRF_TOKEN"] = token
|
||||
environ["CONTENT_LENGTH"] = len(formdata)
|
||||
|
||||
b_handle = BytesIO()
|
||||
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||
b_handle.seek(0)
|
||||
|
||||
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
assert "CSRF token is invalid" in resp
|
||||
|
||||
|
||||
def test_csrf_exempt():
|
||||
_, environ, start_response = setup()
|
||||
app = SpiderwebRouter(
|
||||
middleware=[
|
||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||
],
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
|
||||
app.add_route("/", form_csrf_exempt, ["GET", "POST"])
|
||||
app.add_route("/2", form_view_without_csrf, ["GET", "POST"])
|
||||
|
||||
environ["HTTP_USER_AGENT"] = "hi"
|
||||
environ["REMOTE_ADDR"] = "1.1.1.1"
|
||||
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
|
||||
environ["REQUEST_METHOD"] = "POST"
|
||||
|
||||
formdata = "name=bob"
|
||||
environ["CONTENT_LENGTH"] = len(formdata)
|
||||
b_handle = BytesIO()
|
||||
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||
b_handle.seek(0)
|
||||
|
||||
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
assert "bob" in resp
|
||||
|
||||
environ["PATH_INFO"] = "/2"
|
||||
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
assert "CSRF token is invalid" in resp2
|
||||
|
@ -3,7 +3,12 @@ import pytest
|
||||
from spiderweb import SpiderwebRouter, ConfigError
|
||||
from spiderweb.constants import DEFAULT_ENCODING
|
||||
from spiderweb.exceptions import NoResponseError, SpiderwebNetworkException
|
||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
||||
from spiderweb.response import (
|
||||
HttpResponse,
|
||||
JsonResponse,
|
||||
TemplateResponse,
|
||||
RedirectResponse,
|
||||
)
|
||||
from hypothesis import given, strategies as st
|
||||
|
||||
from spiderweb.tests.helpers import setup
|
||||
@ -27,7 +32,9 @@ def test_json_response():
|
||||
def index(request):
|
||||
return JsonResponse(data={"message": "text"})
|
||||
|
||||
assert app(environ, start_response) == [bytes('{"message": "text"}', DEFAULT_ENCODING)]
|
||||
assert app(environ, start_response) == [
|
||||
bytes('{"message": "text"}', DEFAULT_ENCODING)
|
||||
]
|
||||
|
||||
|
||||
def test_dict_response():
|
||||
@ -51,7 +58,9 @@ def test_template_response(text):
|
||||
request, template_string=template, context={"message": text}
|
||||
)
|
||||
|
||||
assert app(environ, start_response) == [b"MESSAGE: " + bytes(text, DEFAULT_ENCODING)]
|
||||
assert app(environ, start_response) == [
|
||||
b"MESSAGE: " + bytes(text, DEFAULT_ENCODING)
|
||||
]
|
||||
|
||||
|
||||
def test_redirect_response():
|
||||
@ -61,7 +70,7 @@ def test_redirect_response():
|
||||
def index(request):
|
||||
return RedirectResponse(location="/redirected")
|
||||
|
||||
assert app(environ, start_response) == [b'None']
|
||||
assert app(environ, start_response) == [b"None"]
|
||||
assert start_response.get_headers()["Location"] == "/redirected"
|
||||
|
||||
|
||||
@ -74,12 +83,14 @@ def test_add_route_at_server_start():
|
||||
def view2(request):
|
||||
return HttpResponse("View 2")
|
||||
|
||||
app = SpiderwebRouter(routes=[
|
||||
app = SpiderwebRouter(
|
||||
routes=[
|
||||
("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}),
|
||||
("/view2", view2),
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
assert app(environ, start_response) == [b'None']
|
||||
assert app(environ, start_response) == [b"None"]
|
||||
assert start_response.get_headers()["Location"] == "/redirected"
|
||||
|
||||
|
||||
@ -92,7 +103,7 @@ def test_redirect_on_append_slash():
|
||||
pass
|
||||
|
||||
environ["PATH_INFO"] = f"/hello"
|
||||
assert app(environ, start_response) == [b'None']
|
||||
assert app(environ, start_response) == [b"None"]
|
||||
assert start_response.get_headers()["Location"] == "/hello/"
|
||||
|
||||
|
||||
@ -104,11 +115,11 @@ def test_template_response_with_template(text):
|
||||
|
||||
@app.route("/")
|
||||
def index(request):
|
||||
return TemplateResponse(
|
||||
request, "test.html", context={"message": text}
|
||||
)
|
||||
return TemplateResponse(request, "test.html", context={"message": text})
|
||||
|
||||
assert app(environ, start_response) == [b"TEMPLATE! " + bytes(text, DEFAULT_ENCODING)]
|
||||
assert app(environ, start_response) == [
|
||||
b"TEMPLATE! " + bytes(text, DEFAULT_ENCODING)
|
||||
]
|
||||
|
||||
|
||||
def test_view_returns_none():
|
||||
@ -119,7 +130,7 @@ def test_view_returns_none():
|
||||
pass
|
||||
|
||||
with pytest.raises(NoResponseError):
|
||||
assert app(environ, start_response) == [b'None']
|
||||
assert app(environ, start_response) == [b"None"]
|
||||
|
||||
|
||||
def test_exploding_view():
|
||||
@ -130,9 +141,10 @@ def test_exploding_view():
|
||||
raise SpiderwebNetworkException("Boom!")
|
||||
|
||||
assert app(environ, start_response) == [
|
||||
b'Something went wrong.\n\nCode: Boom!\n\nMsg: None\n\nDesc: None'
|
||||
b"Something went wrong.\n\nCode: Boom!\n\nMsg: None\n\nDesc: None"
|
||||
]
|
||||
|
||||
|
||||
def test_missing_view():
|
||||
app, environ, start_response = setup()
|
||||
|
||||
@ -146,20 +158,19 @@ def test_missing_view_with_custom_404():
|
||||
def custom_404(request):
|
||||
return HttpResponse("Custom 404")
|
||||
|
||||
assert app(environ, start_response) == [b'Custom 404']
|
||||
assert app(environ, start_response) == [b"Custom 404"]
|
||||
|
||||
|
||||
def test_duplicate_error_view():
|
||||
app, environ, start_response = setup()
|
||||
|
||||
@app.error(404)
|
||||
def custom_404(request):
|
||||
...
|
||||
def custom_404(request): ...
|
||||
|
||||
with pytest.raises(ConfigError):
|
||||
|
||||
@app.error(404)
|
||||
def custom_404(request):
|
||||
...
|
||||
def custom_404(request): ...
|
||||
|
||||
|
||||
def test_missing_view_with_custom_404_alt():
|
||||
@ -170,4 +181,4 @@ def test_missing_view_with_custom_404_alt():
|
||||
|
||||
app = SpiderwebRouter(error_routes={404: custom_404})
|
||||
|
||||
assert app(environ, start_response) == [b'Custom 404 2']
|
||||
assert app(environ, start_response) == [b"Custom 404 2"]
|
||||
|
@ -3,7 +3,12 @@ import pytest
|
||||
from spiderweb import SpiderwebRouter
|
||||
from spiderweb.constants import DEFAULT_ENCODING
|
||||
from spiderweb.exceptions import ParseError, ConfigError
|
||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
||||
from spiderweb.response import (
|
||||
HttpResponse,
|
||||
JsonResponse,
|
||||
TemplateResponse,
|
||||
RedirectResponse,
|
||||
)
|
||||
from hypothesis import given, strategies as st, assume
|
||||
|
||||
from peewee import SqliteDatabase
|
||||
@ -43,6 +48,7 @@ def test_unknown_converter():
|
||||
app, environ, start_response = setup()
|
||||
|
||||
with pytest.raises(ParseError):
|
||||
|
||||
@app.route("/<asdf:test_input>")
|
||||
def index(request, test_input: str):
|
||||
return HttpResponse(test_input)
|
||||
@ -52,19 +58,19 @@ def test_duplicate_route():
|
||||
app, environ, start_response = setup()
|
||||
|
||||
@app.route("/")
|
||||
def index(request):
|
||||
...
|
||||
def index(request): ...
|
||||
|
||||
with pytest.raises(ConfigError):
|
||||
|
||||
@app.route("/")
|
||||
def index(request):
|
||||
...
|
||||
def index(request): ...
|
||||
|
||||
|
||||
def test_url_with_double_underscore():
|
||||
app, environ, start_response = setup()
|
||||
|
||||
with pytest.raises(ConfigError):
|
||||
|
||||
@app.route("/<asdf:test__input>")
|
||||
def index(request, test_input: str):
|
||||
return HttpResponse(test_input)
|
||||
|
40
spiderweb/tests/views_for_tests.py
Normal file
40
spiderweb/tests/views_for_tests.py
Normal file
@ -0,0 +1,40 @@
|
||||
from spiderweb.decorators import csrf_exempt
|
||||
from spiderweb.response import JsonResponse, TemplateResponse
|
||||
|
||||
|
||||
EXAMPLE_HTML_FORM = """
|
||||
<form action="" method="post">
|
||||
<input type="text" name="name" />
|
||||
<input type="submit" />
|
||||
</form>
|
||||
"""
|
||||
|
||||
EXAMPLE_HTML_FORM_WITH_CSRF = """
|
||||
<form action="" method="post">
|
||||
<input type="text" name="name" />
|
||||
<input type="submit" />
|
||||
{{ csrf_token }}
|
||||
</form>
|
||||
"""
|
||||
|
||||
|
||||
def form_view_without_csrf(request):
|
||||
if request.method == "POST":
|
||||
return JsonResponse(data=request.POST)
|
||||
else:
|
||||
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM)
|
||||
|
||||
|
||||
@csrf_exempt
|
||||
def form_csrf_exempt(request):
|
||||
if request.method == "POST":
|
||||
return JsonResponse(data=request.POST)
|
||||
else:
|
||||
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)
|
||||
|
||||
|
||||
def form_view_with_csrf(request):
|
||||
if request.method == "POST":
|
||||
return JsonResponse(data=request.POST)
|
||||
else:
|
||||
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)
|
Loading…
Reference in New Issue
Block a user