Compare commits

...

3 Commits

Author SHA1 Message Date
9b9e1c8da0 Merge remote-tracking branch 'origin/main' 2024-09-09 01:44:21 -04:00
7ee119d42d 🔖 release 1.1.0 2024-09-09 01:44:09 -04:00
c9f3129b02 add app.reverse() function 2024-09-09 01:43:55 -04:00
10 changed files with 217 additions and 63 deletions

View File

@ -149,4 +149,31 @@ app = SpiderwebRouter(
error_routes={405: http405}, error_routes={405: http405},
) )
``` ```
As with the `routes` argument, as many routes as you'd like can be registered here without issue. As with the `routes` argument, as many routes as you'd like can be registered here without issue.
## Finding Routes Again
> New in 1.1.0
If you need to find the path that's associated with a route (for example, for a RedirectResponse), you can use the `app.reverse()` function to find it. This function takes the name of the view and returns the path that it's associated with. For example:
```python
@app.route("/example", name="example")
def example(request):
return HttpResponse(body="Example")
path = app.reverse("example")
print(path) # -> "/example"
```
If you have a route that takes arguments, you can pass them in as keyword arguments:
```python
@app.route("/example/<int:id>", name="example")
def example(request, id):
return HttpResponse(body=f"Example with id {id}")
path = app.reverse("example", id=3)
print(path) # -> "/example/3"
```
The arguments you pass in must match what the path expects, or you'll get a `SpiderwebException`. If there's no route with that name, you'll get a `ReverseNotFound` exception instead.

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "spiderweb-framework" name = "spiderweb-framework"
version = "1.0.0" version = "1.1.0"
description = "A small web framework, just big enough for a spider." description = "A small web framework, just big enough for a spider."
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"] authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
readme = "README.md" readme = "README.md"

View File

@ -2,7 +2,7 @@ from peewee import DatabaseProxy
DEFAULT_ALLOWED_METHODS = ["POST", "GET", "PUT", "PATCH", "DELETE"] DEFAULT_ALLOWED_METHODS = ["POST", "GET", "PUT", "PATCH", "DELETE"]
DEFAULT_ENCODING = "UTF-8" DEFAULT_ENCODING = "UTF-8"
__version__ = "1.0.0" __version__ = "1.1.0"
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"

View File

@ -90,3 +90,7 @@ class NoResponseError(SpiderwebException):
class StartupErrors(ExceptionGroup): class StartupErrors(ExceptionGroup):
pass pass
class ReverseNotFound(SpiderwebException):
pass

View File

@ -101,7 +101,10 @@ class CorsMiddleware(SpiderwebMiddleware):
) )
if ( if (
self.server.cors_allow_private_network self.server.cors_allow_private_network
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK.replace("-", "_")) == "true" and request.headers.get(
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK.replace("-", "_")
)
== "true"
): ):
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true" response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"

View File

