🎨 run black

This commit is contained in:
Joe Kaufeld 2024-08-14 17:28:01 -04:00
parent 7b60e2fd32
commit 8ec39c2803
8 changed files with 69 additions and 26 deletions

View File

@ -1,6 +1,11 @@
from spiderweb import WebServer from spiderweb import WebServer
from spiderweb.exceptions import ServerError from spiderweb.exceptions import ServerError
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse from spiderweb.response import (
HttpResponse,
JsonResponse,
TemplateResponse,
RedirectResponse,
)
app = WebServer( app = WebServer(
@ -11,7 +16,7 @@ app = WebServer(
"example_middleware.RedirectMiddleware", "example_middleware.RedirectMiddleware",
"example_middleware.ExplodingMiddleware", "example_middleware.ExplodingMiddleware",
], ],
append_slash=False # default append_slash=False, # default
) )

View File

@ -9,9 +9,7 @@ class TestMiddleware(SpiderwebMiddleware):
# example of a middleware that sets a flag on the request # example of a middleware that sets a flag on the request
request.spiderweb = True request.spiderweb = True
def process_response( def process_response(self, request: Request, response: HttpResponse) -> None:
self, request: Request, response: HttpResponse
) -> None:
# example of a middleware that sets a header on the resp # example of a middleware that sets a header on the resp
if hasattr(request, "spiderweb"): if hasattr(request, "spiderweb"):
response.headers["X-Spiderweb"] = "true" response.headers["X-Spiderweb"] = "true"

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "spiderweb" name = "spiderweb"
version = "0.5.0" version = "0.6.0"
description = "A small web framework, just big enough to hold your average spider." description = "A small web framework, just big enough to hold your average spider."
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"] authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
readme = "README.md" readme = "README.md"

View File

@ -28,7 +28,9 @@ class BadRequest(SpiderwebNetworkException):
def __init__(self, desc=None): def __init__(self, desc=None):
self.code = 400 self.code = 400
self.msg = "Bad Request" self.msg = "Bad Request"
self.desc = desc if desc else "The request could not be understood by the server" self.desc = (
desc if desc else "The request could not be understood by the server"
)
class Unauthorized(SpiderwebNetworkException): class Unauthorized(SpiderwebNetworkException):
@ -59,7 +61,6 @@ class CSRFError(SpiderwebNetworkException):
self.desc = desc if desc else "CSRF token is invalid" self.desc = desc if desc else "CSRF token is invalid"
class ConfigError(SpiderwebException): class ConfigError(SpiderwebException):
pass pass

View File

