🎨 run black

This commit is contained in:
Joe Kaufeld 2024-08-14 17:28:01 -04:00
parent 7b60e2fd32
commit 8ec39c2803
8 changed files with 69 additions and 26 deletions

View File

@ -1,6 +1,11 @@
from spiderweb import WebServer
from spiderweb.exceptions import ServerError
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
from spiderweb.response import (
HttpResponse,
JsonResponse,
TemplateResponse,
RedirectResponse,
)
app = WebServer(
@ -11,7 +16,7 @@ app = WebServer(
"example_middleware.RedirectMiddleware",
"example_middleware.ExplodingMiddleware",
],
append_slash=False # default
append_slash=False, # default
)

View File

@ -9,9 +9,7 @@ class TestMiddleware(SpiderwebMiddleware):
# example of a middleware that sets a flag on the request
request.spiderweb = True
def process_response(
self, request: Request, response: HttpResponse
) -> None:
def process_response(self, request: Request, response: HttpResponse) -> None:
# example of a middleware that sets a header on the resp
if hasattr(request, "spiderweb"):
response.headers["X-Spiderweb"] = "true"

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "spiderweb"
version = "0.5.0"
version = "0.6.0"
description = "A small web framework, just big enough to hold your average spider."
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
readme = "README.md"

View File

@ -28,7 +28,9 @@ class BadRequest(SpiderwebNetworkException):
def __init__(self, desc=None):
self.code = 400
self.msg = "Bad Request"
self.desc = desc if desc else "The request could not be understood by the server"
self.desc = (
desc if desc else "The request could not be understood by the server"
)
class Unauthorized(SpiderwebNetworkException):
@ -59,7 +61,6 @@ class CSRFError(SpiderwebNetworkException):
self.desc = desc if desc else "CSRF token is invalid"
class ConfigError(SpiderwebException):
pass

View File