@ -4,7 +4,13 @@ from typing import Callable, Any, Optional, Sequence
from spiderweb.constants import DEFAULT_ALLOWED_METHODS from spiderweb.constants import DEFAULT_ALLOWED_METHODS
from spiderweb.converters import * # noqa: F403 from spiderweb.converters import * # noqa: F403
from spiderweb.default_views import * # noqa: F403 from spiderweb.default_views import * # noqa: F403
from spiderweb.exceptions import NotFound, ConfigError, ParseError from spiderweb.exceptions import (
NotFound,
ConfigError,
ParseError,
SpiderwebException,
ReverseNotFound,
)
from spiderweb.response import RedirectResponse from spiderweb.response import RedirectResponse
@ -35,7 +41,7 @@ class RoutesMixin:
error_routes: dict[int, Callable] error_routes: dict[int, Callable]
append_slash: bool append_slash: bool
def route(self, path, allowed_methods=None) -> Callable: def route(self, path, allowed_methods=None, name=None) -> Callable:
""" """
Decorator for adding a route to a view. Decorator for adding a route to a view.
@ -49,11 +55,12 @@ class RoutesMixin:
:param path: str :param path: str
:param allowed_methods: list[str] :param allowed_methods: list[str]
:param name: str
:return: Callable :return: Callable
""" """
def outer(func): def outer(func):
self.add_route(path, func, allowed_methods) self.add_route(path, func, allowed_methods, name)
return func return func
return outer return outer
@ -115,7 +122,11 @@ class RoutesMixin:
return re.compile(rf"^{'/'.join(parts)}$") return re.compile(rf"^{'/'.join(parts)}$")
def add_route( def add_route(
self, path: str, method: Callable, allowed_methods: None | list[str] = None self,
path: str,
method: Callable,
allowed_methods: None | list[str] = None,
name: str = None,
): ):
"""Add a route to the server.""" """Add a route to the server."""
allowed_methods = ( allowed_methods = (
@ -124,24 +135,27 @@ class RoutesMixin:
or DEFAULT_ALLOWED_METHODS or DEFAULT_ALLOWED_METHODS
) )
reverse_path = re.sub(r"<(.*?):(.*?)>", r"{\2}", path) if "<" in path else path
def get_packet(func):
return {
"func": func,
"allowed_methods": allowed_methods,
"name": name,
"reverse": reverse_path,
}
if self.append_slash and not path.endswith("/"): if self.append_slash and not path.endswith("/"):
updated_path = path + "/" updated_path = path + "/"
self.check_for_route_duplicates(updated_path) self.check_for_route_duplicates(updated_path)
self.check_for_route_duplicates(path) self.check_for_route_duplicates(path)
self._routes[self.convert_path(path)] = { self._routes[self.convert_path(path)] = get_packet(
"func": DummyRedirectRoute(updated_path), DummyRedirectRoute(updated_path)
"allowed_methods": allowed_methods, )
} self._routes[self.convert_path(updated_path)] = get_packet(method)
self._routes[self.convert_path(updated_path)] = {
"func": method,
"allowed_methods": allowed_methods,
}
else: else:
self.check_for_route_duplicates(path) self.check_for_route_duplicates(path)
self._routes[self.convert_path(path)] = { self._routes[self.convert_path(path)] = get_packet(method)
"func": method,
"allowed_methods": allowed_methods,
}
def add_routes(self): def add_routes(self):
for line in self.routes: for line in self.routes:
@ -156,3 +170,27 @@ class RoutesMixin:
def add_error_routes(self): def add_error_routes(self):
for code, func in self.error_routes.items(): for code, func in self.error_routes.items():
self.add_error_route(int(code), func) self.add_error_route(int(code), func)
def reverse(
self, view_name: str, data: dict[str, Any] = None, query: dict[str, Any] = None
) -> str:
# take in a view name and return the path
for option in self._routes.values():
if option["name"] == view_name:
path = option["reverse"]
if args := re.findall(r"{(.*?)}", path):
if not data:
raise SpiderwebException(
f"Missing arguments for reverse: {args}"
)
for arg in args:
if arg not in data:
raise SpiderwebException(
f"Missing argument '{arg}' for reverse."
)
path = path.replace(f"{{{arg}}}", str(data[arg]))
if query:
path += "?" + "&".join([f"{k}={str(v)}" for k, v in query.items()])
return path
raise ReverseNotFound(f"View '{view_name}' not found.")

View File

@ -15,4 +15,4 @@ class ExplodingResponseMiddleware(SpiderwebMiddleware):
class InterruptingMiddleware(SpiderwebMiddleware): class InterruptingMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse: def process_request(self, request: Request) -> HttpResponse:
return HttpResponse("Moo!") return HttpResponse("Moo!")

View File

