From c9f3129b023222cc243d3f85e26c195b91c950bf Mon Sep 17 00:00:00 2001 From: Joe Kaufeld Date: Mon, 9 Sep 2024 01:43:55 -0400 Subject: [PATCH] :sparkles: add app.reverse() function --- docs/routes.md | 29 ++++++++- spiderweb/exceptions.py | 4 ++ spiderweb/middleware/cors.py | 5 +- spiderweb/routes.py | 70 +++++++++++++++++----- spiderweb/tests/middleware.py | 2 +- spiderweb/tests/test_middleware.py | 96 ++++++++++++++++++------------ spiderweb/tests/test_responses.py | 68 ++++++++++++++++++++- spiderweb/tests/views_for_tests.py | 2 +- 8 files changed, 215 insertions(+), 61 deletions(-) diff --git a/docs/routes.md b/docs/routes.md index 469e05d..6a43761 100644 --- a/docs/routes.md +++ b/docs/routes.md @@ -149,4 +149,31 @@ app = SpiderwebRouter( error_routes={405: http405}, ) ``` -As with the `routes` argument, as many routes as you'd like can be registered here without issue. \ No newline at end of file +As with the `routes` argument, as many routes as you'd like can be registered here without issue. + +## Finding Routes Again + +> New in 1.1.0 + +If you need to find the path that's associated with a route (for example, for a RedirectResponse), you can use the `app.reverse()` function to find it. This function takes the name of the view and returns the path that it's associated with. For example: + +```python +@app.route("/example", name="example") +def example(request): + return HttpResponse(body="Example") + +path = app.reverse("example") +print(path) # -> "/example" +``` + +If you have a route that takes arguments, you can pass them in as keyword arguments: + +```python +@app.route("/example/", name="example") +def example(request, id): + return HttpResponse(body=f"Example with id {id}") + +path = app.reverse("example", id=3) +print(path) # -> "/example/3" +``` +The arguments you pass in must match what the path expects, or you'll get a `SpiderwebException`. If there's no route with that name, you'll get a `ReverseNotFound` exception instead. \ No newline at end of file diff --git a/spiderweb/exceptions.py b/spiderweb/exceptions.py index f784c23..9ded4dd 100644 --- a/spiderweb/exceptions.py +++ b/spiderweb/exceptions.py @@ -90,3 +90,7 @@ class NoResponseError(SpiderwebException): class StartupErrors(ExceptionGroup): pass + + +class ReverseNotFound(SpiderwebException): + pass diff --git a/spiderweb/middleware/cors.py b/spiderweb/middleware/cors.py index a7e8672..1c98f30 100644 --- a/spiderweb/middleware/cors.py +++ b/spiderweb/middleware/cors.py @@ -101,7 +101,10 @@ class CorsMiddleware(SpiderwebMiddleware): ) if ( self.server.cors_allow_private_network - and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK.replace("-", "_")) == "true" + and request.headers.get( + ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK.replace("-", "_") + ) + == "true" ): response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true" diff --git a/spiderweb/routes.py b/spiderweb/routes.py index 3622d8b..a22a21a 100644 --- a/spiderweb/routes.py +++ b/spiderweb/routes.py @@ -4,7 +4,13 @@ from typing import Callable, Any, Optional, Sequence from spiderweb.constants import DEFAULT_ALLOWED_METHODS from spiderweb.converters import * # noqa: F403 from spiderweb.default_views import * # noqa: F403 -from spiderweb.exceptions import NotFound, ConfigError, ParseError +from spiderweb.exceptions import ( + NotFound, + ConfigError, + ParseError, + SpiderwebException, + ReverseNotFound, +) from spiderweb.response import RedirectResponse @@ -35,7 +41,7 @@ class RoutesMixin: error_routes: dict[int, Callable] append_slash: bool - def route(self, path, allowed_methods=None) -> Callable: + def route(self, path, allowed_methods=None, name=None) -> Callable: """ Decorator for adding a route to a view. @@ -49,11 +55,12 @@ class RoutesMixin: :param path: str :param allowed_methods: list[str] + :param name: str :return: Callable """ def outer(func): - self.add_route(path, func, allowed_methods) + self.add_route(path, func, allowed_methods, name) return func return outer @@ -115,7 +122,11 @@ class RoutesMixin: return re.compile(rf"^{'/'.join(parts)}$") def add_route( - self, path: str, method: Callable, allowed_methods: None | list[str] = None + self, + path: str, + method: Callable, + allowed_methods: None | list[str] = None, + name: str = None, ): """Add a route to the server.""" allowed_methods = ( @@ -124,24 +135,27 @@ class RoutesMixin: or DEFAULT_ALLOWED_METHODS ) + reverse_path = re.sub(r"<(.*?):(.*?)>", r"{\2}", path) if "<" in path else path + + def get_packet(func): + return { + "func": func, + "allowed_methods": allowed_methods, + "name": name, + "reverse": reverse_path, + } + if self.append_slash and not path.endswith("/"): updated_path = path + "/" self.check_for_route_duplicates(updated_path) self.check_for_route_duplicates(path) - self._routes[self.convert_path(path)] = { - "func": DummyRedirectRoute(updated_path), - "allowed_methods": allowed_methods, - } - self._routes[self.convert_path(updated_path)] = { - "func": method, - "allowed_methods": allowed_methods, - } + self._routes[self.convert_path(path)] = get_packet( + DummyRedirectRoute(updated_path) + ) + self._routes[self.convert_path(updated_path)] = get_packet(method) else: self.check_for_route_duplicates(path) - self._routes[self.convert_path(path)] = { - "func": method, - "allowed_methods": allowed_methods, - } + self._routes[self.convert_path(path)] = get_packet(method) def add_routes(self): for line in self.routes: @@ -156,3 +170,27 @@ class RoutesMixin: def add_error_routes(self): for code, func in self.error_routes.items(): self.add_error_route(int(code), func) + + def reverse( + self, view_name: str, data: dict[str, Any] = None, query: dict[str, Any] = None + ) -> str: + # take in a view name and return the path + for option in self._routes.values(): + if option["name"] == view_name: + path = option["reverse"] + if args := re.findall(r"{(.*?)}", path): + if not data: + raise SpiderwebException( + f"Missing arguments for reverse: {args}" + ) + for arg in args: + if arg not in data: + raise SpiderwebException( + f"Missing argument '{arg}' for reverse." + ) + path = path.replace(f"{{{arg}}}", str(data[arg])) + + if query: + path += "?" + "&".join([f"{k}={str(v)}" for k, v in query.items()]) + return path + raise ReverseNotFound(f"View '{view_name}' not found.") diff --git a/spiderweb/tests/middleware.py b/spiderweb/tests/middleware.py index 1c8f9da..a9950a1 100644 --- a/spiderweb/tests/middleware.py +++ b/spiderweb/tests/middleware.py @@ -15,4 +15,4 @@ class ExplodingResponseMiddleware(SpiderwebMiddleware): class InterruptingMiddleware(SpiderwebMiddleware): def process_request(self, request: Request) -> HttpResponse: - return HttpResponse("Moo!") \ No newline at end of file + return HttpResponse("Moo!") diff --git a/spiderweb/tests/test_middleware.py b/spiderweb/tests/test_middleware.py index ea4b517..ac558d6 100644 --- a/spiderweb/tests/test_middleware.py +++ b/spiderweb/tests/test_middleware.py @@ -21,7 +21,9 @@ from spiderweb.tests.helpers import setup from spiderweb.tests.views_for_tests import ( form_view_with_csrf, form_csrf_exempt, - form_view_without_csrf, text_view, unauthorized_view, + form_view_without_csrf, + text_view, + unauthorized_view, ) @@ -371,11 +373,14 @@ class TestCorsMiddleware: app, environ, start_response = setup( **self.middleware, cors_allow_all_origins=True, - cors_expose_headers=["accept", "content-type"] + cors_expose_headers=["accept", "content-type"], ) environ["HTTP_ORIGIN"] = "https://example.com" app(environ, start_response) - assert start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS] == "accept, content-type" + assert ( + start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS] + == "accept, content-type" + ) def test_get_dont_expose_headers(self): app, environ, start_response = setup( @@ -419,18 +424,20 @@ class TestCorsMiddleware: app, environ, start_response = setup( **self.middleware, cors_allow_private_network=True, - cors_allow_all_origins=True + cors_allow_all_origins=True, ) environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true" environ["HTTP_ORIGIN"] = "http://example.com" app(environ, start_response) - assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true" + assert ( + start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true" + ) def test_allow_private_network_not_added_if_enabled_and_not_requested(self): app, environ, start_response = setup( **self.middleware, cors_allow_private_network=True, - cors_allow_all_origins=True + cors_allow_all_origins=True, ) environ["HTTP_ORIGIN"] = "http://example.com" app(environ, start_response) @@ -440,19 +447,18 @@ class TestCorsMiddleware: app, environ, start_response = setup( **self.middleware, cors_allow_private_network=True, - cors_allowed_origins=["http://example.com"] + cors_allowed_origins=["http://example.com"], ) environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true" environ["HTTP_ORIGIN"] = "http://example.org" app(environ, start_response) assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers() - def test_allow_private_network_not_added_if_disabled_and_requested(self): app, environ, start_response = setup( **self.middleware, cors_allow_private_network=False, - cors_allow_all_origins=True + cors_allow_all_origins=True, ) environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true" environ["HTTP_ORIGIN"] = "http://example.com" @@ -465,7 +471,7 @@ class TestCorsMiddleware: cors_allow_headers=["content-type"], cors_allow_methods=["GET", "OPTIONS"], cors_preflight_max_age=1002, - cors_allow_all_origins=True + cors_allow_all_origins=True, ) environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ORIGIN"] = "https://example.com" @@ -474,26 +480,24 @@ class TestCorsMiddleware: headers = start_response.get_headers() - assert start_response.status == '200 OK' + assert start_response.status == "200 OK" assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type" assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS" assert headers[ACCESS_CONTROL_MAX_AGE] == "1002" - def test_options_no_max_age(self): app, environ, start_response = setup( **self.middleware, cors_allow_headers=["content-type"], cors_allow_methods=["GET", "OPTIONS"], cors_preflight_max_age=0, - cors_allow_all_origins=True + cors_allow_all_origins=True, ) environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ORIGIN"] = "https://example.com" environ["REQUEST_METHOD"] = "OPTIONS" app(environ, start_response) - headers = start_response.get_headers() assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type" assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS" @@ -501,34 +505,39 @@ class TestCorsMiddleware: def test_options_allowed_origins_with_port(self): app, environ, start_response = setup( - **self.middleware, - cors_allowed_origins=["https://localhost:9000"] + **self.middleware, cors_allowed_origins=["https://localhost:9000"] ) environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ORIGIN"] = "https://localhost:9000" environ["REQUEST_METHOD"] = "OPTIONS" app(environ, start_response) - assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://localhost:9000" + assert ( + start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] + == "https://localhost:9000" + ) def test_options_adds_origin_when_domain_found_in_allowed_regexes(self): app, environ, start_response = setup( **self.middleware, - cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"] + cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"], ) environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ORIGIN"] = "https://foo.example.com" environ["REQUEST_METHOD"] = "OPTIONS" app(environ, start_response) - assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com" + assert ( + start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] + == "https://foo.example.com" + ) def test_options_adds_origin_when_domain_found_in_allowed_regexes_second(self): app, environ, start_response = setup( **self.middleware, cors_allowed_origin_regexes=[ - r"^https://\w+\.example\.org$", - r"^https://\w+\.example\.com$", + r"^https://\w+\.example\.org$", + r"^https://\w+\.example\.com$", ], ) environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" @@ -536,7 +545,10 @@ class TestCorsMiddleware: environ["REQUEST_METHOD"] = "OPTIONS" app(environ, start_response) - assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com" + assert ( + start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] + == "https://foo.example.com" + ) def test_options_doesnt_add_origin_when_domain_not_found_in_allowed_regexes(self): app, environ, start_response = setup( @@ -562,14 +574,9 @@ class TestCorsMiddleware: assert start_response.status == "200 OK" - def test_options_no_headers(self): app, environ, start_response = setup( - **self.middleware, - cors_allow_all_origins=True, - routes=[ - ("/", text_view) - ] + **self.middleware, cors_allow_all_origins=True, routes=[("/", text_view)] ) environ["REQUEST_METHOD"] = "OPTIONS" app(environ, start_response) @@ -580,7 +587,7 @@ class TestCorsMiddleware: **self.middleware, cors_allow_credentials=True, cors_allow_all_origins=True, - routes=[("/", text_view)] + routes=[("/", text_view)], ) environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" environ["HTTP_ORIGIN"] = "https://example.com" @@ -588,7 +595,10 @@ class TestCorsMiddleware: app(environ, start_response) assert start_response.status == "200 OK" - assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert ( + start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] + == "https://example.com" + ) assert start_response.get_headers()["vary"] == "origin" def test_allow_all_origins_options(self): @@ -596,7 +606,7 @@ class TestCorsMiddleware: **self.middleware, cors_allow_credentials=True, cors_allow_all_origins=True, - routes=[("/", text_view)] + routes=[("/", text_view)], ) environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" @@ -605,7 +615,10 @@ class TestCorsMiddleware: app(environ, start_response) assert start_response.status == "200 OK" - assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert ( + start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] + == "https://example.com" + ) assert start_response.get_headers()["vary"] == "origin" def test_non_200_headers_still_set(self): @@ -618,15 +631,17 @@ class TestCorsMiddleware: **self.middleware, cors_allow_credentials=True, cors_allow_all_origins=True, - routes=[("/unauthorized", unauthorized_view)] + routes=[("/unauthorized", unauthorized_view)], ) environ["HTTP_ORIGIN"] = "https://example.com" environ["PATH_INFO"] = "/unauthorized" app(environ, start_response) - assert start_response.status == "401 Unauthorized" - assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert ( + start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] + == "https://example.com" + ) def test_auth_view_options(self): """ @@ -637,7 +652,7 @@ class TestCorsMiddleware: **self.middleware, cors_allow_credentials=True, cors_allow_all_origins=True, - routes=[("/unauthorized", unauthorized_view)] + routes=[("/unauthorized", unauthorized_view)], ) environ["HTTP_ORIGIN"] = "https://example.com" environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" @@ -646,10 +661,12 @@ class TestCorsMiddleware: app(environ, start_response) assert start_response.status == "200 OK" - assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert ( + start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] + == "https://example.com" + ) assert start_response.get_headers()["content-length"] == "0" - def test_get_short_circuit(self): """ Test a scenario when a middleware that returns a response is run before @@ -716,6 +733,7 @@ class TestCorsMiddleware: Just in case something crazy happens in the view or other middleware, check that get_response doesn't fall over if `_cors_enabled` is removed """ + def yeet(request): del request._cors_enabled return HttpResponse("hahaha") @@ -723,7 +741,7 @@ class TestCorsMiddleware: app, environ, start_response = setup( **self.middleware, cors_allowed_origins=["https://example.com"], - routes=[('/yeet', yeet)] + routes=[("/yeet", yeet)], ) environ["HTTP_ORIGIN"] = "https://example.com" diff --git a/spiderweb/tests/test_responses.py b/spiderweb/tests/test_responses.py index d27e313..f135825 100644 --- a/spiderweb/tests/test_responses.py +++ b/spiderweb/tests/test_responses.py @@ -1,8 +1,13 @@ import pytest -from spiderweb import SpiderwebRouter, ConfigError +from spiderweb import ConfigError from spiderweb.constants import DEFAULT_ENCODING -from spiderweb.exceptions import NoResponseError, SpiderwebNetworkException +from spiderweb.exceptions import ( + NoResponseError, + SpiderwebNetworkException, + SpiderwebException, + ReverseNotFound, +) from spiderweb.response import ( HttpResponse, JsonResponse, @@ -176,3 +181,62 @@ def test_missing_view_with_custom_404_alt(): app, environ, start_response = setup(error_routes={404: custom_404}) assert app(environ, start_response) == [b"Custom 404 2"] + + +def test_getting_nonexistent_error_view(): + app, environ, start_response = setup() + + assert app.get_error_route(10101).__name__ == "http500" + + +def test_view_gets_name(): + app, environ, start_response = setup() + + @app.route("/", name="asdfasdf") + def index(request): ... + + assert [v for k, v in app._routes.items()][0]["name"] == "asdfasdf" + + +def test_view_can_be_reversed(): + app, environ, start_response = setup() + + @app.route("/", name="asdfasdf") + def index(request): ... + + @app.route("/", name="qwer") + def index(request, hi): ... + + assert app.reverse("asdfasdf") == "/" + assert app.reverse("asdfasdf", {"id": 1}) == "/" + assert app.reverse("asdfasdf", {"id": 1}, query={"key": "value"}) == "/?key=value" + + assert app.reverse("qwer", {"hi": 1}) == "/1" + assert app.reverse("qwer", {"hi": 1}, query={"key": "value"}) == "/1?key=value" + + +def test_reversed_views_explode_when_missing_all_args(): + app, environ, start_response = setup() + + @app.route("/", name="qwer") + def index(request, hi): ... + + with pytest.raises(SpiderwebException): + app.reverse("qwer") + + +def test_reversed_views_explode_when_missing_some_args(): + app, environ, start_response = setup() + + @app.route("//", name="qwer") + def index(request, hi, bye): ... + + with pytest.raises(SpiderwebException): + app.reverse("qwer", {"hi": 1}) + + +def test_reverse_nonexistent_view(): + app, environ, start_response = setup() + + with pytest.raises(ReverseNotFound): + app.reverse("qwer") diff --git a/spiderweb/tests/views_for_tests.py b/spiderweb/tests/views_for_tests.py index 62e21b1..3ae8990 100644 --- a/spiderweb/tests/views_for_tests.py +++ b/spiderweb/tests/views_for_tests.py @@ -46,4 +46,4 @@ def text_view(request): def unauthorized_view(request): - return HttpResponse("Unauthorized", status_code=401) \ No newline at end of file + return HttpResponse("Unauthorized", status_code=401)