diff --git a/example.py b/example.py index 7d103ef..508d8d4 100644 --- a/example.py +++ b/example.py @@ -1,6 +1,11 @@ from spiderweb import WebServer from spiderweb.exceptions import ServerError -from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse +from spiderweb.response import ( + HttpResponse, + JsonResponse, + TemplateResponse, + RedirectResponse, +) app = WebServer( @@ -11,7 +16,7 @@ app = WebServer( "example_middleware.RedirectMiddleware", "example_middleware.ExplodingMiddleware", ], - append_slash=False # default + append_slash=False, # default ) diff --git a/example_middleware.py b/example_middleware.py index f566b2a..4d30a3d 100644 --- a/example_middleware.py +++ b/example_middleware.py @@ -9,9 +9,7 @@ class TestMiddleware(SpiderwebMiddleware): # example of a middleware that sets a flag on the request request.spiderweb = True - def process_response( - self, request: Request, response: HttpResponse - ) -> None: + def process_response(self, request: Request, response: HttpResponse) -> None: # example of a middleware that sets a header on the resp if hasattr(request, "spiderweb"): response.headers["X-Spiderweb"] = "true" diff --git a/pyproject.toml b/pyproject.toml index 4738790..1687b6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "spiderweb" -version = "0.5.0" +version = "0.6.0" description = "A small web framework, just big enough to hold your average spider." authors = ["Joe Kaufeld "] readme = "README.md" diff --git a/spiderweb/exceptions.py b/spiderweb/exceptions.py index 36450d7..dd96ee0 100644 --- a/spiderweb/exceptions.py +++ b/spiderweb/exceptions.py @@ -28,7 +28,9 @@ class BadRequest(SpiderwebNetworkException): def __init__(self, desc=None): self.code = 400 self.msg = "Bad Request" - self.desc = desc if desc else "The request could not be understood by the server" + self.desc = ( + desc if desc else "The request could not be understood by the server" + ) class Unauthorized(SpiderwebNetworkException): @@ -59,7 +61,6 @@ class CSRFError(SpiderwebNetworkException): self.desc = desc if desc else "CSRF token is invalid" - class ConfigError(SpiderwebException): pass diff --git a/spiderweb/main.py b/spiderweb/main.py index de3a52e..e130686 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -23,10 +23,18 @@ from spiderweb.exceptions import ( ConfigError, ParseError, GeneralException, - NoResponseError, UnusedMiddleware, SpiderwebNetworkException, NotFound, + NoResponseError, + UnusedMiddleware, + SpiderwebNetworkException, + NotFound, ) from spiderweb.request import Request -from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse +from spiderweb.response import ( + HttpResponse, + JsonResponse, + TemplateResponse, + RedirectResponse, +) from spiderweb.utils import import_by_string @@ -163,11 +171,20 @@ class WebServer(HTTPServer): updated_path = path + "/" self.check_for_route_duplicates(updated_path) self.check_for_route_duplicates(path) - self.handler_class._routes[self.convert_path(path)] = {'func': DummyRedirectRoute(updated_path), 'allowed_methods': allowed_methods} - self.handler_class._routes[self.convert_path(updated_path)] = {'func': method, 'allowed_methods': allowed_methods} + self.handler_class._routes[self.convert_path(path)] = { + "func": DummyRedirectRoute(updated_path), + "allowed_methods": allowed_methods, + } + self.handler_class._routes[self.convert_path(updated_path)] = { + "func": method, + "allowed_methods": allowed_methods, + } else: self.check_for_route_duplicates(path) - self.handler_class._routes[self.convert_path(path)] = {'func': method, 'allowed_methods': allowed_methods} + self.handler_class._routes[self.convert_path(path)] = { + "func": method, + "allowed_methods": allowed_methods, + } def add_error_route(self, code: int, method: Callable): """Add an error route to the server.""" @@ -200,7 +217,7 @@ class WebServer(HTTPServer): self.add_route( path, func, - allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS + allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS, ) return func @@ -210,6 +227,7 @@ class WebServer(HTTPServer): def outer(func): self.add_error_route(code, func) return func + return outer @property @@ -230,7 +248,7 @@ class WebServer(HTTPServer): return self.__addr + ":" + str(self.port()) + "/" + path def signal_handler(self, sig, frame) -> NoReturn: - log.warning('Shutting down!') + log.warning("Shutting down!") self.stop() def start(self, blocking=False): @@ -261,7 +279,10 @@ class WebServer(HTTPServer): def decrypt(self, data: str): if isinstance(data, bytes): return self.fernet.decrypt(data).decode(DEFAULT_ENCODING) - return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(DEFAULT_ENCODING) + return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode( + DEFAULT_ENCODING + ) + class RequestHandler(BaseHTTPRequestHandler): def __init__(self, *args, **kwargs): @@ -301,9 +322,11 @@ class RequestHandler(BaseHTTPRequestHandler): def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]: for option in self._routes.keys(): if match_data := option.match(path): - return self._routes[option]['func'], convert_match_to_dict( - match_data.groupdict() - ), self._routes[option]['allowed_methods'] + return ( + self._routes[option]["func"], + convert_match_to_dict(match_data.groupdict()), + self._routes[option]["allowed_methods"], + ) raise NotFound() def get_error_route(self, code: int) -> Callable: @@ -345,7 +368,9 @@ class RequestHandler(BaseHTTPRequestHandler): self.fire_response(request, resp) return True # abort further processing - def process_response_middleware(self, request: Request, response: HttpResponse) -> None: + def process_response_middleware( + self, request: Request, response: HttpResponse + ) -> None: for middleware in self.server.middleware: try: middleware.process_response(request, response) @@ -374,7 +399,10 @@ class RequestHandler(BaseHTTPRequestHandler): self.fire_response(request, self.get_error_route(500)(request)) def is_form_request(self, request: Request) -> bool: - return "Content-Type" in request.headers and request.headers["Content-Type"] == "application/x-www-form-urlencoded" + return ( + "Content-Type" in request.headers + and request.headers["Content-Type"] == "application/x-www-form-urlencoded" + ) def send_error_response(self, request: Request, e: SpiderwebNetworkException): try: @@ -397,7 +425,9 @@ class RequestHandler(BaseHTTPRequestHandler): # replace the potentially valid handler with the error route handler = self.get_error_route(405) - request.query_params = urlparse.parse_qs(request.url.query) if request.url.query else {} + request.query_params = ( + urlparse.parse_qs(request.url.query) if request.url.query else {} + ) if self.is_form_request(request): formdata = urlparse.parse_qs(request.content.decode("utf-8")) diff --git a/spiderweb/middleware/base.py b/spiderweb/middleware/base.py index d0a1de3..4fc5677 100644 --- a/spiderweb/middleware/base.py +++ b/spiderweb/middleware/base.py @@ -17,6 +17,7 @@ class SpiderwebMiddleware: If `process_request` returns a HttpResponse, the request will be short-circuited and the response will be returned immediately. `process_response` will not be called. """ + def __init__(self, server): self.server = server diff --git a/spiderweb/middleware/csrf.py b/spiderweb/middleware/csrf.py index 16b37c2..0c502f7 100644 --- a/spiderweb/middleware/csrf.py +++ b/spiderweb/middleware/csrf.py @@ -11,7 +11,11 @@ class CSRFMiddleware(SpiderwebMiddleware): def process_request(self, request: Request) -> HttpResponse | None: if request.method == "POST": - csrf_token = request.headers.get("X-CSRF-TOKEN") or request.GET.get("csrf_token") or request.POST.get("csrf_token") + csrf_token = ( + request.headers.get("X-CSRF-TOKEN") + or request.GET.get("csrf_token") + or request.POST.get("csrf_token") + ) if self.is_csrf_valid(csrf_token): return None else: @@ -25,12 +29,16 @@ class CSRFMiddleware(SpiderwebMiddleware): request.csrf_token = token def get_csrf_token(self): - return self.server.encrypt(str(datetime.now().isoformat())).decode(self.server.DEFAULT_ENCODING) + return self.server.encrypt(str(datetime.now().isoformat())).decode( + self.server.DEFAULT_ENCODING + ) def is_csrf_valid(self, key): try: decoded = self.server.decrypt(key) - if datetime.now() - timedelta(seconds=self.CSRF_EXPIRY) > datetime.fromisoformat(decoded): + if datetime.now() - timedelta( + seconds=self.CSRF_EXPIRY + ) > datetime.fromisoformat(decoded): return False return True except Exception: diff --git a/spiderweb/utils.py b/spiderweb/utils.py index 360e5a3..0235b45 100644 --- a/spiderweb/utils.py +++ b/spiderweb/utils.py @@ -1,6 +1,6 @@ def import_by_string(name): # https://stackoverflow.com/a/547867 - components = name.split('.') + components = name.split(".") mod = __import__(components[0]) for comp in components[1:]: mod = getattr(mod, comp)