@ -21,7 +21,9 @@ from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import ( from spiderweb.tests.views_for_tests import (
form_view_with_csrf, form_view_with_csrf,
form_csrf_exempt, form_csrf_exempt,
form_view_without_csrf, text_view, unauthorized_view, form_view_without_csrf,
text_view,
unauthorized_view,
) )
@ -371,11 +373,14 @@ class TestCorsMiddleware:
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware,
cors_allow_all_origins=True, cors_allow_all_origins=True,
cors_expose_headers=["accept", "content-type"] cors_expose_headers=["accept", "content-type"],
) )
environ["HTTP_ORIGIN"] = "https://example.com" environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response) app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS] == "accept, content-type" assert (
start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS]
== "accept, content-type"
)
def test_get_dont_expose_headers(self): def test_get_dont_expose_headers(self):
app, environ, start_response = setup( app, environ, start_response = setup(
@ -419,18 +424,20 @@ class TestCorsMiddleware:
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware,
cors_allow_private_network=True, cors_allow_private_network=True,
cors_allow_all_origins=True cors_allow_all_origins=True,
) )
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true" environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.com" environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response) app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true" assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true"
)
def test_allow_private_network_not_added_if_enabled_and_not_requested(self): def test_allow_private_network_not_added_if_enabled_and_not_requested(self):
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware,
cors_allow_private_network=True, cors_allow_private_network=True,
cors_allow_all_origins=True cors_allow_all_origins=True,
) )
environ["HTTP_ORIGIN"] = "http://example.com" environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response) app(environ, start_response)
@ -440,19 +447,18 @@ class TestCorsMiddleware:
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware,
cors_allow_private_network=True, cors_allow_private_network=True,
cors_allowed_origins=["http://example.com"] cors_allowed_origins=["http://example.com"],
) )
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true" environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.org" environ["HTTP_ORIGIN"] = "http://example.org"
app(environ, start_response) app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers() assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_allow_private_network_not_added_if_disabled_and_requested(self): def test_allow_private_network_not_added_if_disabled_and_requested(self):
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware,
cors_allow_private_network=False, cors_allow_private_network=False,
cors_allow_all_origins=True cors_allow_all_origins=True,
) )
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true" environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.com" environ["HTTP_ORIGIN"] = "http://example.com"
@ -465,7 +471,7 @@ class TestCorsMiddleware:
cors_allow_headers=["content-type"], cors_allow_headers=["content-type"],
cors_allow_methods=["GET", "OPTIONS"], cors_allow_methods=["GET", "OPTIONS"],
cors_preflight_max_age=1002, cors_preflight_max_age=1002,
cors_allow_all_origins=True cors_allow_all_origins=True,
) )
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com" environ["HTTP_ORIGIN"] = "https://example.com"
@ -474,26 +480,24 @@ class TestCorsMiddleware:
headers = start_response.get_headers() headers = start_response.get_headers()
assert start_response.status == '200 OK' assert start_response.status == "200 OK"
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type" assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS" assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
assert headers[ACCESS_CONTROL_MAX_AGE] == "1002" assert headers[ACCESS_CONTROL_MAX_AGE] == "1002"
def test_options_no_max_age(self): def test_options_no_max_age(self):
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware,
cors_allow_headers=["content-type"], cors_allow_headers=["content-type"],
cors_allow_methods=["GET", "OPTIONS"], cors_allow_methods=["GET", "OPTIONS"],
cors_preflight_max_age=0, cors_preflight_max_age=0,
cors_allow_all_origins=True cors_allow_all_origins=True,
) )
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com" environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS" environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response) app(environ, start_response)
headers = start_response.get_headers() headers = start_response.get_headers()
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type" assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS" assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
@ -501,34 +505,39 @@ class TestCorsMiddleware:
def test_options_allowed_origins_with_port(self): def test_options_allowed_origins_with_port(self):
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware, cors_allowed_origins=["https://localhost:9000"]
cors_allowed_origins=["https://localhost:9000"]
) )
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://localhost:9000" environ["HTTP_ORIGIN"] = "https://localhost:9000"
environ["REQUEST_METHOD"] = "OPTIONS" environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response) app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://localhost:9000" assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://localhost:9000"
)
def test_options_adds_origin_when_domain_found_in_allowed_regexes(self): def test_options_adds_origin_when_domain_found_in_allowed_regexes(self):
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware,
cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"] cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"],
) )
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com" environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS" environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response) app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com" assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://foo.example.com"
)
def test_options_adds_origin_when_domain_found_in_allowed_regexes_second(self): def test_options_adds_origin_when_domain_found_in_allowed_regexes_second(self):
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware,
cors_allowed_origin_regexes=[ cors_allowed_origin_regexes=[
r"^https://\w+\.example\.org$", r"^https://\w+\.example\.org$",
r"^https://\w+\.example\.com$", r"^https://\w+\.example\.com$",
], ],
) )
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
@ -536,7 +545,10 @@ class TestCorsMiddleware:
environ["REQUEST_METHOD"] = "OPTIONS" environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response) app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com" assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://foo.example.com"
)
def test_options_doesnt_add_origin_when_domain_not_found_in_allowed_regexes(self): def test_options_doesnt_add_origin_when_domain_not_found_in_allowed_regexes(self):
app, environ, start_response = setup( app, environ, start_response = setup(
@ -562,14 +574,9 @@ class TestCorsMiddleware:
assert start_response.status == "200 OK" assert start_response.status == "200 OK"
def test_options_no_headers(self): def test_options_no_headers(self):
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware, cors_allow_all_origins=True, routes=[("/", text_view)]
cors_allow_all_origins=True,
routes=[
("/", text_view)
]
) )
environ["REQUEST_METHOD"] = "OPTIONS" environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response) app(environ, start_response)
@ -580,7 +587,7 @@ class TestCorsMiddleware:
**self.middleware, **self.middleware,
cors_allow_credentials=True, cors_allow_credentials=True,
cors_allow_all_origins=True, cors_allow_all_origins=True,
routes=[("/", text_view)] routes=[("/", text_view)],
) )
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com" environ["HTTP_ORIGIN"] = "https://example.com"
@ -588,7 +595,10 @@ class TestCorsMiddleware:
app(environ, start_response) app(environ, start_response)
assert start_response.status == "200 OK" assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
assert start_response.get_headers()["vary"] == "origin" assert start_response.get_headers()["vary"] == "origin"
def test_allow_all_origins_options(self): def test_allow_all_origins_options(self):
@ -596,7 +606,7 @@ class TestCorsMiddleware:
**self.middleware, **self.middleware,
cors_allow_credentials=True, cors_allow_credentials=True,
cors_allow_all_origins=True, cors_allow_all_origins=True,
routes=[("/", text_view)] routes=[("/", text_view)],
) )
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
@ -605,7 +615,10 @@ class TestCorsMiddleware:
app(environ, start_response) app(environ, start_response)
assert start_response.status == "200 OK" assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
assert start_response.get_headers()["vary"] == "origin" assert start_response.get_headers()["vary"] == "origin"
def test_non_200_headers_still_set(self): def test_non_200_headers_still_set(self):
@ -618,15 +631,17 @@ class TestCorsMiddleware:
**self.middleware, **self.middleware,
cors_allow_credentials=True, cors_allow_credentials=True,
cors_allow_all_origins=True, cors_allow_all_origins=True,
routes=[("/unauthorized", unauthorized_view)] routes=[("/unauthorized", unauthorized_view)],
) )
environ["HTTP_ORIGIN"] = "https://example.com" environ["HTTP_ORIGIN"] = "https://example.com"
environ["PATH_INFO"] = "/unauthorized" environ["PATH_INFO"] = "/unauthorized"
app(environ, start_response) app(environ, start_response)
assert start_response.status == "401 Unauthorized" assert start_response.status == "401 Unauthorized"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
def test_auth_view_options(self): def test_auth_view_options(self):
""" """
@ -637,7 +652,7 @@ class TestCorsMiddleware:
**self.middleware, **self.middleware,
cors_allow_credentials=True, cors_allow_credentials=True,
cors_allow_all_origins=True, cors_allow_all_origins=True,
routes=[("/unauthorized", unauthorized_view)] routes=[("/unauthorized", unauthorized_view)],
) )
environ["HTTP_ORIGIN"] = "https://example.com" environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
@ -646,10 +661,12 @@ class TestCorsMiddleware:
app(environ, start_response) app(environ, start_response)
assert start_response.status == "200 OK" assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
assert start_response.get_headers()["content-length"] == "0" assert start_response.get_headers()["content-length"] == "0"
def test_get_short_circuit(self): def test_get_short_circuit(self):
""" """
Test a scenario when a middleware that returns a response is run before Test a scenario when a middleware that returns a response is run before
@ -716,6 +733,7 @@ class TestCorsMiddleware:
Just in case something crazy happens in the view or other middleware, Just in case something crazy happens in the view or other middleware,
check that get_response doesn't fall over if `_cors_enabled` is removed check that get_response doesn't fall over if `_cors_enabled` is removed
""" """
def yeet(request): def yeet(request):
del request._cors_enabled del request._cors_enabled
return HttpResponse("hahaha") return HttpResponse("hahaha")
@ -723,7 +741,7 @@ class TestCorsMiddleware:
app, environ, start_response = setup( app, environ, start_response = setup(
**self.middleware, **self.middleware,
cors_allowed_origins=["https://example.com"], cors_allowed_origins=["https://example.com"],
routes=[('/yeet', yeet)] routes=[("/yeet", yeet)],
) )
environ["HTTP_ORIGIN"] = "https://example.com" environ["HTTP_ORIGIN"] = "https://example.com"

