add app.reverse() function

This commit is contained in:
Joe Kaufeld 2024-09-09 01:43:55 -04:00
parent aa41df4577
commit c9f3129b02
8 changed files with 215 additions and 61 deletions

View File

@ -149,4 +149,31 @@ app = SpiderwebRouter(
error_routes={405: http405},
)
```
As with the `routes` argument, as many routes as you'd like can be registered here without issue.
As with the `routes` argument, as many routes as you'd like can be registered here without issue.
## Finding Routes Again
> New in 1.1.0
If you need to find the path that's associated with a route (for example, for a RedirectResponse), you can use the `app.reverse()` function to find it. This function takes the name of the view and returns the path that it's associated with. For example:
```python
@app.route("/example", name="example")
def example(request):
return HttpResponse(body="Example")
path = app.reverse("example")
print(path) # -> "/example"
```
If you have a route that takes arguments, you can pass them in as keyword arguments:
```python
@app.route("/example/<int:id>", name="example")
def example(request, id):
return HttpResponse(body=f"Example with id {id}")
path = app.reverse("example", id=3)
print(path) # -> "/example/3"
```
The arguments you pass in must match what the path expects, or you'll get a `SpiderwebException`. If there's no route with that name, you'll get a `ReverseNotFound` exception instead.

View File

@ -90,3 +90,7 @@ class NoResponseError(SpiderwebException):
class StartupErrors(ExceptionGroup):
pass
class ReverseNotFound(SpiderwebException):
pass

View File

@ -101,7 +101,10 @@ class CorsMiddleware(SpiderwebMiddleware):
)
if (
self.server.cors_allow_private_network
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK.replace("-", "_")) == "true"
and request.headers.get(
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK.replace("-", "_")
)
== "true"
):
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"

View File

@ -4,7 +4,13 @@ from typing import Callable, Any, Optional, Sequence
from spiderweb.constants import DEFAULT_ALLOWED_METHODS
from spiderweb.converters import * # noqa: F403
from spiderweb.default_views import * # noqa: F403
from spiderweb.exceptions import NotFound, ConfigError, ParseError
from spiderweb.exceptions import (
NotFound,
ConfigError,
ParseError,
SpiderwebException,
ReverseNotFound,
)
from spiderweb.response import RedirectResponse
@ -35,7 +41,7 @@ class RoutesMixin:
error_routes: dict[int, Callable]
append_slash: bool
def route(self, path, allowed_methods=None) -> Callable:
def route(self, path, allowed_methods=None, name=None) -> Callable:
"""
Decorator for adding a route to a view.
@ -49,11 +55,12 @@ class RoutesMixin:
:param path: str
:param allowed_methods: list[str]
:param name: str
:return: Callable
"""
def outer(func):
self.add_route(path, func, allowed_methods)
self.add_route(path, func, allowed_methods, name)
return func
return outer
@ -115,7 +122,11 @@ class RoutesMixin:
return re.compile(rf"^{'/'.join(parts)}$")
def add_route(
self, path: str, method: Callable, allowed_methods: None | list[str] = None
self,
path: str,
method: Callable,
allowed_methods: None | list[str] = None,
name: str = None,
):
"""Add a route to the server."""
allowed_methods = (
@ -124,24 +135,27 @@ class RoutesMixin:
or DEFAULT_ALLOWED_METHODS
)
reverse_path = re.sub(r"<(.*?):(.*?)>", r"{\2}", path) if "<" in path else path
def get_packet(func):
return {
"func": func,
"allowed_methods": allowed_methods,
"name": name,
"reverse": reverse_path,
}
if self.append_slash and not path.endswith("/"):
updated_path = path + "/"
self.check_for_route_duplicates(updated_path)
self.check_for_route_duplicates(path)
self._routes[self.convert_path(path)] = {
"func": DummyRedirectRoute(updated_path),
"allowed_methods": allowed_methods,
}
self._routes[self.convert_path(updated_path)] = {
"func": method,
"allowed_methods": allowed_methods,
}
self._routes[self.convert_path(path)] = get_packet(
DummyRedirectRoute(updated_path)
)
self._routes[self.convert_path(updated_path)] = get_packet(method)
else:
self.check_for_route_duplicates(path)
self._routes[self.convert_path(path)] = {
"func": method,
"allowed_methods": allowed_methods,
}
self._routes[self.convert_path(path)] = get_packet(method)
def add_routes(self):
for line in self.routes:
@ -156,3 +170,27 @@ class RoutesMixin:
def add_error_routes(self):
for code, func in self.error_routes.items():
self.add_error_route(int(code), func)
def reverse(
self, view_name: str, data: dict[str, Any] = None, query: dict[str, Any] = None
) -> str:
# take in a view name and return the path
for option in self._routes.values():
if option["name"] == view_name:
path = option["reverse"]
if args := re.findall(r"{(.*?)}", path):
if not data:
raise SpiderwebException(
f"Missing arguments for reverse: {args}"
)
for arg in args:
if arg not in data:
raise SpiderwebException(
f"Missing argument '{arg}' for reverse."
)
path = path.replace(f"{{{arg}}}", str(data[arg]))
if query:
path += "?" + "&".join([f"{k}={str(v)}" for k, v in query.items()])
return path
raise ReverseNotFound(f"View '{view_name}' not found.")

View File

@ -15,4 +15,4 @@ class ExplodingResponseMiddleware(SpiderwebMiddleware):
class InterruptingMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse:
return HttpResponse("Moo!")
return HttpResponse("Moo!")

View File

@ -21,7 +21,9 @@ from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import (
form_view_with_csrf,
form_csrf_exempt,
form_view_without_csrf, text_view, unauthorized_view,
form_view_without_csrf,
text_view,
unauthorized_view,
)
@ -371,11 +373,14 @@ class TestCorsMiddleware:
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
cors_expose_headers=["accept", "content-type"]
cors_expose_headers=["accept", "content-type"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS] == "accept, content-type"
assert (
start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS]
== "accept, content-type"
)
def test_get_dont_expose_headers(self):
app, environ, start_response = setup(
@ -419,18 +424,20 @@ class TestCorsMiddleware:
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allow_all_origins=True
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true"
)
def test_allow_private_network_not_added_if_enabled_and_not_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allow_all_origins=True
cors_allow_all_origins=True,
)
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
@ -440,19 +447,18 @@ class TestCorsMiddleware:
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allowed_origins=["http://example.com"]
cors_allowed_origins=["http://example.com"],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_allow_private_network_not_added_if_disabled_and_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=False,
cors_allow_all_origins=True
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.com"
@ -465,7 +471,7 @@ class TestCorsMiddleware:
cors_allow_headers=["content-type"],
cors_allow_methods=["GET", "OPTIONS"],
cors_preflight_max_age=1002,
cors_allow_all_origins=True
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
@ -474,26 +480,24 @@ class TestCorsMiddleware:
headers = start_response.get_headers()
assert start_response.status == '200 OK'
assert start_response.status == "200 OK"
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
assert headers[ACCESS_CONTROL_MAX_AGE] == "1002"
def test_options_no_max_age(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_headers=["content-type"],
cors_allow_methods=["GET", "OPTIONS"],
cors_preflight_max_age=0,
cors_allow_all_origins=True
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
headers = start_response.get_headers()
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
@ -501,34 +505,39 @@ class TestCorsMiddleware:
def test_options_allowed_origins_with_port(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://localhost:9000"]
**self.middleware, cors_allowed_origins=["https://localhost:9000"]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://localhost:9000"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://localhost:9000"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://localhost:9000"
)
def test_options_adds_origin_when_domain_found_in_allowed_regexes(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"]
cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://foo.example.com"
)
def test_options_adds_origin_when_domain_found_in_allowed_regexes_second(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[
r"^https://\w+\.example\.org$",
r"^https://\w+\.example\.com$",
r"^https://\w+\.example\.org$",
r"^https://\w+\.example\.com$",
],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
@ -536,7 +545,10 @@ class TestCorsMiddleware:
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://foo.example.com"
)
def test_options_doesnt_add_origin_when_domain_not_found_in_allowed_regexes(self):
app, environ, start_response = setup(
@ -562,14 +574,9 @@ class TestCorsMiddleware:
assert start_response.status == "200 OK"
def test_options_no_headers(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
routes=[
("/", text_view)
]
**self.middleware, cors_allow_all_origins=True, routes=[("/", text_view)]
)
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
@ -580,7 +587,7 @@ class TestCorsMiddleware:
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/", text_view)]
routes=[("/", text_view)],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
@ -588,7 +595,10 @@ class TestCorsMiddleware:
app(environ, start_response)
assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
assert start_response.get_headers()["vary"] == "origin"
def test_allow_all_origins_options(self):
@ -596,7 +606,7 @@ class TestCorsMiddleware:
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/", text_view)]
routes=[("/", text_view)],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
@ -605,7 +615,10 @@ class TestCorsMiddleware:
app(environ, start_response)
assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
assert start_response.get_headers()["vary"] == "origin"
def test_non_200_headers_still_set(self):
@ -618,15 +631,17 @@ class TestCorsMiddleware:
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/unauthorized", unauthorized_view)]
routes=[("/unauthorized", unauthorized_view)],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["PATH_INFO"] = "/unauthorized"
app(environ, start_response)
assert start_response.status == "401 Unauthorized"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
def test_auth_view_options(self):
"""
@ -637,7 +652,7 @@ class TestCorsMiddleware:
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/unauthorized", unauthorized_view)]
routes=[("/unauthorized", unauthorized_view)],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
@ -646,10 +661,12 @@ class TestCorsMiddleware:
app(environ, start_response)
assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
assert start_response.get_headers()["content-length"] == "0"
def test_get_short_circuit(self):
"""
Test a scenario when a middleware that returns a response is run before
@ -716,6 +733,7 @@ class TestCorsMiddleware:
Just in case something crazy happens in the view or other middleware,
check that get_response doesn't fall over if `_cors_enabled` is removed
"""
def yeet(request):
del request._cors_enabled
return HttpResponse("hahaha")
@ -723,7 +741,7 @@ class TestCorsMiddleware:
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
routes=[('/yeet', yeet)]
routes=[("/yeet", yeet)],
)
environ["HTTP_ORIGIN"] = "https://example.com"

View File

@ -1,8 +1,13 @@
import pytest
from spiderweb import SpiderwebRouter, ConfigError
from spiderweb import ConfigError
from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.exceptions import NoResponseError, SpiderwebNetworkException
from spiderweb.exceptions import (
NoResponseError,
SpiderwebNetworkException,
SpiderwebException,
ReverseNotFound,
)
from spiderweb.response import (
HttpResponse,
JsonResponse,
@ -176,3 +181,62 @@ def test_missing_view_with_custom_404_alt():
app, environ, start_response = setup(error_routes={404: custom_404})
assert app(environ, start_response) == [b"Custom 404 2"]
def test_getting_nonexistent_error_view():
app, environ, start_response = setup()
assert app.get_error_route(10101).__name__ == "http500"
def test_view_gets_name():
app, environ, start_response = setup()
@app.route("/", name="asdfasdf")
def index(request): ...
assert [v for k, v in app._routes.items()][0]["name"] == "asdfasdf"
def test_view_can_be_reversed():
app, environ, start_response = setup()
@app.route("/", name="asdfasdf")
def index(request): ...
@app.route("/<int:hi>", name="qwer")
def index(request, hi): ...
assert app.reverse("asdfasdf") == "/"
assert app.reverse("asdfasdf", {"id": 1}) == "/"
assert app.reverse("asdfasdf", {"id": 1}, query={"key": "value"}) == "/?key=value"
assert app.reverse("qwer", {"hi": 1}) == "/1"
assert app.reverse("qwer", {"hi": 1}, query={"key": "value"}) == "/1?key=value"
def test_reversed_views_explode_when_missing_all_args():
app, environ, start_response = setup()
@app.route("/<int:hi>", name="qwer")
def index(request, hi): ...
with pytest.raises(SpiderwebException):
app.reverse("qwer")
def test_reversed_views_explode_when_missing_some_args():
app, environ, start_response = setup()
@app.route("/<int:hi>/<str:bye>", name="qwer")
def index(request, hi, bye): ...
with pytest.raises(SpiderwebException):
app.reverse("qwer", {"hi": 1})
def test_reverse_nonexistent_view():
app, environ, start_response = setup()
with pytest.raises(ReverseNotFound):
app.reverse("qwer")

View File

@ -46,4 +46,4 @@ def text_view(request):
def unauthorized_view(request):
return HttpResponse("Unauthorized", status_code=401)
return HttpResponse("Unauthorized", status_code=401)