@ -23,10 +23,18 @@ from spiderweb.exceptions import (
ConfigError,
ParseError,
GeneralException,
NoResponseError, UnusedMiddleware, SpiderwebNetworkException, NotFound,
NoResponseError,
UnusedMiddleware,
SpiderwebNetworkException,
NotFound,
)
from spiderweb.request import Request
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
from spiderweb.response import (
HttpResponse,
JsonResponse,
TemplateResponse,
RedirectResponse,
)
from spiderweb.utils import import_by_string
@ -163,11 +171,20 @@ class WebServer(HTTPServer):
updated_path = path + "/"
self.check_for_route_duplicates(updated_path)
self.check_for_route_duplicates(path)
self.handler_class._routes[self.convert_path(path)] = {'func': DummyRedirectRoute(updated_path), 'allowed_methods': allowed_methods}
self.handler_class._routes[self.convert_path(updated_path)] = {'func': method, 'allowed_methods': allowed_methods}
self.handler_class._routes[self.convert_path(path)] = {
"func": DummyRedirectRoute(updated_path),
"allowed_methods": allowed_methods,
}
self.handler_class._routes[self.convert_path(updated_path)] = {
"func": method,
"allowed_methods": allowed_methods,
}
else:
self.check_for_route_duplicates(path)
self.handler_class._routes[self.convert_path(path)] = {'func': method, 'allowed_methods': allowed_methods}
self.handler_class._routes[self.convert_path(path)] = {
"func": method,
"allowed_methods": allowed_methods,
}
def add_error_route(self, code: int, method: Callable):
"""Add an error route to the server."""
@ -200,7 +217,7 @@ class WebServer(HTTPServer):
self.add_route(
path,
func,
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS,
)
return func
@ -210,6 +227,7 @@ class WebServer(HTTPServer):
def outer(func):
self.add_error_route(code, func)
return func
return outer
@property
@ -230,7 +248,7 @@ class WebServer(HTTPServer):
return self.__addr + ":" + str(self.port()) + "/" + path
def signal_handler(self, sig, frame) -> NoReturn:
log.warning('Shutting down!')
log.warning("Shutting down!")
self.stop()
def start(self, blocking=False):
@ -261,7 +279,10 @@ class WebServer(HTTPServer):
def decrypt(self, data: str):
if isinstance(data, bytes):
return self.fernet.decrypt(data).decode(DEFAULT_ENCODING)
return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(DEFAULT_ENCODING)
return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(
DEFAULT_ENCODING
)
class RequestHandler(BaseHTTPRequestHandler):
def __init__(self, *args, **kwargs):
@ -301,9 +322,11 @@ class RequestHandler(BaseHTTPRequestHandler):
def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]:
for option in self._routes.keys():
if match_data := option.match(path):
return self._routes[option]['func'], convert_match_to_dict(
match_data.groupdict()
), self._routes[option]['allowed_methods']
return (
self._routes[option]["func"],
convert_match_to_dict(match_data.groupdict()),
self._routes[option]["allowed_methods"],
)
raise NotFound()
def get_error_route(self, code: int) -> Callable:
@ -345,7 +368,9 @@ class RequestHandler(BaseHTTPRequestHandler):
self.fire_response(request, resp)
return True # abort further processing
def process_response_middleware(self, request: Request, response: HttpResponse) -> None:
def process_response_middleware(
self, request: Request, response: HttpResponse
) -> None:
for middleware in self.server.middleware:
try:
middleware.process_response(request, response)
@ -374,7 +399,10 @@ class RequestHandler(BaseHTTPRequestHandler):
self.fire_response(request, self.get_error_route(500)(request))
def is_form_request(self, request: Request) -> bool:
return "Content-Type" in request.headers and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
return (
"Content-Type" in request.headers
and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
)
def send_error_response(self, request: Request, e: SpiderwebNetworkException):
try:
@ -397,7 +425,9 @@ class RequestHandler(BaseHTTPRequestHandler):
# replace the potentially valid handler with the error route
handler = self.get_error_route(405)
request.query_params = urlparse.parse_qs(request.url.query) if request.url.query else {}
request.query_params = (
urlparse.parse_qs(request.url.query) if request.url.query else {}
)
if self.is_form_request(request):
formdata = urlparse.parse_qs(request.content.decode("utf-8"))

View File

@ -17,6 +17,7 @@ class SpiderwebMiddleware:
If `process_request` returns a HttpResponse, the request will be short-circuited
and the response will be returned immediately. `process_response` will not be called.
"""
def __init__(self, server):
self.server = server

View File

@ -11,7 +11,11 @@ class CSRFMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse | None:
if request.method == "POST":
csrf_token = request.headers.get("X-CSRF-TOKEN") or request.GET.get("csrf_token") or request.POST.get("csrf_token")
csrf_token = (
request.headers.get("X-CSRF-TOKEN")
or request.GET.get("csrf_token")
or request.POST.get("csrf_token")
)
if self.is_csrf_valid(csrf_token):
return None
else:
@ -25,12 +29,16 @@ class CSRFMiddleware(SpiderwebMiddleware):
request.csrf_token = token
def get_csrf_token(self):
return self.server.encrypt(str(datetime.now().isoformat())).decode(self.server.DEFAULT_ENCODING)
return self.server.encrypt(str(datetime.now().isoformat())).decode(
self.server.DEFAULT_ENCODING
)
def is_csrf_valid(self, key):
try:
decoded = self.server.decrypt(key)
if datetime.now() - timedelta(seconds=self.CSRF_EXPIRY) > datetime.fromisoformat(decoded):
if datetime.now() - timedelta(
seconds=self.CSRF_EXPIRY
) > datetime.fromisoformat(decoded):
return False
return True
except Exception:

View File

@ -1,6 +1,6 @@
def import_by_string(name):
# https://stackoverflow.com/a/547867
components = name.split('.')
components = name.split(".")
mod = __import__(components[0])
for comp in components[1:]:
mod = getattr(mod, comp)