🎨 run black
This commit is contained in:
parent
7b60e2fd32
commit
8ec39c2803
@ -1,6 +1,11 @@
|
|||||||
from spiderweb import WebServer
|
from spiderweb import WebServer
|
||||||
from spiderweb.exceptions import ServerError
|
from spiderweb.exceptions import ServerError
|
||||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
from spiderweb.response import (
|
||||||
|
HttpResponse,
|
||||||
|
JsonResponse,
|
||||||
|
TemplateResponse,
|
||||||
|
RedirectResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
app = WebServer(
|
app = WebServer(
|
||||||
@ -11,7 +16,7 @@ app = WebServer(
|
|||||||
"example_middleware.RedirectMiddleware",
|
"example_middleware.RedirectMiddleware",
|
||||||
"example_middleware.ExplodingMiddleware",
|
"example_middleware.ExplodingMiddleware",
|
||||||
],
|
],
|
||||||
append_slash=False # default
|
append_slash=False, # default
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,9 +9,7 @@ class TestMiddleware(SpiderwebMiddleware):
|
|||||||
# example of a middleware that sets a flag on the request
|
# example of a middleware that sets a flag on the request
|
||||||
request.spiderweb = True
|
request.spiderweb = True
|
||||||
|
|
||||||
def process_response(
|
def process_response(self, request: Request, response: HttpResponse) -> None:
|
||||||
self, request: Request, response: HttpResponse
|
|
||||||
) -> None:
|
|
||||||
# example of a middleware that sets a header on the resp
|
# example of a middleware that sets a header on the resp
|
||||||
if hasattr(request, "spiderweb"):
|
if hasattr(request, "spiderweb"):
|
||||||
response.headers["X-Spiderweb"] = "true"
|
response.headers["X-Spiderweb"] = "true"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "spiderweb"
|
name = "spiderweb"
|
||||||
version = "0.5.0"
|
version = "0.6.0"
|
||||||
description = "A small web framework, just big enough to hold your average spider."
|
description = "A small web framework, just big enough to hold your average spider."
|
||||||
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
|
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -28,7 +28,9 @@ class BadRequest(SpiderwebNetworkException):
|
|||||||
def __init__(self, desc=None):
|
def __init__(self, desc=None):
|
||||||
self.code = 400
|
self.code = 400
|
||||||
self.msg = "Bad Request"
|
self.msg = "Bad Request"
|
||||||
self.desc = desc if desc else "The request could not be understood by the server"
|
self.desc = (
|
||||||
|
desc if desc else "The request could not be understood by the server"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Unauthorized(SpiderwebNetworkException):
|
class Unauthorized(SpiderwebNetworkException):
|
||||||
@ -59,7 +61,6 @@ class CSRFError(SpiderwebNetworkException):
|
|||||||
self.desc = desc if desc else "CSRF token is invalid"
|
self.desc = desc if desc else "CSRF token is invalid"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigError(SpiderwebException):
|
class ConfigError(SpiderwebException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -23,10 +23,18 @@ from spiderweb.exceptions import (
|
|||||||
ConfigError,
|
ConfigError,
|
||||||
ParseError,
|
ParseError,
|
||||||
GeneralException,
|
GeneralException,
|
||||||
NoResponseError, UnusedMiddleware, SpiderwebNetworkException, NotFound,
|
NoResponseError,
|
||||||
|
UnusedMiddleware,
|
||||||
|
SpiderwebNetworkException,
|
||||||
|
NotFound,
|
||||||
)
|
)
|
||||||
from spiderweb.request import Request
|
from spiderweb.request import Request
|
||||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
from spiderweb.response import (
|
||||||
|
HttpResponse,
|
||||||
|
JsonResponse,
|
||||||
|
TemplateResponse,
|
||||||
|
RedirectResponse,
|
||||||
|
)
|
||||||
from spiderweb.utils import import_by_string
|
from spiderweb.utils import import_by_string
|
||||||
|
|
||||||
|
|
||||||
@ -163,11 +171,20 @@ class WebServer(HTTPServer):
|
|||||||
updated_path = path + "/"
|
updated_path = path + "/"
|
||||||
self.check_for_route_duplicates(updated_path)
|
self.check_for_route_duplicates(updated_path)
|
||||||
self.check_for_route_duplicates(path)
|
self.check_for_route_duplicates(path)
|
||||||
self.handler_class._routes[self.convert_path(path)] = {'func': DummyRedirectRoute(updated_path), 'allowed_methods': allowed_methods}
|
self.handler_class._routes[self.convert_path(path)] = {
|
||||||
self.handler_class._routes[self.convert_path(updated_path)] = {'func': method, 'allowed_methods': allowed_methods}
|
"func": DummyRedirectRoute(updated_path),
|
||||||
|
"allowed_methods": allowed_methods,
|
||||||
|
}
|
||||||
|
self.handler_class._routes[self.convert_path(updated_path)] = {
|
||||||
|
"func": method,
|
||||||
|
"allowed_methods": allowed_methods,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
self.check_for_route_duplicates(path)
|
self.check_for_route_duplicates(path)
|
||||||
self.handler_class._routes[self.convert_path(path)] = {'func': method, 'allowed_methods': allowed_methods}
|
self.handler_class._routes[self.convert_path(path)] = {
|
||||||
|
"func": method,
|
||||||
|
"allowed_methods": allowed_methods,
|
||||||
|
}
|
||||||
|
|
||||||
def add_error_route(self, code: int, method: Callable):
|
def add_error_route(self, code: int, method: Callable):
|
||||||
"""Add an error route to the server."""
|
"""Add an error route to the server."""
|
||||||
@ -200,7 +217,7 @@ class WebServer(HTTPServer):
|
|||||||
self.add_route(
|
self.add_route(
|
||||||
path,
|
path,
|
||||||
func,
|
func,
|
||||||
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS
|
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS,
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@ -210,6 +227,7 @@ class WebServer(HTTPServer):
|
|||||||
def outer(func):
|
def outer(func):
|
||||||
self.add_error_route(code, func)
|
self.add_error_route(code, func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return outer
|
return outer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -230,7 +248,7 @@ class WebServer(HTTPServer):
|
|||||||
return self.__addr + ":" + str(self.port()) + "/" + path
|
return self.__addr + ":" + str(self.port()) + "/" + path
|
||||||
|
|
||||||
def signal_handler(self, sig, frame) -> NoReturn:
|
def signal_handler(self, sig, frame) -> NoReturn:
|
||||||
log.warning('Shutting down!')
|
log.warning("Shutting down!")
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
def start(self, blocking=False):
|
def start(self, blocking=False):
|
||||||
@ -261,7 +279,10 @@ class WebServer(HTTPServer):
|
|||||||
def decrypt(self, data: str):
|
def decrypt(self, data: str):
|
||||||
if isinstance(data, bytes):
|
if isinstance(data, bytes):
|
||||||
return self.fernet.decrypt(data).decode(DEFAULT_ENCODING)
|
return self.fernet.decrypt(data).decode(DEFAULT_ENCODING)
|
||||||
return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(DEFAULT_ENCODING)
|
return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(
|
||||||
|
DEFAULT_ENCODING
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler(BaseHTTPRequestHandler):
|
class RequestHandler(BaseHTTPRequestHandler):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -301,9 +322,11 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||||||
def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]:
|
def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]:
|
||||||
for option in self._routes.keys():
|
for option in self._routes.keys():
|
||||||
if match_data := option.match(path):
|
if match_data := option.match(path):
|
||||||
return self._routes[option]['func'], convert_match_to_dict(
|
return (
|
||||||
match_data.groupdict()
|
self._routes[option]["func"],
|
||||||
), self._routes[option]['allowed_methods']
|
convert_match_to_dict(match_data.groupdict()),
|
||||||
|
self._routes[option]["allowed_methods"],
|
||||||
|
)
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
def get_error_route(self, code: int) -> Callable:
|
def get_error_route(self, code: int) -> Callable:
|
||||||
@ -345,7 +368,9 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||||||
self.fire_response(request, resp)
|
self.fire_response(request, resp)
|
||||||
return True # abort further processing
|
return True # abort further processing
|
||||||
|
|
||||||
def process_response_middleware(self, request: Request, response: HttpResponse) -> None:
|
def process_response_middleware(
|
||||||
|
self, request: Request, response: HttpResponse
|
||||||
|
) -> None:
|
||||||
for middleware in self.server.middleware:
|
for middleware in self.server.middleware:
|
||||||
try:
|
try:
|
||||||
middleware.process_response(request, response)
|
middleware.process_response(request, response)
|
||||||
@ -374,7 +399,10 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||||||
self.fire_response(request, self.get_error_route(500)(request))
|
self.fire_response(request, self.get_error_route(500)(request))
|
||||||
|
|
||||||
def is_form_request(self, request: Request) -> bool:
|
def is_form_request(self, request: Request) -> bool:
|
||||||
return "Content-Type" in request.headers and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
|
return (
|
||||||
|
"Content-Type" in request.headers
|
||||||
|
and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
|
||||||
|
)
|
||||||
|
|
||||||
def send_error_response(self, request: Request, e: SpiderwebNetworkException):
|
def send_error_response(self, request: Request, e: SpiderwebNetworkException):
|
||||||
try:
|
try:
|
||||||
@ -397,7 +425,9 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||||||
# replace the potentially valid handler with the error route
|
# replace the potentially valid handler with the error route
|
||||||
handler = self.get_error_route(405)
|
handler = self.get_error_route(405)
|
||||||
|
|
||||||
request.query_params = urlparse.parse_qs(request.url.query) if request.url.query else {}
|
request.query_params = (
|
||||||
|
urlparse.parse_qs(request.url.query) if request.url.query else {}
|
||||||
|
)
|
||||||
|
|
||||||
if self.is_form_request(request):
|
if self.is_form_request(request):
|
||||||
formdata = urlparse.parse_qs(request.content.decode("utf-8"))
|
formdata = urlparse.parse_qs(request.content.decode("utf-8"))
|
||||||
|
@ -17,6 +17,7 @@ class SpiderwebMiddleware:
|
|||||||
If `process_request` returns a HttpResponse, the request will be short-circuited
|
If `process_request` returns a HttpResponse, the request will be short-circuited
|
||||||
and the response will be returned immediately. `process_response` will not be called.
|
and the response will be returned immediately. `process_response` will not be called.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, server):
|
def __init__(self, server):
|
||||||
self.server = server
|
self.server = server
|
||||||
|
|
||||||
|
@ -11,7 +11,11 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
|||||||
|
|
||||||
def process_request(self, request: Request) -> HttpResponse | None:
|
def process_request(self, request: Request) -> HttpResponse | None:
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
csrf_token = request.headers.get("X-CSRF-TOKEN") or request.GET.get("csrf_token") or request.POST.get("csrf_token")
|
csrf_token = (
|
||||||
|
request.headers.get("X-CSRF-TOKEN")
|
||||||
|
or request.GET.get("csrf_token")
|
||||||
|
or request.POST.get("csrf_token")
|
||||||
|
)
|
||||||
if self.is_csrf_valid(csrf_token):
|
if self.is_csrf_valid(csrf_token):
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
@ -25,12 +29,16 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
|||||||
request.csrf_token = token
|
request.csrf_token = token
|
||||||
|
|
||||||
def get_csrf_token(self):
|
def get_csrf_token(self):
|
||||||
return self.server.encrypt(str(datetime.now().isoformat())).decode(self.server.DEFAULT_ENCODING)
|
return self.server.encrypt(str(datetime.now().isoformat())).decode(
|
||||||
|
self.server.DEFAULT_ENCODING
|
||||||
|
)
|
||||||
|
|
||||||
def is_csrf_valid(self, key):
|
def is_csrf_valid(self, key):
|
||||||
try:
|
try:
|
||||||
decoded = self.server.decrypt(key)
|
decoded = self.server.decrypt(key)
|
||||||
if datetime.now() - timedelta(seconds=self.CSRF_EXPIRY) > datetime.fromisoformat(decoded):
|
if datetime.now() - timedelta(
|
||||||
|
seconds=self.CSRF_EXPIRY
|
||||||
|
) > datetime.fromisoformat(decoded):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
def import_by_string(name):
|
def import_by_string(name):
|
||||||
# https://stackoverflow.com/a/547867
|
# https://stackoverflow.com/a/547867
|
||||||
components = name.split('.')
|
components = name.split(".")
|
||||||
mod = __import__(components[0])
|
mod = __import__(components[0])
|
||||||
for comp in components[1:]:
|
for comp in components[1:]:
|
||||||
mod = getattr(mod, comp)
|
mod = getattr(mod, comp)
|
||||||
|
Loading…
Reference in New Issue
Block a user