diff --git a/example.py b/example.py
index 508d8d4..53f7a37 100644
--- a/example.py
+++ b/example.py
@@ -16,12 +16,14 @@ app = WebServer(
"example_middleware.RedirectMiddleware",
"example_middleware.ExplodingMiddleware",
],
+ staticfiles_dirs=["static_files"],
append_slash=False, # default
)
@app.route("/")
def index(request):
+ print(app.BASE_DIR)
return TemplateResponse(request, "test.html", context={"value": "TEST!"})
diff --git a/pyproject.toml b/pyproject.toml
index 1687b6b..2c5a640 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "spiderweb"
-version = "0.6.0"
+version = "0.7.0"
description = "A small web framework, just big enough to hold your average spider."
authors = ["Joe Kaufeld "]
readme = "README.md"
diff --git a/spiderweb/constants.py b/spiderweb/constants.py
new file mode 100644
index 0000000..8f1a223
--- /dev/null
+++ b/spiderweb/constants.py
@@ -0,0 +1,2 @@
+DEFAULT_ALLOWED_METHODS = ["GET"]
+DEFAULT_ENCODING = "ISO-8859-1"
diff --git a/spiderweb/exceptions.py b/spiderweb/exceptions.py
index dd96ee0..3acc7f1 100644
--- a/spiderweb/exceptions.py
+++ b/spiderweb/exceptions.py
@@ -1,7 +1,7 @@
class SpiderwebException(Exception):
# parent error class; all child exceptions should inherit from this
def __str__(self):
- return f"{self.__class__.__name__}({self.code}, {self.msg})"
+ return f"{self.__class__.__name__}()"
class SpiderwebNetworkException(SpiderwebException):
@@ -12,6 +12,8 @@ class SpiderwebNetworkException(SpiderwebException):
self.msg = msg
self.desc = desc
+ def __str__(self):
+ return f"{self.__class__.__name__}({self.code}, {self.msg})"
class APIError(SpiderwebNetworkException):
pass
diff --git a/spiderweb/main.py b/spiderweb/main.py
index e130686..94046ed 100644
--- a/spiderweb/main.py
+++ b/spiderweb/main.py
@@ -2,8 +2,10 @@
# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
#
# Extensively modified by @itsthejoker
-from datetime import datetime, timedelta
+import inspect
+import os
import re
+import pathlib
import signal
import time
import traceback
@@ -16,6 +18,7 @@ from typing import Callable, Any, NoReturn
from cryptography.fernet import Fernet
from jinja2 import Environment, FileSystemLoader
+from spiderweb.constants import DEFAULT_ENCODING, DEFAULT_ALLOWED_METHODS
from spiderweb.converters import * # noqa: F403
from spiderweb.default_responses import * # noqa: F403
from spiderweb.exceptions import (
@@ -33,17 +36,13 @@ from spiderweb.response import (
HttpResponse,
JsonResponse,
TemplateResponse,
- RedirectResponse,
+ RedirectResponse, FileResponse,
)
-from spiderweb.utils import import_by_string
-
+from spiderweb.utils import import_by_string, is_safe_path
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
-DEFAULT_ALLOWED_METHODS = ["GET"]
-DEFAULT_ENCODING = "utf-8"
-
def route(path):
def outer(func):
@@ -55,6 +54,16 @@ def route(path):
return outer
+def send_file(request, filename: str) -> HttpResponse:
+ for folder in request.server.staticfiles_dirs:
+ requested_path = request.server.BASE_DIR / folder / filename
+ if os.path.exists(requested_path):
+ if not is_safe_path(requested_path):
+ raise NotFound
+ return FileResponse(filename=requested_path)
+ raise NotFound
+
+
class DummyRedirectRoute:
def __init__(self, location):
self.location = location
@@ -80,6 +89,7 @@ class WebServer(HTTPServer):
templates_dirs: list[str] = None,
middleware: list[str] = None,
append_slash: bool = False,
+ staticfiles_dirs: list[str] = None,
secret_key: str = None,
):
"""
@@ -95,6 +105,7 @@ class WebServer(HTTPServer):
port = port if port else 8000
self.append_slash = append_slash
self.templates_dirs = templates_dirs
+ self.staticfiles_dirs = staticfiles_dirs
self.middleware = middleware if middleware else []
self.secret_key = secret_key if secret_key else self._create_secret_key()
self.fernet = Fernet(self.key)
@@ -102,6 +113,8 @@ class WebServer(HTTPServer):
self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS
self._thread = None
+ self.BASE_DIR = self.get_caller_filepath()
+
if self.middleware:
middleware_by_reference = []
for m in self.middleware:
@@ -115,6 +128,7 @@ class WebServer(HTTPServer):
self.env = Environment(loader=FileSystemLoader(self.templates_dirs))
else:
self.env = None
+
server_address = (addr, port)
self.__addr = addr
@@ -131,11 +145,25 @@ class WebServer(HTTPServer):
for route in method._routes:
self.add_route(route, method)
+ if self.staticfiles_dirs:
+ for static_dir in self.staticfiles_dirs:
+ static_dir = pathlib.Path(static_dir)
+ if not pathlib.Path(self.BASE_DIR / static_dir).exists():
+ log.error(f"Static files directory '{str(static_dir)}' does not exist.")
+ raise ConfigError
+ self.add_route(r"/static/", send_file)
+
try:
super().__init__(server_address, self.handler_class)
except OSError:
raise GeneralException("Port already in use.")
+ def get_caller_filepath(self):
+ """Figure out who called us and return their path."""
+ stack = inspect.stack()
+ caller_frame = stack[1]
+ return pathlib.Path(caller_frame.filename).parent.parent
+
def convert_path(self, path: str):
"""Convert a path to a regex."""
parts = path.split("/")
@@ -162,11 +190,13 @@ class WebServer(HTTPServer):
if self.convert_path(path) in self.handler_class._routes:
raise ConfigError(f"Route '{path}' already exists.")
- def add_route(self, path: str, method: Callable, allowed_methods: list[str]):
+ def add_route(self, path: str, method: Callable, allowed_methods: None|list[str] = None):
"""Add a route to the server."""
if not hasattr(self.handler_class, "_routes"):
setattr(self.handler_class, "_routes", {})
+ allowed_methods = allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS
+
if self.append_slash and not path.endswith("/"):
updated_path = path + "/"
self.check_for_route_duplicates(updated_path)
@@ -300,6 +330,7 @@ class RequestHandler(BaseHTTPRequestHandler):
method=self.command,
headers=self.headers,
path=self.path,
+ server=self.server,
)
# I can't help the naming convention of these because that's what
@@ -319,6 +350,8 @@ class RequestHandler(BaseHTTPRequestHandler):
request.content = content
self.handle_request(request)
+
+
def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]:
for option in self._routes.keys():
if match_data := option.match(path):
@@ -335,19 +368,18 @@ class RequestHandler(BaseHTTPRequestHandler):
return http500
return view
- def _fire_response(self, resp: HttpResponse):
- self.send_response(resp.status_code)
- content = resp.render()
+ def _fire_response(self, status: int=200, content: str=None, headers: dict[str, str | int]=None):
+ self.send_response(status)
self.send_header("Content-Length", str(len(content)))
- if resp.headers:
- for key, value in resp.headers.items():
+ if headers:
+ for key, value in headers.items():
self.send_header(key, value)
self.end_headers()
self.wfile.write(bytes(content, DEFAULT_ENCODING))
def fire_response(self, request: Request, resp: HttpResponse):
try:
- self._fire_response(resp)
+ self._fire_response(status=resp.status_code, content=resp.render(), headers=resp.headers)
except APIError:
raise
except ConnectionAbortedError as e:
diff --git a/spiderweb/request.py b/spiderweb/request.py
index 83e256f..44aec1d 100644
--- a/spiderweb/request.py
+++ b/spiderweb/request.py
@@ -11,6 +11,7 @@ class Request:
path=None,
url=None,
query_params=None,
+ server=None
):
self.content: str = content
self.body: str = body
@@ -19,6 +20,7 @@ class Request:
self.path: str = path
self.url = url
self.query_params = query_params
+ self.server = server
self.GET = {}
self.POST = {}
diff --git a/spiderweb/response.py b/spiderweb/response.py
index e4e701e..01cf493 100644
--- a/spiderweb/response.py
+++ b/spiderweb/response.py
@@ -1,11 +1,16 @@
import datetime
import json
from typing import Any
+import mimetypes
+from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.exceptions import GeneralException
from spiderweb.request import Request
+mimetypes.init()
+
+
class HttpResponse:
def __init__(
self,
@@ -20,7 +25,8 @@ class HttpResponse:
self.context = context if context else {}
self.status_code = status_code
self.headers = headers if headers else {}
- self.headers["Content-Type"] = "text/html; charset=utf-8"
+ if not self.headers.get("Content-Type"):
+ self.headers["Content-Type"] = "text/html; charset=utf-8"
self.headers["Server"] = "Spiderweb"
self.headers["Date"] = datetime.datetime.now(tz=datetime.UTC).strftime(
"%a, %d %b %Y %H:%M:%S GMT"
@@ -33,6 +39,19 @@ class HttpResponse:
return str(self.body)
+class FileResponse(HttpResponse):
+ def __init__(self, filename, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.filename = filename
+ self.content_type = mimetypes.guess_type(self.filename)[0]
+ self.headers["Content-Type"] = self.content_type
+
+ def render(self) -> str:
+ with open(self.filename, 'rb') as f:
+ self.body = f.read().decode(DEFAULT_ENCODING)
+ return self.body
+
+
class JsonResponse(HttpResponse):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/spiderweb/utils.py b/spiderweb/utils.py
index 0235b45..559b231 100644
--- a/spiderweb/utils.py
+++ b/spiderweb/utils.py
@@ -5,3 +5,8 @@ def import_by_string(name):
for comp in components[1:]:
mod = getattr(mod, comp)
return mod
+
+
+def is_safe_path(path: str) -> bool:
+ # this cannot possibly catch all issues
+ return not ".." in str(path)
diff --git a/static_files/aaaaaa.gif b/static_files/aaaaaa.gif
new file mode 100644
index 0000000..7e96c95
Binary files /dev/null and b/static_files/aaaaaa.gif differ
diff --git a/templates/test.html b/templates/test.html
index 1806f16..db08db0 100644
--- a/templates/test.html
+++ b/templates/test.html
@@ -12,4 +12,7 @@
The value of request.spiderweb
is {{ request.spiderweb }}. If this is True,
middleware is working.
+
+
+
{% endblock %}