View File

@ -1,8 +1,13 @@
import pytest import pytest
from spiderweb import SpiderwebRouter, ConfigError from spiderweb import ConfigError
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.exceptions import NoResponseError, SpiderwebNetworkException from spiderweb.exceptions import (
NoResponseError,
SpiderwebNetworkException,
SpiderwebException,
ReverseNotFound,
)
from spiderweb.response import ( from spiderweb.response import (
HttpResponse, HttpResponse,
JsonResponse, JsonResponse,
@ -176,3 +181,62 @@ def test_missing_view_with_custom_404_alt():
app, environ, start_response = setup(error_routes={404: custom_404}) app, environ, start_response = setup(error_routes={404: custom_404})
assert app(environ, start_response) == [b"Custom 404 2"] assert app(environ, start_response) == [b"Custom 404 2"]
def test_getting_nonexistent_error_view():
app, environ, start_response = setup()
assert app.get_error_route(10101).__name__ == "http500"
def test_view_gets_name():
app, environ, start_response = setup()
@app.route("/", name="asdfasdf")
def index(request): ...
assert [v for k, v in app._routes.items()][0]["name"] == "asdfasdf"
def test_view_can_be_reversed():
app, environ, start_response = setup()
@app.route("/", name="asdfasdf")
def index(request): ...
@app.route("/<int:hi>", name="qwer")
def index(request, hi): ...
assert app.reverse("asdfasdf") == "/"
assert app.reverse("asdfasdf", {"id": 1}) == "/"
assert app.reverse("asdfasdf", {"id": 1}, query={"key": "value"}) == "/?key=value"
assert app.reverse("qwer", {"hi": 1}) == "/1"
assert app.reverse("qwer", {"hi": 1}, query={"key": "value"}) == "/1?key=value"
def test_reversed_views_explode_when_missing_all_args():
app, environ, start_response = setup()
@app.route("/<int:hi>", name="qwer")
def index(request, hi): ...
with pytest.raises(SpiderwebException):
app.reverse("qwer")
def test_reversed_views_explode_when_missing_some_args():
app, environ, start_response = setup()
@app.route("/<int:hi>/<str:bye>", name="qwer")
def index(request, hi, bye): ...
with pytest.raises(SpiderwebException):
app.reverse("qwer", {"hi": 1})
def test_reverse_nonexistent_view():
app, environ, start_response = setup()
with pytest.raises(ReverseNotFound):
app.reverse("qwer")

View File

@ -46,4 +46,4 @@ def text_view(request):
def unauthorized_view(request): def unauthorized_view(request):
return HttpResponse("Unauthorized", status_code=401) return HttpResponse("Unauthorized", status_code=401)