From 675743bf8d0cc80472f2e9319409ad4935c77b49 Mon Sep 17 00:00:00 2001 From: Joe Kaufeld Date: Sun, 18 Aug 2024 17:45:38 -0400 Subject: [PATCH] :sparkles: convert to wsgi with gunicorn support --- example.py | 15 ++- example2.py | 78 ++++++++++++ poetry.lock | 61 ++++++--- pyproject.toml | 5 +- spiderweb/__init__.py | 3 +- spiderweb/constants.py | 1 + spiderweb/decorators.py | 4 + spiderweb/default_responses.py | 19 --- spiderweb/default_views.py | 33 +++++ spiderweb/exceptions.py | 1 + spiderweb/local_server.py | 56 +++++++++ spiderweb/main.py | 27 +++- spiderweb/middleware/__init__.py | 45 +++++++ spiderweb/middleware/csrf.py | 8 +- spiderweb/request.py | 63 ++++++++-- spiderweb/response.py | 7 +- spiderweb/routes.py | 158 +++++++++++++++++++++++ spiderweb/secrets.py | 26 ++++ spiderweb/utils.py | 26 ++++ spiderweb/wsgi_main.py | 208 +++++++++++++++++++++++++++++++ templates/form.html | 2 +- 21 files changed, 780 insertions(+), 66 deletions(-) create mode 100644 example2.py create mode 100644 spiderweb/decorators.py delete mode 100644 spiderweb/default_responses.py create mode 100644 spiderweb/default_views.py create mode 100644 spiderweb/local_server.py create mode 100644 spiderweb/routes.py create mode 100644 spiderweb/secrets.py create mode 100644 spiderweb/wsgi_main.py diff --git a/example.py b/example.py index 351495b..a7e593e 100644 --- a/example.py +++ b/example.py @@ -1,4 +1,5 @@ -from spiderweb import WebServer +from spiderweb.decorators import csrf_exempt +from spiderweb.wsgi_main import SpiderwebRouter from spiderweb.exceptions import ServerError from spiderweb.response import ( HttpResponse, @@ -8,7 +9,7 @@ from spiderweb.response import ( ) -app = WebServer( +app = SpiderwebRouter( templates_dirs=["templates"], middleware=[ "spiderweb.middleware.csrf.CSRFMiddleware", @@ -23,7 +24,6 @@ app = WebServer( @app.route("/") def index(request): - print(app.BASE_DIR) return TemplateResponse(request, "test.html", context={"value": "TEST!"}) @@ -59,7 +59,8 @@ def http405(request) -> HttpResponse: return HttpResponse(body="Method not allowed", status_code=405) -@app.route("/form", allowed_methods=["POST"]) +@csrf_exempt +@app.route("/form", allowed_methods=["GET", "POST"]) def form(request): if request.method == "POST": return JsonResponse(data=request.POST) @@ -70,4 +71,8 @@ def form(request): if __name__ == "__main__": # can also add routes like this: # app.add_route("/", index) - app.start(blocking=True) + # + # If gunicorn is installed, you can run this file directly through gunicorn with + # `gunicorn --workers=2 "example:app"` -- the biggest thing here is that all + # configuration must be done using decorators or top level in the file. + app.start() diff --git a/example2.py b/example2.py new file mode 100644 index 0000000..59181a3 --- /dev/null +++ b/example2.py @@ -0,0 +1,78 @@ +from spiderweb.decorators import csrf_exempt +from spiderweb.wsgi_main import SpiderwebRouter +from spiderweb.exceptions import ServerError +from spiderweb.response import ( + HttpResponse, + JsonResponse, + TemplateResponse, + RedirectResponse, +) + + +def index(request): + return TemplateResponse(request, "test.html", context={"value": "TEST!"}) + + +def redirect(request): + return RedirectResponse("/") + + +def json(request): + return JsonResponse(data={"key": "value"}) + + +def error(request): + raise ServerError + + +def middleware(request): + return HttpResponse( + body="We'll never hit this because it's redirected in middleware" + ) + + +def example(request, id): + return HttpResponse(body=f"Example with id {id}") + + +def http405(request) -> HttpResponse: + return HttpResponse(body="Method not allowed", status_code=405) + + +def form(request): + if request.method == "POST": + return JsonResponse(data=request.POST) + else: + return TemplateResponse(request, "form.html") + + +app = SpiderwebRouter( + templates_dirs=["templates"], + middleware=[ + "spiderweb.middleware.csrf.CSRFMiddleware", + "example_middleware.TestMiddleware", + "example_middleware.RedirectMiddleware", + "example_middleware.ExplodingMiddleware", + ], + staticfiles_dirs=["static_files"], + routes=[ + ["/", index], + ["/redirect", redirect], + ["/json", json], + ["/error", error], + ["/middleware", middleware], + ["/example/", example], + ["/form", form, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}], + ], + error_routes={"405": http405}, +) + + +if __name__ == "__main__": + # can also add routes like this: + # app.add_route("/", index) + # + # If gunicorn is installed, you can run this file directly through gunicorn with + # `gunicorn --workers=2 "example:app"` -- the biggest thing here is that all + # configuration must be done using decorators or top level in the file. + app.start() diff --git a/poetry.lock b/poetry.lock index ad84957..a8fdd73 100644 --- a/poetry.lock +++ b/poetry.lock @@ -197,6 +197,27 @@ ssh = ["bcrypt (>=3.1.5)"] test = ["certifi", "cryptography-vectors (==43.0.0)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] test-randomorder = ["pytest-randomly"] +[[package]] +name = "gunicorn" +version = "23.0.0" +description = "WSGI HTTP Server for UNIX" +optional = false +python-versions = ">=3.7" +files = [ + {file = "gunicorn-23.0.0-py3-none-any.whl", hash = "sha256:ec400d38950de4dfd418cff8328b2c8faed0edb0d517d3394e457c317908ca4d"}, + {file = "gunicorn-23.0.0.tar.gz", hash = "sha256:f014447a0101dc57e294f6c18ca6b40227a4c90e9bdb586042628030cba004ec"}, +] + +[package.dependencies] +packaging = "*" + +[package.extras] +eventlet = ["eventlet (>=0.24.1,!=0.36.0)"] +gevent = ["gevent (>=1.4.0)"] +setproctitle = ["setproctitle"] +testing = ["coverage", "eventlet", "gevent", "pytest", "pytest-cov"] +tornado = ["tornado (>=0.2)"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -401,32 +422,32 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments [[package]] name = "ruff" -version = "0.5.5" +version = "0.5.7" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.5.5-py3-none-linux_armv6l.whl", hash = "sha256:605d589ec35d1da9213a9d4d7e7a9c761d90bba78fc8790d1c5e65026c1b9eaf"}, - {file = "ruff-0.5.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00817603822a3e42b80f7c3298c8269e09f889ee94640cd1fc7f9329788d7bf8"}, - {file = "ruff-0.5.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:187a60f555e9f865a2ff2c6984b9afeffa7158ba6e1eab56cb830404c942b0f3"}, - {file = "ruff-0.5.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe26fc46fa8c6e0ae3f47ddccfbb136253c831c3289bba044befe68f467bfb16"}, - {file = "ruff-0.5.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4ad25dd9c5faac95c8e9efb13e15803cd8bbf7f4600645a60ffe17c73f60779b"}, - {file = "ruff-0.5.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f70737c157d7edf749bcb952d13854e8f745cec695a01bdc6e29c29c288fc36e"}, - {file = "ruff-0.5.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:cfd7de17cef6ab559e9f5ab859f0d3296393bc78f69030967ca4d87a541b97a0"}, - {file = "ruff-0.5.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a09b43e02f76ac0145f86a08e045e2ea452066f7ba064fd6b0cdccb486f7c3e7"}, - {file = "ruff-0.5.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0b856cb19c60cd40198be5d8d4b556228e3dcd545b4f423d1ad812bfdca5884"}, - {file = "ruff-0.5.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3687d002f911e8a5faf977e619a034d159a8373514a587249cc00f211c67a091"}, - {file = "ruff-0.5.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ac9dc814e510436e30d0ba535f435a7f3dc97f895f844f5b3f347ec8c228a523"}, - {file = "ruff-0.5.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:af9bdf6c389b5add40d89b201425b531e0a5cceb3cfdcc69f04d3d531c6be74f"}, - {file = "ruff-0.5.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:d40a8533ed545390ef8315b8e25c4bb85739b90bd0f3fe1280a29ae364cc55d8"}, - {file = "ruff-0.5.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cab904683bf9e2ecbbe9ff235bfe056f0eba754d0168ad5407832928d579e7ab"}, - {file = "ruff-0.5.5-py3-none-win32.whl", hash = "sha256:696f18463b47a94575db635ebb4c178188645636f05e934fdf361b74edf1bb2d"}, - {file = "ruff-0.5.5-py3-none-win_amd64.whl", hash = "sha256:50f36d77f52d4c9c2f1361ccbfbd09099a1b2ea5d2b2222c586ab08885cf3445"}, - {file = "ruff-0.5.5-py3-none-win_arm64.whl", hash = "sha256:3191317d967af701f1b73a31ed5788795936e423b7acce82a2b63e26eb3e89d6"}, - {file = "ruff-0.5.5.tar.gz", hash = "sha256:cc5516bdb4858d972fbc31d246bdb390eab8df1a26e2353be2dbc0c2d7f5421a"}, + {file = "ruff-0.5.7-py3-none-linux_armv6l.whl", hash = "sha256:548992d342fc404ee2e15a242cdbea4f8e39a52f2e7752d0e4cbe88d2d2f416a"}, + {file = "ruff-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00cc8872331055ee017c4f1071a8a31ca0809ccc0657da1d154a1d2abac5c0be"}, + {file = "ruff-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf3d86a1fdac1aec8a3417a63587d93f906c678bb9ed0b796da7b59c1114a1e"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a01c34400097b06cf8a6e61b35d6d456d5bd1ae6961542de18ec81eaf33b4cb8"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcc8054f1a717e2213500edaddcf1dbb0abad40d98e1bd9d0ad364f75c763eea"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f70284e73f36558ef51602254451e50dd6cc479f8b6f8413a95fcb5db4a55fc"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:a78ad870ae3c460394fc95437d43deb5c04b5c29297815a2a1de028903f19692"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ccd078c66a8e419475174bfe60a69adb36ce04f8d4e91b006f1329d5cd44bcf"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e31c9bad4ebf8fdb77b59cae75814440731060a09a0e0077d559a556453acbb"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d796327eed8e168164346b769dd9a27a70e0298d667b4ecee6877ce8095ec8e"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a09ea2c3f7778cc635e7f6edf57d566a8ee8f485f3c4454db7771efb692c499"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a36d8dcf55b3a3bc353270d544fb170d75d2dff41eba5df57b4e0b67a95bb64e"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9369c218f789eefbd1b8d82a8cf25017b523ac47d96b2f531eba73770971c9e5"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b88ca3db7eb377eb24fb7c82840546fb7acef75af4a74bd36e9ceb37a890257e"}, + {file = "ruff-0.5.7-py3-none-win32.whl", hash = "sha256:33d61fc0e902198a3e55719f4be6b375b28f860b09c281e4bdbf783c0566576a"}, + {file = "ruff-0.5.7-py3-none-win_amd64.whl", hash = "sha256:083bbcbe6fadb93cd86709037acc510f86eed5a314203079df174c40bbbca6b3"}, + {file = "ruff-0.5.7-py3-none-win_arm64.whl", hash = "sha256:2dca26154ff9571995107221d0aeaad0e75a77b5a682d6236cf89a58c70b76f4"}, + {file = "ruff-0.5.7.tar.gz", hash = "sha256:8dfc0a458797f5d9fb622dd0efc52d796f23f0a1493a9527f4e49a550ae9a7e5"}, ] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "96cb529cc8a301c9ac0920582fb7ccb26bc789b0c5ccbc4135fc2d8d6936bb75" +content-hash = "e74f9bbb0dad671b46a8e80cbc5776de00b043b8e93ba102eab9be8b4aef2fac" diff --git a/pyproject.toml b/pyproject.toml index 2c5a640..250b235 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "spiderweb" -version = "0.7.0" +version = "0.8.0" description = "A small web framework, just big enough to hold your average spider." authors = ["Joe Kaufeld "] readme = "README.md" @@ -15,7 +15,10 @@ cryptography = "^43.0.0" ruff = "^0.5.5" pytest = "^8.3.2" black = "^24.8.0" +gunicorn = "^23.0.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.poetry_bumpversion.file."spiderweb/constants.py"] diff --git a/spiderweb/__init__.py b/spiderweb/__init__.py index eff7bf0..762f1bd 100644 --- a/spiderweb/__init__.py +++ b/spiderweb/__init__.py @@ -1,2 +1,3 @@ -from spiderweb.main import route, WebServer # noqa: F401 +from spiderweb.wsgi_main import SpiderwebRouter # noqa: F401 from spiderweb.middleware import * # noqa: F401, F403 +from spiderweb.constants import __version__ diff --git a/spiderweb/constants.py b/spiderweb/constants.py index 8f1a223..a2d6a18 100644 --- a/spiderweb/constants.py +++ b/spiderweb/constants.py @@ -1,2 +1,3 @@ DEFAULT_ALLOWED_METHODS = ["GET"] DEFAULT_ENCODING = "ISO-8859-1" +__version__ = "0.8.0" diff --git a/spiderweb/decorators.py b/spiderweb/decorators.py new file mode 100644 index 0000000..056b618 --- /dev/null +++ b/spiderweb/decorators.py @@ -0,0 +1,4 @@ +def csrf_exempt(func): + """Mark a view as not requiring CSRF verification on POST requests.""" + func.csrf_exempt = True + return func diff --git a/spiderweb/default_responses.py b/spiderweb/default_responses.py deleted file mode 100644 index 017ed37..0000000 --- a/spiderweb/default_responses.py +++ /dev/null @@ -1,19 +0,0 @@ -from spiderweb.response import JsonResponse - - -def http403(request): - return JsonResponse(data={"error": "Forbidden"}, status_code=403) - - -def http404(request): - return JsonResponse( - data={"error": f"Route {request.url} not found"}, status_code=404 - ) - - -def http405(request): - return JsonResponse(data={"error": "Method not allowed"}, status_code=405) - - -def http500(request): - return JsonResponse(data={"error": "Internal server error"}, status_code=500) diff --git a/spiderweb/default_views.py b/spiderweb/default_views.py new file mode 100644 index 0000000..1a26917 --- /dev/null +++ b/spiderweb/default_views.py @@ -0,0 +1,33 @@ +import os + +from spiderweb.exceptions import NotFound +from spiderweb.response import JsonResponse, FileResponse +from spiderweb.utils import is_safe_path + + +def http403(request): + return JsonResponse(data={"error": "Forbidden"}, status_code=403) + + +def http404(request): + return JsonResponse( + data={"error": f"Route {request.url} not found"}, status_code=404 + ) + + +def http405(request): + return JsonResponse(data={"error": "Method not allowed"}, status_code=405) + + +def http500(request): + return JsonResponse(data={"error": "Internal server error"}, status_code=500) + + +def send_file(request, filename: str) -> FileResponse: + for folder in request.server.staticfiles_dirs: + requested_path = request.server.BASE_DIR / folder / filename + if os.path.exists(requested_path): + if not is_safe_path(requested_path): + raise NotFound + return FileResponse(filename=requested_path) + raise NotFound diff --git a/spiderweb/exceptions.py b/spiderweb/exceptions.py index 3acc7f1..4969560 100644 --- a/spiderweb/exceptions.py +++ b/spiderweb/exceptions.py @@ -15,6 +15,7 @@ class SpiderwebNetworkException(SpiderwebException): def __str__(self): return f"{self.__class__.__name__}({self.code}, {self.msg})" + class APIError(SpiderwebNetworkException): pass diff --git a/spiderweb/local_server.py b/spiderweb/local_server.py new file mode 100644 index 0000000..cfca6e9 --- /dev/null +++ b/spiderweb/local_server.py @@ -0,0 +1,56 @@ +import signal +import threading +import time +from logging import Logger +from threading import Thread +from typing import NoReturn, Callable, Any +from wsgiref.simple_server import WSGIServer, WSGIRequestHandler + +from spiderweb.constants import __version__ + + +class SpiderwebRequestHandler(WSGIRequestHandler): + server_version = "spiderweb/" + __version__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class LocalServerMiddleware: + """Cannot be called on its own. Requires context of SpiderwebRouter.""" + + addr: str + port: int + log: Logger + _server: WSGIServer + _thread: Thread + + def create_server(self): + server = WSGIServer((self.addr, self.port), SpiderwebRequestHandler) + server.set_app(self) + return server + + def signal_handler(self, sig, frame) -> NoReturn: + self.log.warning("Shutting down!") + self.stop() + + def start(self, blocking=False): + signal.signal(signal.SIGINT, self.signal_handler) + self.log.info(f"Starting server on {self.addr}:{self.port}") + self.log.info("Press CTRL+C to stop the server.") + self._server = self.create_server() + self._thread = threading.Thread(target=self._server.serve_forever) + self._thread.start() + + if not blocking: + return self._thread + else: + while self._thread.is_alive(): + try: + time.sleep(0.2) + except KeyboardInterrupt: + self.stop() + + def stop(self): + self._server.shutdown() + self._server.socket.close() diff --git a/spiderweb/main.py b/spiderweb/main.py index 6039490..97d3c84 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -14,6 +14,7 @@ import urllib.parse as urlparse import threading import logging from typing import Callable, Any, NoReturn +from wsgiref.simple_server import WSGIRequestHandler, WSGIServer from cryptography.fernet import Fernet from jinja2 import Environment, FileSystemLoader @@ -36,7 +37,8 @@ from spiderweb.response import ( HttpResponse, JsonResponse, TemplateResponse, - RedirectResponse, FileResponse, + RedirectResponse, + FileResponse, ) from spiderweb.utils import import_by_string, is_safe_path @@ -149,7 +151,9 @@ class WebServer(HTTPServer): for static_dir in self.staticfiles_dirs: static_dir = pathlib.Path(static_dir) if not pathlib.Path(self.BASE_DIR / static_dir).exists(): - log.error(f"Static files directory '{str(static_dir)}' does not exist.") + log.error( + f"Static files directory '{str(static_dir)}' does not exist." + ) raise ConfigError self.add_route(r"/static/", send_file) @@ -190,12 +194,16 @@ class WebServer(HTTPServer): if self.convert_path(path) in self.handler_class._routes: raise ConfigError(f"Route '{path}' already exists.") - def add_route(self, path: str, method: Callable, allowed_methods: None|list[str] = None): + def add_route( + self, path: str, method: Callable, allowed_methods: None | list[str] = None + ): """Add a route to the server.""" if not hasattr(self.handler_class, "_routes"): setattr(self.handler_class, "_routes", {}) - allowed_methods = allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS + allowed_methods = ( + allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS + ) if self.append_slash and not path.endswith("/"): updated_path = path + "/" @@ -366,7 +374,12 @@ class RequestHandler(BaseHTTPRequestHandler): return http500 return view - def _fire_response(self, status: int=200, content: str=None, headers: dict[str, str | int]=None): + def _fire_response( + self, + status: int = 200, + content: str = None, + headers: dict[str, str | int] = None, + ): self.send_response(status) self.send_header("Content-Length", str(len(content))) if headers: @@ -377,7 +390,9 @@ class RequestHandler(BaseHTTPRequestHandler): def fire_response(self, request: Request, resp: HttpResponse): try: - self._fire_response(status=resp.status_code, content=resp.render(), headers=resp.headers) + self._fire_response( + status=resp.status_code, content=resp.render(), headers=resp.headers + ) except APIError: raise except ConnectionAbortedError as e: diff --git a/spiderweb/middleware/__init__.py b/spiderweb/middleware/__init__.py index 35964e5..89a8ac2 100644 --- a/spiderweb/middleware/__init__.py +++ b/spiderweb/middleware/__init__.py @@ -1,2 +1,47 @@ +from typing import Callable, ClassVar + from .base import SpiderwebMiddleware from .csrf import CSRFMiddleware +from ..exceptions import ConfigError, UnusedMiddleware +from ..request import Request +from ..response import HttpResponse +from ..utils import import_by_string + + +class MiddlewareMiddleware: + """Cannot be called on its own. Requires context of SpiderwebRouter.""" + + middleware: list[ClassVar] + fire_response: Callable + + def init_middleware(self): + if self.middleware: + middleware_by_reference = [] + for m in self.middleware: + try: + middleware_by_reference.append(import_by_string(m)(server=self)) + except ImportError: + raise ConfigError(f"Middleware '{m}' not found.") + self.middleware = middleware_by_reference + + def process_request_middleware(self, request: Request) -> None | bool: + for middleware in self.middleware: + try: + resp = middleware.process_request(request) + except UnusedMiddleware: + self.middleware.remove(middleware) + continue + if resp: + self.process_response_middleware(request, resp) + self.fire_response(request, resp) + return True # abort further processing + + def process_response_middleware( + self, request: Request, response: HttpResponse + ) -> None: + for middleware in self.middleware: + try: + middleware.process_response(request, response) + except UnusedMiddleware: + self.middleware.remove(middleware) + continue diff --git a/spiderweb/middleware/csrf.py b/spiderweb/middleware/csrf.py index 0c502f7..fd95e42 100644 --- a/spiderweb/middleware/csrf.py +++ b/spiderweb/middleware/csrf.py @@ -11,6 +11,9 @@ class CSRFMiddleware(SpiderwebMiddleware): def process_request(self, request: Request) -> HttpResponse | None: if request.method == "POST": + if hasattr(request.handler, "csrf_exempt"): + if request.handler.csrf_exempt is True: + return csrf_token = ( request.headers.get("X-CSRF-TOKEN") or request.GET.get("csrf_token") @@ -26,7 +29,10 @@ class CSRFMiddleware(SpiderwebMiddleware): token = self.get_csrf_token() # do we need it in both places? response.headers["X-CSRF-TOKEN"] = token - request.csrf_token = token + response.context |= { + "csrf_token": f"""""", + "raw_csrf_token": token, # in case they want to format it themselves + } def get_csrf_token(self): return self.server.encrypt(str(datetime.now().isoformat())).decode( diff --git a/spiderweb/request.py b/spiderweb/request.py index 5215efb..930dea9 100644 --- a/spiderweb/request.py +++ b/spiderweb/request.py @@ -1,28 +1,73 @@ import json from urllib.parse import urlparse +from spiderweb.constants import DEFAULT_ENCODING + class Request: def __init__( self, + environ=None, content=None, - body=None, - method=None, headers=None, path=None, - query_params=None, - server=None + server=None, + handler=None, ): + self.environ = environ self.content: str = content - self.body: str = body - self.method: str = method - self.headers: dict[str] = headers - self.path: str = path + self.method: str = environ["REQUEST_METHOD"] + self.headers: dict[str, str] = headers if headers else {} + self.path: str = path if path else environ["PATH_INFO"] self.url = urlparse(path) - self.query_params = query_params + self.query_params = [] self.server = server + self.handler = handler # the view function that will be called self.GET = {} self.POST = {} + self.META = {} + + self.populate_headers() + self.populate_meta() + + content_length = int(self.headers.get("CONTENT_LENGTH") or 0) + if content_length: + self.content = ( + self.environ["wsgi.input"].read(content_length).decode(DEFAULT_ENCODING) + ) + + def populate_headers(self) -> None: + self.headers |= { + "CONTENT_TYPE": self.environ.get("CONTENT_TYPE"), + "CONTENT_LENGTH": self.environ.get("CONTENT_LENGTH"), + } + for k, v in self.environ.items(): + if k.startswith("HTTP_"): + self.headers[k] = v + + def populate_meta(self) -> None: + fields = [ + "SERVER_PROTOCOL", + "SERVER_SOFTWARE", + "REQUEST_METHOD", + "PATH_INFO", + "QUERY_STRING", + "REMOTE_HOST", + "REMOTE_ADDR", + "SERVER_NAME", + "GATEWAY_INTERFACE", + "SERVER_PORT", + "CONTENT_LENGTH", + "SCRIPT_NAME", + ] + for f in fields: + self.META[f] = self.environ.get(f) def json(self): return json.loads(self.content) + + def is_form_request(self) -> bool: + return ( + "CONTENT_TYPE" in self.headers + and self.headers["CONTENT_TYPE"] == "application/x-www-form-urlencoded" + ) diff --git a/spiderweb/response.py b/spiderweb/response.py index 01cf493..06f6b19 100644 --- a/spiderweb/response.py +++ b/spiderweb/response.py @@ -2,6 +2,7 @@ import datetime import json from typing import Any import mimetypes +from wsgiref.util import FileWrapper from spiderweb.constants import DEFAULT_ENCODING from spiderweb.exceptions import GeneralException @@ -46,9 +47,9 @@ class FileResponse(HttpResponse): self.content_type = mimetypes.guess_type(self.filename)[0] self.headers["Content-Type"] = self.content_type - def render(self) -> str: - with open(self.filename, 'rb') as f: - self.body = f.read().decode(DEFAULT_ENCODING) + def render(self) -> list[bytes]: + with open(self.filename, "rb") as f: + self.body = [chunk for chunk in FileWrapper(f)] return self.body diff --git a/spiderweb/routes.py b/spiderweb/routes.py new file mode 100644 index 0000000..407eadc --- /dev/null +++ b/spiderweb/routes.py @@ -0,0 +1,158 @@ +import re +from typing import Callable, Any + +from spiderweb.constants import DEFAULT_ALLOWED_METHODS +from spiderweb.converters import * # noqa: F403 +from spiderweb.default_views import * # noqa: F403 +from spiderweb.exceptions import NotFound, ConfigError, ParseError +from spiderweb.response import RedirectResponse + + +def convert_match_to_dict(match: dict): + """Convert a match object to a dict with the proper converted types for each match.""" + return { + k.split("__")[0]: globals()[k.split("__")[1]]().to_python(v) + for k, v in match.items() + } + + +class DummyRedirectRoute: + def __init__(self, location): + self.location = location + + def __call__(self, request): + return RedirectResponse(self.location) + + +class RoutesMiddleware: + """Cannot be called on its own. Requires context of SpiderwebRouter.""" + + # ones that start with underscores are the compiled versions, non-underscores + # are the user-supplied versions + _routes: dict + routes: list[list[str | Callable | dict]] + _error_routes: dict + error_routes: dict[str, Callable] + append_slash: bool + + def route(self, path, allowed_methods=None) -> Callable: + """ + Decorator for adding a route to a view. + + Usage: + + app = WebServer() + + @app.route("/hello") + def index(request): + return HttpResponse(content="Hello, world!") + + :param path: str + :param allowed_methods: list[str] + :return: Callable + """ + + def outer(func): + self.add_route(path, func, allowed_methods) + return func + + return outer + + def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]: + for option in self._routes.keys(): + if match_data := option.match(path): + return ( + self._routes[option]["func"], + convert_match_to_dict(match_data.groupdict()), + self._routes[option]["allowed_methods"], + ) + raise NotFound() + + def add_error_route(self, code: int, method: Callable): + """Add an error route to the server.""" + if code not in self._error_routes: + self._error_routes[code] = method + else: + raise ConfigError(f"Error route for code {code} already exists.") + + def error(self, code: int) -> Callable: + def outer(func): + self.add_error_route(code, func) + return func + + return outer + + def get_error_route(self, code: int) -> Callable: + view = self._error_routes.get(code) or globals().get(f"http{code}") + if not view: + return http500 + return view + + def check_for_route_duplicates(self, path: str): + if self.convert_path(path) in self._routes: + raise ConfigError(f"Route '{path}' already exists.") + + def convert_path(self, path: str): + """Convert a path to a regex.""" + parts = path.split("/") + for i, part in enumerate(parts): + if part.startswith("<") and part.endswith(">"): + name = part[1:-1] + if "__" in name: + raise ConfigError( + f"Cannot use `__` (double underscore) in path variable." + f" Please fix '{name}'." + ) + if ":" in name: + converter, name = name.split(":") + try: + converter = globals()[converter.title() + "Converter"] + except KeyError: + raise ParseError(f"Unknown converter {converter}") + else: + converter = StrConverter # noqa: F405 + parts[i] = rf"(?P<{name}__{str(converter.__name__)}>{converter.regex})" + return re.compile(rf"^{'/'.join(parts)}$") + + def add_route( + self, path: str, method: Callable, allowed_methods: None | list[str] = None + ): + """Add a route to the server.""" + allowed_methods = ( + getattr(method, "allowed_methods", None) + or allowed_methods + or DEFAULT_ALLOWED_METHODS + ) + + if self.append_slash and not path.endswith("/"): + updated_path = path + "/" + self.check_for_route_duplicates(updated_path) + self.check_for_route_duplicates(path) + self._routes[self.convert_path(path)] = { + "func": DummyRedirectRoute(updated_path), + "allowed_methods": allowed_methods, + } + self._routes[self.convert_path(updated_path)] = { + "func": method, + "allowed_methods": allowed_methods, + } + else: + self.check_for_route_duplicates(path) + self._routes[self.convert_path(path)] = { + "func": method, + "allowed_methods": allowed_methods, + } + + def add_routes(self): + for line in self.routes: + if len(line) == 3: + path, func, kwargs = line + for k, v in kwargs.items(): + setattr(func, k, v) + else: + path, func = line + self.add_route(path, func) + + def add_error_routes(self): + for code, func in self.error_routes.items(): + self.add_error_route(int(code), func) diff --git a/spiderweb/secrets.py b/spiderweb/secrets.py new file mode 100644 index 0000000..b0e26ae --- /dev/null +++ b/spiderweb/secrets.py @@ -0,0 +1,26 @@ +from cryptography.fernet import Fernet + +from spiderweb.constants import DEFAULT_ENCODING + + +class FernetMiddleware: + """Cannot be called on its own. Requires context of SpiderwebRouter.""" + + fernet: Fernet + secret_key: str + + def init_fernet(self): + self.fernet = Fernet(self.secret_key) + + def generate_key(self): + return Fernet.generate_key() + + def encrypt(self, data: str): + return self.fernet.encrypt(bytes(data, DEFAULT_ENCODING)) + + def decrypt(self, data: str): + if isinstance(data, bytes): + return self.fernet.decrypt(data).decode(DEFAULT_ENCODING) + return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode( + DEFAULT_ENCODING + ) diff --git a/spiderweb/utils.py b/spiderweb/utils.py index 559b231..7ddc47a 100644 --- a/spiderweb/utils.py +++ b/spiderweb/utils.py @@ -1,3 +1,9 @@ +from http import HTTPStatus +from typing import Optional + +from spiderweb.request import Request + + def import_by_string(name): # https://stackoverflow.com/a/547867 components = name.split(".") @@ -10,3 +16,23 @@ def import_by_string(name): def is_safe_path(path: str) -> bool: # this cannot possibly catch all issues return not ".." in str(path) + + +def get_http_status_by_code(code: int) -> Optional[str]: + """ + Get the full HTTP status code required by WSGI by code. + + Example: + >>> get_http_status_by_code(200) + '200 OK' + """ + resp = HTTPStatus(code) + if resp: + return f"{resp.value} {resp.phrase}" + + +def is_form_request(request: Request) -> bool: + return ( + "Content-Type" in request.headers + and request.headers["Content-Type"] == "application/x-www-form-urlencoded" + ) diff --git a/spiderweb/wsgi_main.py b/spiderweb/wsgi_main.py new file mode 100644 index 0000000..371116f --- /dev/null +++ b/spiderweb/wsgi_main.py @@ -0,0 +1,208 @@ +import inspect +import logging +import pathlib +import traceback +import urllib.parse as urlparse +from threading import Thread +from typing import Optional, Callable +from wsgiref.simple_server import WSGIServer + +from jinja2 import Environment, FileSystemLoader + +from spiderweb.middleware import MiddlewareMiddleware +from spiderweb.constants import DEFAULT_ENCODING, DEFAULT_ALLOWED_METHODS +from spiderweb.default_views import * # noqa: F403 +from spiderweb.exceptions import ( + ConfigError, + NotFound, + APIError, + NoResponseError, + SpiderwebNetworkException, +) +from spiderweb.local_server import LocalServerMiddleware +from spiderweb.request import Request +from spiderweb.response import HttpResponse, TemplateResponse +from spiderweb.routes import RoutesMiddleware +from spiderweb.secrets import FernetMiddleware +from spiderweb.utils import get_http_status_by_code + +file_logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +class SpiderwebRouter( + LocalServerMiddleware, MiddlewareMiddleware, RoutesMiddleware, FernetMiddleware +): + def __init__( + self, + addr: str = None, + port: int = None, + templates_dirs: list[str] = None, + middleware: list[str] = None, + append_slash: bool = False, + staticfiles_dirs: list[str] = None, + routes: list[list[str | Callable | dict]] = None, + error_routes: dict[str, Callable] = None, + secret_key: str = None, + log=None, + ): + self._routes = {} + self.routes = routes + self._error_routes = {} + self.error_routes = error_routes + self.addr = addr if addr else "localhost" + self.port = port if port else 8000 + self.server_address = (self.addr, self.port) + self.append_slash = append_slash + self.templates_dirs = templates_dirs + self.staticfiles_dirs = staticfiles_dirs + self.middleware = middleware if middleware else [] + self.secret_key = secret_key if secret_key else self.generate_key() + + self.DEFAULT_ENCODING = DEFAULT_ENCODING + self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS + self.log = log if log else file_logger + + # for using .start() and .stop() + self._thread: Optional[Thread] = None + self._server: Optional[WSGIServer] = None + self.BASE_DIR = self.get_caller_filepath() + + self.init_fernet() + self.init_middleware() + + if self.routes: + self.add_routes() + + if self.templates_dirs: + self.env = Environment(loader=FileSystemLoader(self.templates_dirs)) + else: + self.env = None + + if self.staticfiles_dirs: + for static_dir in self.staticfiles_dirs: + static_dir = pathlib.Path(static_dir) + if not pathlib.Path(self.BASE_DIR / static_dir).exists(): + log.error( + f"Static files directory '{str(static_dir)}' does not exist." + ) + raise ConfigError + self.add_route(r"/static/", send_file) + + def fire_response(self, start_response, request: Request, resp: HttpResponse): + try: + status = get_http_status_by_code(resp.status_code) + headers = list(resp.headers.items()) + + start_response(status, headers) + + rendered_output = resp.render() + if not isinstance(rendered_output, list): + rendered_output = [rendered_output] + + encoded_resp = [ + chunk.encode(DEFAULT_ENCODING) if isinstance(chunk, str) else chunk + for chunk in rendered_output + ] + + return encoded_resp + except APIError: + raise + except ConnectionAbortedError as e: + self.log.error(f"GET {request.path} : {e}") + except Exception: + self.log.error(traceback.format_exc()) + return self.fire_response( + start_response, request, self.get_error_route(500)(request) + ) + + def get_caller_filepath(self): + """Figure out who called us and return their path.""" + stack = inspect.stack() + caller_frame = stack[1] + return pathlib.Path(caller_frame.filename).parent.parent + + def get_request(self, environ): + return Request( + content="", + environ=environ, + server=self, + ) + + def send_error_response( + self, start_response, request: Request, e: SpiderwebNetworkException + ): + try: + status = get_http_status_by_code(500) + headers = [("Content-type", "text/plain; charset=utf-8")] + + start_response(status, headers) + + resp = [ + f"Something went wrong.\n\nCode: {e.code}\n\nMsg: {e.msg}\n\nDesc: {e.desc}".encode( + DEFAULT_ENCODING + ) + ] + + return resp + except ConnectionAbortedError as e: + self.log.error(f"{request.method} {request.path} : {e}") + + def prepare_and_fire_response(self, start_response, request, resp) -> list[bytes]: + try: + if isinstance(resp, dict): + self.fire_response(request, JsonResponse(data=resp)) + if isinstance(resp, TemplateResponse): + resp.set_template_loader(self.env) + + for middleware in self.middleware: + middleware.process_response(request, resp) + + return self.fire_response(start_response, request, resp) + + except APIError: + raise + + except Exception: + self.log.error(traceback.format_exc()) + self.fire_response( + start_response, request, self.get_error_route(500)(request) + ) + + def __call__(self, environ, start_response, *args, **kwargs): + """Entry point for WSGI apps.""" + request = self.get_request(environ) + + try: + handler, additional_args, allowed_methods = self.get_route(request.path) + except NotFound: + handler = self.get_error_route(404) + additional_args = {} + allowed_methods = DEFAULT_ALLOWED_METHODS + request.handler = handler + + if request.method not in allowed_methods: + # replace the potentially valid handler with the error route + handler = self.get_error_route(405) + + if request.is_form_request(): + form_data = urlparse.parse_qs(request.content) + for key, value in form_data.items(): + if len(value) == 1: + form_data[key] = value[0] + setattr(request, request.method, form_data) + + try: + if handler: + abort = self.process_request_middleware(request) + if abort: + return + resp = handler(request, **additional_args) + if resp is None: + raise NoResponseError(f"View {handler} returned None.") + # run the response through the middleware and send it + return self.prepare_and_fire_response(start_response, request, resp) + else: + raise SpiderwebNetworkException(404) + except SpiderwebNetworkException as e: + return self.send_error_response(start_response, request, e) diff --git a/templates/form.html b/templates/form.html index 1701002..3cbe891 100644 --- a/templates/form.html +++ b/templates/form.html @@ -14,7 +14,7 @@ - +