🎨 run black
This commit is contained in:
parent
7b60e2fd32
commit
8ec39c2803
@ -1,6 +1,11 @@
|
||||
from spiderweb import WebServer
|
||||
from spiderweb.exceptions import ServerError
|
||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
||||
from spiderweb.response import (
|
||||
HttpResponse,
|
||||
JsonResponse,
|
||||
TemplateResponse,
|
||||
RedirectResponse,
|
||||
)
|
||||
|
||||
|
||||
app = WebServer(
|
||||
@ -11,7 +16,7 @@ app = WebServer(
|
||||
"example_middleware.RedirectMiddleware",
|
||||
"example_middleware.ExplodingMiddleware",
|
||||
],
|
||||
append_slash=False # default
|
||||
append_slash=False, # default
|
||||
)
|
||||
|
||||
|
||||
|
@ -9,9 +9,7 @@ class TestMiddleware(SpiderwebMiddleware):
|
||||
# example of a middleware that sets a flag on the request
|
||||
request.spiderweb = True
|
||||
|
||||
def process_response(
|
||||
self, request: Request, response: HttpResponse
|
||||
) -> None:
|
||||
def process_response(self, request: Request, response: HttpResponse) -> None:
|
||||
# example of a middleware that sets a header on the resp
|
||||
if hasattr(request, "spiderweb"):
|
||||
response.headers["X-Spiderweb"] = "true"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "spiderweb"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
description = "A small web framework, just big enough to hold your average spider."
|
||||
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
|
||||
readme = "README.md"
|
||||
|
@ -28,7 +28,9 @@ class BadRequest(SpiderwebNetworkException):
|
||||
def __init__(self, desc=None):
|
||||
self.code = 400
|
||||
self.msg = "Bad Request"
|
||||
self.desc = desc if desc else "The request could not be understood by the server"
|
||||
self.desc = (
|
||||
desc if desc else "The request could not be understood by the server"
|
||||
)
|
||||
|
||||
|
||||
class Unauthorized(SpiderwebNetworkException):
|
||||
@ -59,7 +61,6 @@ class CSRFError(SpiderwebNetworkException):
|
||||
self.desc = desc if desc else "CSRF token is invalid"
|
||||
|
||||
|
||||
|
||||
class ConfigError(SpiderwebException):
|
||||
pass
|
||||
|
||||
|
@ -23,10 +23,18 @@ from spiderweb.exceptions import (
|
||||
ConfigError,
|
||||
ParseError,
|
||||
GeneralException,
|
||||
NoResponseError, UnusedMiddleware, SpiderwebNetworkException, NotFound,
|
||||
NoResponseError,
|
||||
UnusedMiddleware,
|
||||
SpiderwebNetworkException,
|
||||
NotFound,
|
||||
)
|
||||
from spiderweb.request import Request
|
||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
||||
from spiderweb.response import (
|
||||
HttpResponse,
|
||||
JsonResponse,
|
||||
TemplateResponse,
|
||||
RedirectResponse,
|
||||
)
|
||||
from spiderweb.utils import import_by_string
|
||||
|
||||
|
||||
@ -163,11 +171,20 @@ class WebServer(HTTPServer):
|
||||
updated_path = path + "/"
|
||||
self.check_for_route_duplicates(updated_path)
|
||||
self.check_for_route_duplicates(path)
|
||||
self.handler_class._routes[self.convert_path(path)] = {'func': DummyRedirectRoute(updated_path), 'allowed_methods': allowed_methods}
|
||||
self.handler_class._routes[self.convert_path(updated_path)] = {'func': method, 'allowed_methods': allowed_methods}
|
||||
self.handler_class._routes[self.convert_path(path)] = {
|
||||
"func": DummyRedirectRoute(updated_path),
|
||||
"allowed_methods": allowed_methods,
|
||||
}
|
||||
self.handler_class._routes[self.convert_path(updated_path)] = {
|
||||
"func": method,
|
||||
"allowed_methods": allowed_methods,
|
||||
}
|
||||
else:
|
||||
self.check_for_route_duplicates(path)
|
||||
self.handler_class._routes[self.convert_path(path)] = {'func': method, 'allowed_methods': allowed_methods}
|
||||
self.handler_class._routes[self.convert_path(path)] = {
|
||||
"func": method,
|
||||
"allowed_methods": allowed_methods,
|
||||
}
|
||||
|
||||
def add_error_route(self, code: int, method: Callable):
|
||||
"""Add an error route to the server."""
|
||||
@ -200,7 +217,7 @@ class WebServer(HTTPServer):
|
||||
self.add_route(
|
||||
path,
|
||||
func,
|
||||
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS
|
||||
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS,
|
||||
)
|
||||
return func
|
||||
|
||||
@ -210,6 +227,7 @@ class WebServer(HTTPServer):
|
||||
def outer(func):
|
||||
self.add_error_route(code, func)
|
||||
return func
|
||||
|
||||
return outer
|
||||
|
||||
@property
|
||||
@ -230,7 +248,7 @@ class WebServer(HTTPServer):
|
||||
return self.__addr + ":" + str(self.port()) + "/" + path
|
||||
|
||||
def signal_handler(self, sig, frame) -> NoReturn:
|
||||
log.warning('Shutting down!')
|
||||
log.warning("Shutting down!")
|
||||
self.stop()
|
||||
|
||||
def start(self, blocking=False):
|
||||
@ -261,7 +279,10 @@ class WebServer(HTTPServer):
|
||||
def decrypt(self, data: str):
|
||||
if isinstance(data, bytes):
|
||||
return self.fernet.decrypt(data).decode(DEFAULT_ENCODING)
|
||||
return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(DEFAULT_ENCODING)
|
||||
return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(
|
||||
DEFAULT_ENCODING
|
||||
)
|
||||
|
||||
|
||||
class RequestHandler(BaseHTTPRequestHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -301,9 +322,11 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||
def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]:
|
||||
for option in self._routes.keys():
|
||||
if match_data := option.match(path):
|
||||
return self._routes[option]['func'], convert_match_to_dict(
|
||||
match_data.groupdict()
|
||||
), self._routes[option]['allowed_methods']
|
||||
return (
|
||||
self._routes[option]["func"],
|
||||
convert_match_to_dict(match_data.groupdict()),
|
||||
self._routes[option]["allowed_methods"],
|
||||
)
|
||||
raise NotFound()
|
||||
|
||||
def get_error_route(self, code: int) -> Callable:
|
||||
@ -345,7 +368,9 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||
self.fire_response(request, resp)
|
||||
return True # abort further processing
|
||||
|
||||
def process_response_middleware(self, request: Request, response: HttpResponse) -> None:
|
||||
def process_response_middleware(
|
||||
self, request: Request, response: HttpResponse
|
||||
) -> None:
|
||||
for middleware in self.server.middleware:
|
||||
try:
|
||||
middleware.process_response(request, response)
|
||||
@ -374,7 +399,10 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||
self.fire_response(request, self.get_error_route(500)(request))
|
||||
|
||||
def is_form_request(self, request: Request) -> bool:
|
||||
return "Content-Type" in request.headers and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
|
||||
return (
|
||||
"Content-Type" in request.headers
|
||||
and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
|
||||
)
|
||||
|
||||
def send_error_response(self, request: Request, e: SpiderwebNetworkException):
|
||||
try:
|
||||
@ -397,7 +425,9 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||
# replace the potentially valid handler with the error route
|
||||
handler = self.get_error_route(405)
|
||||
|
||||
request.query_params = urlparse.parse_qs(request.url.query) if request.url.query else {}
|
||||
request.query_params = (
|
||||
urlparse.parse_qs(request.url.query) if request.url.query else {}
|
||||
)
|
||||
|
||||
if self.is_form_request(request):
|
||||
formdata = urlparse.parse_qs(request.content.decode("utf-8"))
|
||||
|
@ -17,6 +17,7 @@ class SpiderwebMiddleware:
|
||||
If `process_request` returns a HttpResponse, the request will be short-circuited
|
||||
and the response will be returned immediately. `process_response` will not be called.
|
||||
"""
|
||||
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
|
||||
|
@ -11,7 +11,11 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
||||
|
||||
def process_request(self, request: Request) -> HttpResponse | None:
|
||||
if request.method == "POST":
|
||||
csrf_token = request.headers.get("X-CSRF-TOKEN") or request.GET.get("csrf_token") or request.POST.get("csrf_token")
|
||||
csrf_token = (
|
||||
request.headers.get("X-CSRF-TOKEN")
|
||||
or request.GET.get("csrf_token")
|
||||
or request.POST.get("csrf_token")
|
||||
)
|
||||
if self.is_csrf_valid(csrf_token):
|
||||
return None
|
||||
else:
|
||||
@ -25,12 +29,16 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
||||
request.csrf_token = token
|
||||
|
||||
def get_csrf_token(self):
|
||||
return self.server.encrypt(str(datetime.now().isoformat())).decode(self.server.DEFAULT_ENCODING)
|
||||
return self.server.encrypt(str(datetime.now().isoformat())).decode(
|
||||
self.server.DEFAULT_ENCODING
|
||||
)
|
||||
|
||||
def is_csrf_valid(self, key):
|
||||
try:
|
||||
decoded = self.server.decrypt(key)
|
||||
if datetime.now() - timedelta(seconds=self.CSRF_EXPIRY) > datetime.fromisoformat(decoded):
|
||||
if datetime.now() - timedelta(
|
||||
seconds=self.CSRF_EXPIRY
|
||||
) > datetime.fromisoformat(decoded):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
|
@ -1,6 +1,6 @@
|
||||
def import_by_string(name):
|
||||
# https://stackoverflow.com/a/547867
|
||||
components = name.split('.')
|
||||
components = name.split(".")
|
||||
mod = __import__(components[0])
|
||||
for comp in components[1:]:
|
||||
mod = getattr(mod, comp)
|
||||
|
Loading…
Reference in New Issue
Block a user