@ -23,10 +23,18 @@ from spiderweb.exceptions import (
ConfigError, ConfigError,
ParseError, ParseError,
GeneralException, GeneralException,
NoResponseError, UnusedMiddleware, SpiderwebNetworkException, NotFound, NoResponseError,
UnusedMiddleware,
SpiderwebNetworkException,
NotFound,
) )
from spiderweb.request import Request from spiderweb.request import Request
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse from spiderweb.response import (
HttpResponse,
JsonResponse,
TemplateResponse,
RedirectResponse,
)
from spiderweb.utils import import_by_string from spiderweb.utils import import_by_string
@ -163,11 +171,20 @@ class WebServer(HTTPServer):
updated_path = path + "/" updated_path = path + "/"
self.check_for_route_duplicates(updated_path) self.check_for_route_duplicates(updated_path)
self.check_for_route_duplicates(path) self.check_for_route_duplicates(path)
self.handler_class._routes[self.convert_path(path)] = {'func': DummyRedirectRoute(updated_path), 'allowed_methods': allowed_methods} self.handler_class._routes[self.convert_path(path)] = {
self.handler_class._routes[self.convert_path(updated_path)] = {'func': method, 'allowed_methods': allowed_methods} "func": DummyRedirectRoute(updated_path),
"allowed_methods": allowed_methods,
}
self.handler_class._routes[self.convert_path(updated_path)] = {
"func": method,
"allowed_methods": allowed_methods,
}
else: else:
self.check_for_route_duplicates(path) self.check_for_route_duplicates(path)
self.handler_class._routes[self.convert_path(path)] = {'func': method, 'allowed_methods': allowed_methods} self.handler_class._routes[self.convert_path(path)] = {
"func": method,
"allowed_methods": allowed_methods,
}
def add_error_route(self, code: int, method: Callable): def add_error_route(self, code: int, method: Callable):
"""Add an error route to the server.""" """Add an error route to the server."""
@ -200,7 +217,7 @@ class WebServer(HTTPServer):
self.add_route( self.add_route(
path, path,
func, func,
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS,
) )
return func return func
@ -210,6 +227,7 @@ class WebServer(HTTPServer):
def outer(func): def outer(func):
self.add_error_route(code, func) self.add_error_route(code, func)
return func return func
return outer return outer
@property @property
@ -230,7 +248,7 @@ class WebServer(HTTPServer):
return self.__addr + ":" + str(self.port()) + "/" + path return self.__addr + ":" + str(self.port()) + "/" + path
def signal_handler(self, sig, frame) -> NoReturn: def signal_handler(self, sig, frame) -> NoReturn:
log.warning('Shutting down!') log.warning("Shutting down!")
self.stop() self.stop()
def start(self, blocking=False): def start(self, blocking=False):
@ -261,7 +279,10 @@ class WebServer(HTTPServer):
def decrypt(self, data: str): def decrypt(self, data: str):
if isinstance(data, bytes): if isinstance(data, bytes):
return self.fernet.decrypt(data).decode(DEFAULT_ENCODING) return self.fernet.decrypt(data).decode(DEFAULT_ENCODING)
return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(DEFAULT_ENCODING) return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(
DEFAULT_ENCODING
)
class RequestHandler(BaseHTTPRequestHandler): class RequestHandler(BaseHTTPRequestHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -301,9 +322,11 @@ class RequestHandler(BaseHTTPRequestHandler):
def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]: def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]:
for option in self._routes.keys(): for option in self._routes.keys():
if match_data := option.match(path): if match_data := option.match(path):
return self._routes[option]['func'], convert_match_to_dict( return (
match_data.groupdict() self._routes[option]["func"],
), self._routes[option]['allowed_methods'] convert_match_to_dict(match_data.groupdict()),
self._routes[option]["allowed_methods"],
)
raise NotFound() raise NotFound()
def get_error_route(self, code: int) -> Callable: def get_error_route(self, code: int) -> Callable:
@ -345,7 +368,9 @@ class RequestHandler(BaseHTTPRequestHandler):
self.fire_response(request, resp) self.fire_response(request, resp)
return True # abort further processing return True # abort further processing
def process_response_middleware(self, request: Request, response: HttpResponse) -> None: def process_response_middleware(
self, request: Request, response: HttpResponse
) -> None:
for middleware in self.server.middleware: for middleware in self.server.middleware:
try: try:
middleware.process_response(request, response) middleware.process_response(request, response)
@ -374,7 +399,10 @@ class RequestHandler(BaseHTTPRequestHandler):
self.fire_response(request, self.get_error_route(500)(request)) self.fire_response(request, self.get_error_route(500)(request))
def is_form_request(self, request: Request) -> bool: def is_form_request(self, request: Request) -> bool:
return "Content-Type" in request.headers and request.headers["Content-Type"] == "application/x-www-form-urlencoded" return (
"Content-Type" in request.headers
and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
)
def send_error_response(self, request: Request, e: SpiderwebNetworkException): def send_error_response(self, request: Request, e: SpiderwebNetworkException):
try: try:
@ -397,7 +425,9 @@ class RequestHandler(BaseHTTPRequestHandler):
# replace the potentially valid handler with the error route # replace the potentially valid handler with the error route
handler = self.get_error_route(405) handler = self.get_error_route(405)
request.query_params = urlparse.parse_qs(request.url.query) if request.url.query else {} request.query_params = (
urlparse.parse_qs(request.url.query) if request.url.query else {}
)
if self.is_form_request(request): if self.is_form_request(request):
formdata = urlparse.parse_qs(request.content.decode("utf-8")) formdata = urlparse.parse_qs(request.content.decode("utf-8"))

View File

@ -17,6 +17,7 @@ class SpiderwebMiddleware:
If `process_request` returns a HttpResponse, the request will be short-circuited If `process_request` returns a HttpResponse, the request will be short-circuited
and the response will be returned immediately. `process_response` will not be called. and the response will be returned immediately. `process_response` will not be called.
""" """
def __init__(self, server): def __init__(self, server):
self.server = server self.server = server

View File

@ -11,7 +11,11 @@ class CSRFMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse | None: def process_request(self, request: Request) -> HttpResponse | None:
if request.method == "POST": if request.method == "POST":
csrf_token = request.headers.get("X-CSRF-TOKEN") or request.GET.get("csrf_token") or request.POST.get("csrf_token") csrf_token = (
request.headers.get("X-CSRF-TOKEN")
or request.GET.get("csrf_token")
or request.POST.get("csrf_token")
)
if self.is_csrf_valid(csrf_token): if self.is_csrf_valid(csrf_token):
return None return None
else: else:
@ -25,12 +29,16 @@ class CSRFMiddleware(SpiderwebMiddleware):
request.csrf_token = token request.csrf_token = token
def get_csrf_token(self): def get_csrf_token(self):
return self.server.encrypt(str(datetime.now().isoformat())).decode(self.server.DEFAULT_ENCODING) return self.server.encrypt(str(datetime.now().isoformat())).decode(
self.server.DEFAULT_ENCODING
)
def is_csrf_valid(self, key): def is_csrf_valid(self, key):
try: try:
decoded = self.server.decrypt(key) decoded = self.server.decrypt(key)
if datetime.now() - timedelta(seconds=self.CSRF_EXPIRY) > datetime.fromisoformat(decoded): if datetime.now() - timedelta(
seconds=self.CSRF_EXPIRY
) > datetime.fromisoformat(decoded):
return False return False
return True return True
except Exception: except Exception:

View File

@ -1,6 +1,6 @@
def import_by_string(name): def import_by_string(name):
# https://stackoverflow.com/a/547867 # https://stackoverflow.com/a/547867
components = name.split('.') components = name.split(".")
mod = __import__(components[0]) mod = __import__(components[0])
for comp in components[1:]: for comp in components[1:]:
mod = getattr(mod, comp) mod = getattr(mod, comp)