diff --git a/example.py b/example.py index 508d8d4..53f7a37 100644 --- a/example.py +++ b/example.py @@ -16,12 +16,14 @@ app = WebServer( "example_middleware.RedirectMiddleware", "example_middleware.ExplodingMiddleware", ], + staticfiles_dirs=["static_files"], append_slash=False, # default ) @app.route("/") def index(request): + print(app.BASE_DIR) return TemplateResponse(request, "test.html", context={"value": "TEST!"}) diff --git a/pyproject.toml b/pyproject.toml index 1687b6b..2c5a640 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "spiderweb" -version = "0.6.0" +version = "0.7.0" description = "A small web framework, just big enough to hold your average spider." authors = ["Joe Kaufeld "] readme = "README.md" diff --git a/spiderweb/constants.py b/spiderweb/constants.py new file mode 100644 index 0000000..8f1a223 --- /dev/null +++ b/spiderweb/constants.py @@ -0,0 +1,2 @@ +DEFAULT_ALLOWED_METHODS = ["GET"] +DEFAULT_ENCODING = "ISO-8859-1" diff --git a/spiderweb/exceptions.py b/spiderweb/exceptions.py index dd96ee0..3acc7f1 100644 --- a/spiderweb/exceptions.py +++ b/spiderweb/exceptions.py @@ -1,7 +1,7 @@ class SpiderwebException(Exception): # parent error class; all child exceptions should inherit from this def __str__(self): - return f"{self.__class__.__name__}({self.code}, {self.msg})" + return f"{self.__class__.__name__}()" class SpiderwebNetworkException(SpiderwebException): @@ -12,6 +12,8 @@ class SpiderwebNetworkException(SpiderwebException): self.msg = msg self.desc = desc + def __str__(self): + return f"{self.__class__.__name__}({self.code}, {self.msg})" class APIError(SpiderwebNetworkException): pass diff --git a/spiderweb/main.py b/spiderweb/main.py index e130686..94046ed 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -2,8 +2,10 @@ # https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55 # # Extensively modified by @itsthejoker -from datetime import datetime, timedelta +import inspect +import os import re +import pathlib import signal import time import traceback @@ -16,6 +18,7 @@ from typing import Callable, Any, NoReturn from cryptography.fernet import Fernet from jinja2 import Environment, FileSystemLoader +from spiderweb.constants import DEFAULT_ENCODING, DEFAULT_ALLOWED_METHODS from spiderweb.converters import * # noqa: F403 from spiderweb.default_responses import * # noqa: F403 from spiderweb.exceptions import ( @@ -33,17 +36,13 @@ from spiderweb.response import ( HttpResponse, JsonResponse, TemplateResponse, - RedirectResponse, + RedirectResponse, FileResponse, ) -from spiderweb.utils import import_by_string - +from spiderweb.utils import import_by_string, is_safe_path log = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) -DEFAULT_ALLOWED_METHODS = ["GET"] -DEFAULT_ENCODING = "utf-8" - def route(path): def outer(func): @@ -55,6 +54,16 @@ def route(path): return outer +def send_file(request, filename: str) -> HttpResponse: + for folder in request.server.staticfiles_dirs: + requested_path = request.server.BASE_DIR / folder / filename + if os.path.exists(requested_path): + if not is_safe_path(requested_path): + raise NotFound + return FileResponse(filename=requested_path) + raise NotFound + + class DummyRedirectRoute: def __init__(self, location): self.location = location @@ -80,6 +89,7 @@ class WebServer(HTTPServer): templates_dirs: list[str] = None, middleware: list[str] = None, append_slash: bool = False, + staticfiles_dirs: list[str] = None, secret_key: str = None, ): """ @@ -95,6 +105,7 @@ class WebServer(HTTPServer): port = port if port else 8000 self.append_slash = append_slash self.templates_dirs = templates_dirs + self.staticfiles_dirs = staticfiles_dirs self.middleware = middleware if middleware else [] self.secret_key = secret_key if secret_key else self._create_secret_key() self.fernet = Fernet(self.key) @@ -102,6 +113,8 @@ class WebServer(HTTPServer): self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS self._thread = None + self.BASE_DIR = self.get_caller_filepath() + if self.middleware: middleware_by_reference = [] for m in self.middleware: @@ -115,6 +128,7 @@ class WebServer(HTTPServer): self.env = Environment(loader=FileSystemLoader(self.templates_dirs)) else: self.env = None + server_address = (addr, port) self.__addr = addr @@ -131,11 +145,25 @@ class WebServer(HTTPServer): for route in method._routes: self.add_route(route, method) + if self.staticfiles_dirs: + for static_dir in self.staticfiles_dirs: + static_dir = pathlib.Path(static_dir) + if not pathlib.Path(self.BASE_DIR / static_dir).exists(): + log.error(f"Static files directory '{str(static_dir)}' does not exist.") + raise ConfigError + self.add_route(r"/static/", send_file) + try: super().__init__(server_address, self.handler_class) except OSError: raise GeneralException("Port already in use.") + def get_caller_filepath(self): + """Figure out who called us and return their path.""" + stack = inspect.stack() + caller_frame = stack[1] + return pathlib.Path(caller_frame.filename).parent.parent + def convert_path(self, path: str): """Convert a path to a regex.""" parts = path.split("/") @@ -162,11 +190,13 @@ class WebServer(HTTPServer): if self.convert_path(path) in self.handler_class._routes: raise ConfigError(f"Route '{path}' already exists.") - def add_route(self, path: str, method: Callable, allowed_methods: list[str]): + def add_route(self, path: str, method: Callable, allowed_methods: None|list[str] = None): """Add a route to the server.""" if not hasattr(self.handler_class, "_routes"): setattr(self.handler_class, "_routes", {}) + allowed_methods = allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS + if self.append_slash and not path.endswith("/"): updated_path = path + "/" self.check_for_route_duplicates(updated_path) @@ -300,6 +330,7 @@ class RequestHandler(BaseHTTPRequestHandler): method=self.command, headers=self.headers, path=self.path, + server=self.server, ) # I can't help the naming convention of these because that's what @@ -319,6 +350,8 @@ class RequestHandler(BaseHTTPRequestHandler): request.content = content self.handle_request(request) + + def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]: for option in self._routes.keys(): if match_data := option.match(path): @@ -335,19 +368,18 @@ class RequestHandler(BaseHTTPRequestHandler): return http500 return view - def _fire_response(self, resp: HttpResponse): - self.send_response(resp.status_code) - content = resp.render() + def _fire_response(self, status: int=200, content: str=None, headers: dict[str, str | int]=None): + self.send_response(status) self.send_header("Content-Length", str(len(content))) - if resp.headers: - for key, value in resp.headers.items(): + if headers: + for key, value in headers.items(): self.send_header(key, value) self.end_headers() self.wfile.write(bytes(content, DEFAULT_ENCODING)) def fire_response(self, request: Request, resp: HttpResponse): try: - self._fire_response(resp) + self._fire_response(status=resp.status_code, content=resp.render(), headers=resp.headers) except APIError: raise except ConnectionAbortedError as e: diff --git a/spiderweb/request.py b/spiderweb/request.py index 83e256f..44aec1d 100644 --- a/spiderweb/request.py +++ b/spiderweb/request.py @@ -11,6 +11,7 @@ class Request: path=None, url=None, query_params=None, + server=None ): self.content: str = content self.body: str = body @@ -19,6 +20,7 @@ class Request: self.path: str = path self.url = url self.query_params = query_params + self.server = server self.GET = {} self.POST = {} diff --git a/spiderweb/response.py b/spiderweb/response.py index e4e701e..01cf493 100644 --- a/spiderweb/response.py +++ b/spiderweb/response.py @@ -1,11 +1,16 @@ import datetime import json from typing import Any +import mimetypes +from spiderweb.constants import DEFAULT_ENCODING from spiderweb.exceptions import GeneralException from spiderweb.request import Request +mimetypes.init() + + class HttpResponse: def __init__( self, @@ -20,7 +25,8 @@ class HttpResponse: self.context = context if context else {} self.status_code = status_code self.headers = headers if headers else {} - self.headers["Content-Type"] = "text/html; charset=utf-8" + if not self.headers.get("Content-Type"): + self.headers["Content-Type"] = "text/html; charset=utf-8" self.headers["Server"] = "Spiderweb" self.headers["Date"] = datetime.datetime.now(tz=datetime.UTC).strftime( "%a, %d %b %Y %H:%M:%S GMT" @@ -33,6 +39,19 @@ class HttpResponse: return str(self.body) +class FileResponse(HttpResponse): + def __init__(self, filename, *args, **kwargs): + super().__init__(*args, **kwargs) + self.filename = filename + self.content_type = mimetypes.guess_type(self.filename)[0] + self.headers["Content-Type"] = self.content_type + + def render(self) -> str: + with open(self.filename, 'rb') as f: + self.body = f.read().decode(DEFAULT_ENCODING) + return self.body + + class JsonResponse(HttpResponse): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/spiderweb/utils.py b/spiderweb/utils.py index 0235b45..559b231 100644 --- a/spiderweb/utils.py +++ b/spiderweb/utils.py @@ -5,3 +5,8 @@ def import_by_string(name): for comp in components[1:]: mod = getattr(mod, comp) return mod + + +def is_safe_path(path: str) -> bool: + # this cannot possibly catch all issues + return not ".." in str(path) diff --git a/static_files/aaaaaa.gif b/static_files/aaaaaa.gif new file mode 100644 index 0000000..7e96c95 Binary files /dev/null and b/static_files/aaaaaa.gif differ diff --git a/templates/test.html b/templates/test.html index 1806f16..db08db0 100644 --- a/templates/test.html +++ b/templates/test.html @@ -12,4 +12,7 @@ The value of request.spiderweb is {{ request.spiderweb }}. If this is True, middleware is working.

+

+ AAAAAAAAAA +

{% endblock %}