✨ add server checks & fix csrf middleware
This commit is contained in:
parent
aabe20cff7
commit
c451aff1e2
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "spiderweb-framework"
|
name = "spiderweb-framework"
|
||||||
version = "0.10.0"
|
version = "0.11.0"
|
||||||
description = "A small web framework, just big enough for a spider."
|
description = "A small web framework, just big enough for a spider."
|
||||||
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
|
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -2,7 +2,7 @@ from peewee import DatabaseProxy
|
|||||||
|
|
||||||
DEFAULT_ALLOWED_METHODS = ["GET"]
|
DEFAULT_ALLOWED_METHODS = ["GET"]
|
||||||
DEFAULT_ENCODING = "UTF-8"
|
DEFAULT_ENCODING = "UTF-8"
|
||||||
__version__ = "0.10.0"
|
__version__ = "0.11.0"
|
||||||
|
|
||||||
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
|
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
|
||||||
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
||||||
|
@ -5,7 +5,7 @@ class SpiderwebException(Exception):
|
|||||||
msg = self.args[0] if len(self.args) > 0 else ""
|
msg = self.args[0] if len(self.args) > 0 else ""
|
||||||
if msg:
|
if msg:
|
||||||
return f"{name}() - {msg}"
|
return f"{name}() - {msg}"
|
||||||
return f"{self.__class__.__name__}()"
|
return f"{name}()"
|
||||||
|
|
||||||
|
|
||||||
class SpiderwebNetworkException(SpiderwebException):
|
class SpiderwebNetworkException(SpiderwebException):
|
||||||
|
@ -32,7 +32,7 @@ from spiderweb.routes import RoutesMixin
|
|||||||
from spiderweb.secrets import FernetMixin
|
from spiderweb.secrets import FernetMixin
|
||||||
from spiderweb.utils import get_http_status_by_code
|
from spiderweb.utils import get_http_status_by_code
|
||||||
|
|
||||||
file_logger = logging.getLogger(__name__)
|
console_logger = logging.getLogger(__name__)
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
session_cookie_same_site="lax",
|
session_cookie_same_site="lax",
|
||||||
session_cookie_path="/",
|
session_cookie_path="/",
|
||||||
log=None,
|
log=None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._routes = {}
|
self._routes = {}
|
||||||
self.routes = routes
|
self.routes = routes
|
||||||
@ -69,7 +69,8 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
self.append_slash = append_slash
|
self.append_slash = append_slash
|
||||||
self.templates_dirs = templates_dirs
|
self.templates_dirs = templates_dirs
|
||||||
self.staticfiles_dirs = staticfiles_dirs
|
self.staticfiles_dirs = staticfiles_dirs
|
||||||
self.middleware = middleware if middleware else []
|
self._middleware: list[str] = middleware if middleware else []
|
||||||
|
self.middleware: list[Callable] = []
|
||||||
self.secret_key = secret_key if secret_key else self.generate_key()
|
self.secret_key = secret_key if secret_key else self.generate_key()
|
||||||
|
|
||||||
self.extra_data = kwargs
|
self.extra_data = kwargs
|
||||||
@ -84,7 +85,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
|
|
||||||
self.DEFAULT_ENCODING = DEFAULT_ENCODING
|
self.DEFAULT_ENCODING = DEFAULT_ENCODING
|
||||||
self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS
|
self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS
|
||||||
self.log: logging.Logger = log if log else file_logger
|
self.log: logging.Logger = log if log else console_logger
|
||||||
|
|
||||||
# for using .start() and .stop()
|
# for using .start() and .stop()
|
||||||
self._thread: Optional[Thread] = None
|
self._thread: Optional[Thread] = None
|
||||||
@ -108,7 +109,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
self.add_error_routes()
|
self.add_error_routes()
|
||||||
|
|
||||||
if self.templates_dirs:
|
if self.templates_dirs:
|
||||||
self.template_loader = Environment(loader=FileSystemLoader(self.templates_dirs))
|
self.template_loader = Environment(
|
||||||
|
loader=FileSystemLoader(self.templates_dirs)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.template_loader = None
|
self.template_loader = None
|
||||||
self.string_loader = Environment(loader=BaseLoader())
|
self.string_loader = Environment(loader=BaseLoader())
|
||||||
@ -117,12 +120,16 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
for static_dir in self.staticfiles_dirs:
|
for static_dir in self.staticfiles_dirs:
|
||||||
static_dir = pathlib.Path(static_dir)
|
static_dir = pathlib.Path(static_dir)
|
||||||
if not pathlib.Path(self.BASE_DIR / static_dir).exists():
|
if not pathlib.Path(self.BASE_DIR / static_dir).exists():
|
||||||
log.error(
|
self.log.error(
|
||||||
f"Static files directory '{str(static_dir)}' does not exist."
|
f"Static files directory '{str(static_dir)}' does not exist."
|
||||||
)
|
)
|
||||||
raise ConfigError
|
raise ConfigError
|
||||||
self.add_route(r"/static/<str:filename>", send_file) # noqa: F405
|
self.add_route(r"/static/<str:filename>", send_file) # noqa: F405
|
||||||
|
|
||||||
|
# finally, run the startup checks to verify everything is correct and happy.
|
||||||
|
self.log.info("Run startup checks...")
|
||||||
|
self.run_middleware_checks()
|
||||||
|
|
||||||
def fire_response(self, start_response, request: Request, resp: HttpResponse):
|
def fire_response(self, start_response, request: Request, resp: HttpResponse):
|
||||||
try:
|
try:
|
||||||
status = get_http_status_by_code(resp.status_code)
|
status = get_http_status_by_code(resp.status_code)
|
||||||
@ -190,7 +197,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
def prepare_and_fire_response(self, start_response, request, resp) -> list[bytes]:
|
def prepare_and_fire_response(self, start_response, request, resp) -> list[bytes]:
|
||||||
try:
|
try:
|
||||||
if isinstance(resp, dict):
|
if isinstance(resp, dict):
|
||||||
return self.fire_response(start_response, request, JsonResponse(data=resp))
|
return self.fire_response(
|
||||||
|
start_response, request, JsonResponse(data=resp)
|
||||||
|
)
|
||||||
if isinstance(resp, TemplateResponse):
|
if isinstance(resp, TemplateResponse):
|
||||||
resp.set_template_loader(self.template_loader)
|
resp.set_template_loader(self.template_loader)
|
||||||
resp.set_string_loader(self.string_loader)
|
resp.set_string_loader(self.string_loader)
|
||||||
|
@ -12,19 +12,26 @@ from ..utils import import_by_string
|
|||||||
class MiddlewareMixin:
|
class MiddlewareMixin:
|
||||||
"""Cannot be called on its own. Requires context of SpiderwebRouter."""
|
"""Cannot be called on its own. Requires context of SpiderwebRouter."""
|
||||||
|
|
||||||
|
_middleware: list[str]
|
||||||
middleware: list[ClassVar]
|
middleware: list[ClassVar]
|
||||||
fire_response: Callable
|
fire_response: Callable
|
||||||
|
|
||||||
def init_middleware(self):
|
def init_middleware(self):
|
||||||
if self.middleware:
|
if self._middleware:
|
||||||
middleware_by_reference = []
|
middleware_by_reference = []
|
||||||
for m in self.middleware:
|
for m in self._middleware:
|
||||||
try:
|
try:
|
||||||
middleware_by_reference.append(import_by_string(m)(server=self))
|
middleware_by_reference.append(import_by_string(m)(server=self))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ConfigError(f"Middleware '{m}' not found.")
|
raise ConfigError(f"Middleware '{m}' not found.")
|
||||||
self.middleware = middleware_by_reference
|
self.middleware = middleware_by_reference
|
||||||
|
|
||||||
|
def run_middleware_checks(self):
|
||||||
|
for middleware in self.middleware:
|
||||||
|
if hasattr(middleware, "checks"):
|
||||||
|
for check in middleware.checks:
|
||||||
|
check(server=self).check()
|
||||||
|
|
||||||
def process_request_middleware(self, request: Request) -> None | bool:
|
def process_request_middleware(self, request: Request) -> None | bool:
|
||||||
for middleware in self.middleware:
|
for middleware in self.middleware:
|
||||||
try:
|
try:
|
||||||
|
@ -1,25 +1,42 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from spiderweb.exceptions import CSRFError
|
from spiderweb.exceptions import CSRFError, ConfigError
|
||||||
from spiderweb.middleware import SpiderwebMiddleware
|
from spiderweb.middleware import SpiderwebMiddleware
|
||||||
from spiderweb.request import Request
|
from spiderweb.request import Request
|
||||||
from spiderweb.response import HttpResponse
|
from spiderweb.response import HttpResponse
|
||||||
|
from spiderweb.server_checks import ServerCheck
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCheck(ServerCheck):
|
||||||
|
|
||||||
|
SESSION_MIDDLEWARE_NOT_FOUND = (
|
||||||
|
"Session middleware is not enabled. It must be listed above"
|
||||||
|
"CSRFMiddleware in the middleware list."
|
||||||
|
)
|
||||||
|
SESSION_MIDDLEWARE_BELOW_CSRF = (
|
||||||
|
"SessionMiddleware is enabled, but it must be listed above"
|
||||||
|
"CSRFMiddleware in the middleware list."
|
||||||
|
)
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
|
||||||
|
if (
|
||||||
|
"spiderweb.middleware.sessions.SessionMiddleware"
|
||||||
|
not in self.server._middleware
|
||||||
|
):
|
||||||
|
raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
|
||||||
|
|
||||||
|
if self.server._middleware.index(
|
||||||
|
"spiderweb.middleware.sessions.SessionMiddleware"
|
||||||
|
) > self.server._middleware.index(
|
||||||
|
"spiderweb.middleware.csrf.CSRFMiddleware"
|
||||||
|
):
|
||||||
|
raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
|
||||||
|
|
||||||
|
|
||||||
class CSRFMiddleware(SpiderwebMiddleware):
|
class CSRFMiddleware(SpiderwebMiddleware):
|
||||||
"""
|
|
||||||
tl;dr: this is a naive implementation going off just what I could think of
|
|
||||||
at the time. It is very vulnerable to CSRF Forgery and should be updated.
|
|
||||||
|
|
||||||
Eventually I'll probably just pull everything out of Django and use their
|
checks = [SessionCheck]
|
||||||
implementation, as it's written by people who know a lot more about these
|
|
||||||
things than I do, but in the meantime, this is still here until I get
|
|
||||||
around to making it more solid.
|
|
||||||
|
|
||||||
todo: fix
|
|
||||||
|
|
||||||
https://en.wikipedia.org/wiki/Cross-site_request_forgery
|
|
||||||
"""
|
|
||||||
|
|
||||||
CSRF_EXPIRY = 60 * 60 # 1 hour
|
CSRF_EXPIRY = 60 * 60 # 1 hour
|
||||||
|
|
||||||
@ -33,14 +50,14 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
|||||||
or request.GET.get("csrf_token")
|
or request.GET.get("csrf_token")
|
||||||
or request.POST.get("csrf_token")
|
or request.POST.get("csrf_token")
|
||||||
)
|
)
|
||||||
if self.is_csrf_valid(csrf_token):
|
if self.is_csrf_valid(request, csrf_token):
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
raise CSRFError()
|
raise CSRFError()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def process_response(self, request: Request, response: HttpResponse) -> None:
|
def process_response(self, request: Request, response: HttpResponse) -> None:
|
||||||
token = self.get_csrf_token()
|
token = self.get_csrf_token(request)
|
||||||
# do we need it in both places?
|
# do we need it in both places?
|
||||||
response.headers["X-CSRF-TOKEN"] = token
|
response.headers["X-CSRF-TOKEN"] = token
|
||||||
response.context |= {
|
response.context |= {
|
||||||
@ -48,17 +65,22 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
|||||||
"raw_csrf_token": token, # in case they want to format it themselves
|
"raw_csrf_token": token, # in case they want to format it themselves
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_csrf_token(self):
|
def get_csrf_token(self, request):
|
||||||
return self.server.encrypt(str(datetime.now().isoformat())).decode(
|
# the session key should be here because we've processed the session first
|
||||||
self.server.DEFAULT_ENCODING
|
session_key = request._session["id"]
|
||||||
)
|
return self.server.encrypt(
|
||||||
|
f"{str(datetime.now().isoformat())}::{session_key}"
|
||||||
|
).decode(self.server.DEFAULT_ENCODING)
|
||||||
|
|
||||||
def is_csrf_valid(self, key):
|
def is_csrf_valid(self, request, key):
|
||||||
try:
|
try:
|
||||||
decoded = self.server.decrypt(key)
|
decoded = self.server.decrypt(key)
|
||||||
|
timestamp, session_key = decoded.split("::")
|
||||||
|
if session_key != request._session["id"]:
|
||||||
|
return False
|
||||||
if datetime.now() - timedelta(
|
if datetime.now() - timedelta(
|
||||||
seconds=self.CSRF_EXPIRY
|
seconds=self.CSRF_EXPIRY
|
||||||
) > datetime.fromisoformat(decoded):
|
) > datetime.fromisoformat(timestamp):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -30,7 +30,7 @@ class RoutesMixin:
|
|||||||
# ones that start with underscores are the compiled versions, non-underscores
|
# ones that start with underscores are the compiled versions, non-underscores
|
||||||
# are the user-supplied versions
|
# are the user-supplied versions
|
||||||
_routes: dict
|
_routes: dict
|
||||||
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
|
||||||
_error_routes: dict
|
_error_routes: dict
|
||||||
error_routes: dict[int, Callable]
|
error_routes: dict[int, Callable]
|
||||||
append_slash: bool
|
append_slash: bool
|
||||||
|
18
spiderweb/server_checks.py
Normal file
18
spiderweb/server_checks.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
class ServerCheck:
|
||||||
|
"""
|
||||||
|
Single server check base class.
|
||||||
|
|
||||||
|
During startup, each middleware can request checks to be run against the
|
||||||
|
current state of the server. These are usually used to verify that things
|
||||||
|
are configured correctly, but can also be used for database setup or other
|
||||||
|
similar things.
|
||||||
|
|
||||||
|
To fail a check, raise any error that makes sense to raise. This will halt
|
||||||
|
startup so the error can be fixed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, server):
|
||||||
|
self.server = server
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
pass
|
@ -1 +1,4 @@
|
|||||||
from spiderweb.tests.middleware import ExplodingResponseMiddleware, ExplodingRequestMiddleware
|
from spiderweb.tests.middleware import (
|
||||||
|
ExplodingResponseMiddleware,
|
||||||
|
ExplodingRequestMiddleware,
|
||||||
|
)
|
||||||
|
@ -7,5 +7,7 @@ class ExplodingRequestMiddleware(SpiderwebMiddleware):
|
|||||||
|
|
||||||
|
|
||||||
class ExplodingResponseMiddleware(SpiderwebMiddleware):
|
class ExplodingResponseMiddleware(SpiderwebMiddleware):
|
||||||
def process_response(self, request: Request, response: HttpResponse) -> HttpResponse | None:
|
def process_response(
|
||||||
|
self, request: Request, response: HttpResponse
|
||||||
|
) -> HttpResponse | None:
|
||||||
raise UnusedMiddleware("Unfinished!")
|
raise UnusedMiddleware("Unfinished!")
|
@ -1,11 +1,16 @@
|
|||||||
|
from io import BytesIO, BufferedReader
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
from peewee import SqliteDatabase
|
from peewee import SqliteDatabase
|
||||||
|
|
||||||
from spiderweb import SpiderwebRouter, HttpResponse
|
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
|
||||||
from spiderweb.constants import DEFAULT_ENCODING
|
from spiderweb.constants import DEFAULT_ENCODING
|
||||||
from spiderweb.middleware.sessions import Session
|
from spiderweb.middleware.sessions import Session
|
||||||
|
from spiderweb.middleware import csrf
|
||||||
from spiderweb.tests.helpers import setup
|
from spiderweb.tests.helpers import setup
|
||||||
|
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
|
||||||
|
|
||||||
|
|
||||||
# app = SpiderwebRouter(
|
# app = SpiderwebRouter(
|
||||||
# middleware=[
|
# middleware=[
|
||||||
@ -18,19 +23,20 @@ from spiderweb.tests.helpers import setup
|
|||||||
# ],
|
# ],
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
def index(request):
|
def index(request):
|
||||||
if "value" in request.SESSION:
|
if "value" in request.SESSION:
|
||||||
request.SESSION['value'] += 1
|
request.SESSION["value"] += 1
|
||||||
else:
|
else:
|
||||||
request.SESSION['value'] = 0
|
request.SESSION["value"] = 0
|
||||||
return HttpResponse(body=str(request.SESSION['value']))
|
return HttpResponse(body=str(request.SESSION["value"]))
|
||||||
|
|
||||||
|
|
||||||
def test_session_middleware():
|
def test_session_middleware():
|
||||||
_, environ, start_response = setup()
|
_, environ, start_response = setup()
|
||||||
app = SpiderwebRouter(
|
app = SpiderwebRouter(
|
||||||
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
|
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
|
||||||
db=SqliteDatabase("spiderweb-tests.db")
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_route("/", index)
|
app.add_route("/", index)
|
||||||
@ -46,11 +52,12 @@ def test_session_middleware():
|
|||||||
assert app(environ, start_response) == [bytes(str(1), DEFAULT_ENCODING)]
|
assert app(environ, start_response) == [bytes(str(1), DEFAULT_ENCODING)]
|
||||||
assert app(environ, start_response) == [bytes(str(2), DEFAULT_ENCODING)]
|
assert app(environ, start_response) == [bytes(str(2), DEFAULT_ENCODING)]
|
||||||
|
|
||||||
|
|
||||||
def test_expired_session():
|
def test_expired_session():
|
||||||
_, environ, start_response = setup()
|
_, environ, start_response = setup()
|
||||||
app = SpiderwebRouter(
|
app = SpiderwebRouter(
|
||||||
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
|
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
|
||||||
db=SqliteDatabase("spiderweb-tests.db")
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_route("/", index)
|
app.add_route("/", index)
|
||||||
@ -80,9 +87,170 @@ def test_exploding_middleware():
|
|||||||
"spiderweb.tests.middleware.ExplodingRequestMiddleware",
|
"spiderweb.tests.middleware.ExplodingRequestMiddleware",
|
||||||
"spiderweb.tests.middleware.ExplodingResponseMiddleware",
|
"spiderweb.tests.middleware.ExplodingResponseMiddleware",
|
||||||
],
|
],
|
||||||
db=SqliteDatabase("spiderweb-tests.db")
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_route("/", index)
|
app.add_route("/", index)
|
||||||
|
|
||||||
assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)]
|
assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)]
|
||||||
|
# make sure it kicked out the middleware and isn't just ignoring it
|
||||||
|
assert len(app.middleware) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_middleware_without_session_middleware():
|
||||||
|
_, environ, start_response = setup()
|
||||||
|
with pytest.raises(ConfigError) as e:
|
||||||
|
SpiderwebRouter(
|
||||||
|
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_middleware_above_session_middleware():
|
||||||
|
_, environ, start_response = setup()
|
||||||
|
with pytest.raises(ConfigError) as e:
|
||||||
|
SpiderwebRouter(
|
||||||
|
middleware=[
|
||||||
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_middleware():
|
||||||
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
|
middleware=[
|
||||||
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
|
||||||
|
|
||||||
|
environ["HTTP_USER_AGENT"] = "hi"
|
||||||
|
environ["REMOTE_ADDR"] = "1.1.1.1"
|
||||||
|
|
||||||
|
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
|
|
||||||
|
assert "<form" in resp
|
||||||
|
assert '<input type="hidden" name="csrf_token"' in resp
|
||||||
|
|
||||||
|
token = resp.split('value="')[1].split('"')[0]
|
||||||
|
|
||||||
|
formdata = f"name=bob&csrf_token={token}"
|
||||||
|
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
|
||||||
|
environ["HTTP_COOKIE"] = (
|
||||||
|
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
|
||||||
|
)
|
||||||
|
environ["REQUEST_METHOD"] = "POST"
|
||||||
|
environ["HTTP_X_CSRF_TOKEN"] = token
|
||||||
|
environ["CONTENT_LENGTH"] = len(formdata)
|
||||||
|
|
||||||
|
# setup form data
|
||||||
|
b_handle = BytesIO()
|
||||||
|
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||||
|
b_handle.seek(0)
|
||||||
|
|
||||||
|
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||||
|
|
||||||
|
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
|
|
||||||
|
assert "bob" in resp2
|
||||||
|
|
||||||
|
# test that it raises a CSRF error on wrong token
|
||||||
|
formdata = f"name=bob&csrf_token=badtoken"
|
||||||
|
b_handle = BytesIO()
|
||||||
|
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||||
|
b_handle.seek(0)
|
||||||
|
|
||||||
|
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||||
|
resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
|
assert "CSRF token is invalid" in resp3
|
||||||
|
|
||||||
|
# test that the wrong session also raises a CSRF error
|
||||||
|
token = app.decrypt(token).split("::")[0]
|
||||||
|
token = app.encrypt(f"{token}::badsession").decode(DEFAULT_ENCODING)
|
||||||
|
formdata = f"name=bob&csrf_token={token}"
|
||||||
|
b_handle = BytesIO()
|
||||||
|
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||||
|
b_handle.seek(0)
|
||||||
|
|
||||||
|
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||||
|
resp4 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
|
assert "CSRF token is invalid" in resp4
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_expired_token():
|
||||||
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
|
middleware=[
|
||||||
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
|
)
|
||||||
|
app.middleware[1].CSRF_EXPIRY = -1
|
||||||
|
|
||||||
|
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
|
||||||
|
|
||||||
|
environ["HTTP_USER_AGENT"] = "hi"
|
||||||
|
environ["REMOTE_ADDR"] = "1.1.1.1"
|
||||||
|
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
|
token = resp.split('value="')[1].split('"')[0]
|
||||||
|
|
||||||
|
formdata = f"name=bob&csrf_token={token}"
|
||||||
|
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
|
||||||
|
environ["HTTP_COOKIE"] = (
|
||||||
|
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
|
||||||
|
)
|
||||||
|
environ["REQUEST_METHOD"] = "POST"
|
||||||
|
environ["HTTP_X_CSRF_TOKEN"] = token
|
||||||
|
environ["CONTENT_LENGTH"] = len(formdata)
|
||||||
|
|
||||||
|
b_handle = BytesIO()
|
||||||
|
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||||
|
b_handle.seek(0)
|
||||||
|
|
||||||
|
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||||
|
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
|
assert "CSRF token is invalid" in resp
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_exempt():
|
||||||
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
|
middleware=[
|
||||||
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_route("/", form_csrf_exempt, ["GET", "POST"])
|
||||||
|
app.add_route("/2", form_view_without_csrf, ["GET", "POST"])
|
||||||
|
|
||||||
|
environ["HTTP_USER_AGENT"] = "hi"
|
||||||
|
environ["REMOTE_ADDR"] = "1.1.1.1"
|
||||||
|
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
|
||||||
|
environ["REQUEST_METHOD"] = "POST"
|
||||||
|
|
||||||
|
formdata = "name=bob"
|
||||||
|
environ["CONTENT_LENGTH"] = len(formdata)
|
||||||
|
b_handle = BytesIO()
|
||||||
|
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||||
|
b_handle.seek(0)
|
||||||
|
|
||||||
|
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||||
|
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
|
assert "bob" in resp
|
||||||
|
|
||||||
|
environ["PATH_INFO"] = "/2"
|
||||||
|
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
|
assert "CSRF token is invalid" in resp2
|
||||||
|
@ -3,7 +3,12 @@ import pytest
|
|||||||
from spiderweb import SpiderwebRouter, ConfigError
|
from spiderweb import SpiderwebRouter, ConfigError
|
||||||
from spiderweb.constants import DEFAULT_ENCODING
|
from spiderweb.constants import DEFAULT_ENCODING
|
||||||
from spiderweb.exceptions import NoResponseError, SpiderwebNetworkException
|
from spiderweb.exceptions import NoResponseError, SpiderwebNetworkException
|
||||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
from spiderweb.response import (
|
||||||
|
HttpResponse,
|
||||||
|
JsonResponse,
|
||||||
|
TemplateResponse,
|
||||||
|
RedirectResponse,
|
||||||
|
)
|
||||||
from hypothesis import given, strategies as st
|
from hypothesis import given, strategies as st
|
||||||
|
|
||||||
from spiderweb.tests.helpers import setup
|
from spiderweb.tests.helpers import setup
|
||||||
@ -27,7 +32,9 @@ def test_json_response():
|
|||||||
def index(request):
|
def index(request):
|
||||||
return JsonResponse(data={"message": "text"})
|
return JsonResponse(data={"message": "text"})
|
||||||
|
|
||||||
assert app(environ, start_response) == [bytes('{"message": "text"}', DEFAULT_ENCODING)]
|
assert app(environ, start_response) == [
|
||||||
|
bytes('{"message": "text"}', DEFAULT_ENCODING)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_dict_response():
|
def test_dict_response():
|
||||||
@ -51,7 +58,9 @@ def test_template_response(text):
|
|||||||
request, template_string=template, context={"message": text}
|
request, template_string=template, context={"message": text}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert app(environ, start_response) == [b"MESSAGE: " + bytes(text, DEFAULT_ENCODING)]
|
assert app(environ, start_response) == [
|
||||||
|
b"MESSAGE: " + bytes(text, DEFAULT_ENCODING)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_redirect_response():
|
def test_redirect_response():
|
||||||
@ -61,7 +70,7 @@ def test_redirect_response():
|
|||||||
def index(request):
|
def index(request):
|
||||||
return RedirectResponse(location="/redirected")
|
return RedirectResponse(location="/redirected")
|
||||||
|
|
||||||
assert app(environ, start_response) == [b'None']
|
assert app(environ, start_response) == [b"None"]
|
||||||
assert start_response.get_headers()["Location"] == "/redirected"
|
assert start_response.get_headers()["Location"] == "/redirected"
|
||||||
|
|
||||||
|
|
||||||
@ -74,12 +83,14 @@ def test_add_route_at_server_start():
|
|||||||
def view2(request):
|
def view2(request):
|
||||||
return HttpResponse("View 2")
|
return HttpResponse("View 2")
|
||||||
|
|
||||||
app = SpiderwebRouter(routes=[
|
app = SpiderwebRouter(
|
||||||
|
routes=[
|
||||||
("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}),
|
("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}),
|
||||||
("/view2", view2),
|
("/view2", view2),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
assert app(environ, start_response) == [b'None']
|
assert app(environ, start_response) == [b"None"]
|
||||||
assert start_response.get_headers()["Location"] == "/redirected"
|
assert start_response.get_headers()["Location"] == "/redirected"
|
||||||
|
|
||||||
|
|
||||||
@ -92,7 +103,7 @@ def test_redirect_on_append_slash():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
environ["PATH_INFO"] = f"/hello"
|
environ["PATH_INFO"] = f"/hello"
|
||||||
assert app(environ, start_response) == [b'None']
|
assert app(environ, start_response) == [b"None"]
|
||||||
assert start_response.get_headers()["Location"] == "/hello/"
|
assert start_response.get_headers()["Location"] == "/hello/"
|
||||||
|
|
||||||
|
|
||||||
@ -104,11 +115,11 @@ def test_template_response_with_template(text):
|
|||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
def index(request):
|
def index(request):
|
||||||
return TemplateResponse(
|
return TemplateResponse(request, "test.html", context={"message": text})
|
||||||
request, "test.html", context={"message": text}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert app(environ, start_response) == [b"TEMPLATE! " + bytes(text, DEFAULT_ENCODING)]
|
assert app(environ, start_response) == [
|
||||||
|
b"TEMPLATE! " + bytes(text, DEFAULT_ENCODING)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_view_returns_none():
|
def test_view_returns_none():
|
||||||
@ -119,7 +130,7 @@ def test_view_returns_none():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
with pytest.raises(NoResponseError):
|
with pytest.raises(NoResponseError):
|
||||||
assert app(environ, start_response) == [b'None']
|
assert app(environ, start_response) == [b"None"]
|
||||||
|
|
||||||
|
|
||||||
def test_exploding_view():
|
def test_exploding_view():
|
||||||
@ -130,9 +141,10 @@ def test_exploding_view():
|
|||||||
raise SpiderwebNetworkException("Boom!")
|
raise SpiderwebNetworkException("Boom!")
|
||||||
|
|
||||||
assert app(environ, start_response) == [
|
assert app(environ, start_response) == [
|
||||||
b'Something went wrong.\n\nCode: Boom!\n\nMsg: None\n\nDesc: None'
|
b"Something went wrong.\n\nCode: Boom!\n\nMsg: None\n\nDesc: None"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_missing_view():
|
def test_missing_view():
|
||||||
app, environ, start_response = setup()
|
app, environ, start_response = setup()
|
||||||
|
|
||||||
@ -146,20 +158,19 @@ def test_missing_view_with_custom_404():
|
|||||||
def custom_404(request):
|
def custom_404(request):
|
||||||
return HttpResponse("Custom 404")
|
return HttpResponse("Custom 404")
|
||||||
|
|
||||||
assert app(environ, start_response) == [b'Custom 404']
|
assert app(environ, start_response) == [b"Custom 404"]
|
||||||
|
|
||||||
|
|
||||||
def test_duplicate_error_view():
|
def test_duplicate_error_view():
|
||||||
app, environ, start_response = setup()
|
app, environ, start_response = setup()
|
||||||
|
|
||||||
@app.error(404)
|
@app.error(404)
|
||||||
def custom_404(request):
|
def custom_404(request): ...
|
||||||
...
|
|
||||||
|
|
||||||
with pytest.raises(ConfigError):
|
with pytest.raises(ConfigError):
|
||||||
|
|
||||||
@app.error(404)
|
@app.error(404)
|
||||||
def custom_404(request):
|
def custom_404(request): ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def test_missing_view_with_custom_404_alt():
|
def test_missing_view_with_custom_404_alt():
|
||||||
@ -170,4 +181,4 @@ def test_missing_view_with_custom_404_alt():
|
|||||||
|
|
||||||
app = SpiderwebRouter(error_routes={404: custom_404})
|
app = SpiderwebRouter(error_routes={404: custom_404})
|
||||||
|
|
||||||
assert app(environ, start_response) == [b'Custom 404 2']
|
assert app(environ, start_response) == [b"Custom 404 2"]
|
||||||
|
@ -3,7 +3,12 @@ import pytest
|
|||||||
from spiderweb import SpiderwebRouter
|
from spiderweb import SpiderwebRouter
|
||||||
from spiderweb.constants import DEFAULT_ENCODING
|
from spiderweb.constants import DEFAULT_ENCODING
|
||||||
from spiderweb.exceptions import ParseError, ConfigError
|
from spiderweb.exceptions import ParseError, ConfigError
|
||||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
from spiderweb.response import (
|
||||||
|
HttpResponse,
|
||||||
|
JsonResponse,
|
||||||
|
TemplateResponse,
|
||||||
|
RedirectResponse,
|
||||||
|
)
|
||||||
from hypothesis import given, strategies as st, assume
|
from hypothesis import given, strategies as st, assume
|
||||||
|
|
||||||
from peewee import SqliteDatabase
|
from peewee import SqliteDatabase
|
||||||
@ -43,6 +48,7 @@ def test_unknown_converter():
|
|||||||
app, environ, start_response = setup()
|
app, environ, start_response = setup()
|
||||||
|
|
||||||
with pytest.raises(ParseError):
|
with pytest.raises(ParseError):
|
||||||
|
|
||||||
@app.route("/<asdf:test_input>")
|
@app.route("/<asdf:test_input>")
|
||||||
def index(request, test_input: str):
|
def index(request, test_input: str):
|
||||||
return HttpResponse(test_input)
|
return HttpResponse(test_input)
|
||||||
@ -52,19 +58,19 @@ def test_duplicate_route():
|
|||||||
app, environ, start_response = setup()
|
app, environ, start_response = setup()
|
||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
def index(request):
|
def index(request): ...
|
||||||
...
|
|
||||||
|
|
||||||
with pytest.raises(ConfigError):
|
with pytest.raises(ConfigError):
|
||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
def index(request):
|
def index(request): ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def test_url_with_double_underscore():
|
def test_url_with_double_underscore():
|
||||||
app, environ, start_response = setup()
|
app, environ, start_response = setup()
|
||||||
|
|
||||||
with pytest.raises(ConfigError):
|
with pytest.raises(ConfigError):
|
||||||
|
|
||||||
@app.route("/<asdf:test__input>")
|
@app.route("/<asdf:test__input>")
|
||||||
def index(request, test_input: str):
|
def index(request, test_input: str):
|
||||||
return HttpResponse(test_input)
|
return HttpResponse(test_input)
|
||||||
|
40
spiderweb/tests/views_for_tests.py
Normal file
40
spiderweb/tests/views_for_tests.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from spiderweb.decorators import csrf_exempt
|
||||||
|
from spiderweb.response import JsonResponse, TemplateResponse
|
||||||
|
|
||||||
|
|
||||||
|
EXAMPLE_HTML_FORM = """
|
||||||
|
<form action="" method="post">
|
||||||
|
<input type="text" name="name" />
|
||||||
|
<input type="submit" />
|
||||||
|
</form>
|
||||||
|
"""
|
||||||
|
|
||||||
|
EXAMPLE_HTML_FORM_WITH_CSRF = """
|
||||||
|
<form action="" method="post">
|
||||||
|
<input type="text" name="name" />
|
||||||
|
<input type="submit" />
|
||||||
|
{{ csrf_token }}
|
||||||
|
</form>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def form_view_without_csrf(request):
|
||||||
|
if request.method == "POST":
|
||||||
|
return JsonResponse(data=request.POST)
|
||||||
|
else:
|
||||||
|
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM)
|
||||||
|
|
||||||
|
|
||||||
|
@csrf_exempt
|
||||||
|
def form_csrf_exempt(request):
|
||||||
|
if request.method == "POST":
|
||||||
|
return JsonResponse(data=request.POST)
|
||||||
|
else:
|
||||||
|
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)
|
||||||
|
|
||||||
|
|
||||||
|
def form_view_with_csrf(request):
|
||||||
|
if request.method == "POST":
|
||||||
|
return JsonResponse(data=request.POST)
|
||||||
|
else:
|
||||||
|
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)
|
Loading…
Reference in New Issue
Block a user