✨ session middleware and databases!
This commit is contained in:
parent
173325731b
commit
9d4dffb358
10
example.py
10
example.py
@ -15,6 +15,7 @@ from spiderweb.response import (
|
||||
app = SpiderwebRouter(
|
||||
templates_dirs=["templates"],
|
||||
middleware=[
|
||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||
"example_middleware.TestMiddleware",
|
||||
"example_middleware.RedirectMiddleware",
|
||||
@ -72,6 +73,15 @@ def form(request: CommentForm):
|
||||
return TemplateResponse(request, "form.html")
|
||||
|
||||
|
||||
@app.route("/session")
|
||||
def session(request):
|
||||
if "test" not in request.SESSION:
|
||||
request.SESSION["test"] = 0
|
||||
else:
|
||||
request.SESSION["test"] += 1
|
||||
return HttpResponse(body=f"Session test: {request.SESSION['test']}")
|
||||
|
||||
|
||||
@app.route("/cookies")
|
||||
def cookies(request):
|
||||
print("request.COOKIES: ", request.COOKIES)
|
||||
|
2
poetry.lock
generated
2
poetry.lock
generated
@ -641,4 +641,4 @@ files = [
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "4fa8ab616be6891780300d4d66e9fb936aeac75bd3448e04073b2835acf9aadd"
|
||||
content-hash = "84633fc94c48c2a05b5ec77367ad29f327be1dc249a6e4cb76b50ebbe14739b5"
|
||||
|
@ -1,6 +1,10 @@
|
||||
from peewee import DatabaseProxy
|
||||
|
||||
DEFAULT_ALLOWED_METHODS = ["GET"]
|
||||
DEFAULT_ENCODING = "ISO-8859-1"
|
||||
__version__ = "0.10.0"
|
||||
|
||||
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
|
||||
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
||||
|
||||
DATABASE_PROXY = DatabaseProxy()
|
||||
|
@ -0,0 +1,98 @@
|
||||
from peewee import Model, Field, SchemaManager, DatabaseProxy
|
||||
|
||||
from spiderweb.constants import DATABASE_PROXY
|
||||
|
||||
|
||||
class MigrationsNeeded(ExceptionGroup): ...
|
||||
|
||||
|
||||
class MigrationRequired(Exception): ...
|
||||
|
||||
|
||||
class SpiderwebModel(Model):
|
||||
|
||||
@classmethod
|
||||
def check_for_needed_migration(cls):
|
||||
current_model_fields: dict[str, Field] = cls._meta.fields
|
||||
current_db_fields = {
|
||||
c.name: {
|
||||
"data_type": c.data_type,
|
||||
"null": c.null,
|
||||
"primary_key": c.primary_key,
|
||||
"default": c.default,
|
||||
}
|
||||
for c in cls._meta.database.get_columns(cls._meta.table_name)
|
||||
}
|
||||
problems = []
|
||||
s = SchemaManager(cls, cls._meta.database)
|
||||
ctx = s._create_context()
|
||||
for field_name, field_obj in current_model_fields.items():
|
||||
db_version = current_db_fields.get(field_obj.column_name)
|
||||
if not db_version:
|
||||
problems.append(
|
||||
MigrationRequired(f"Field {field_name} not found in DB.")
|
||||
)
|
||||
continue
|
||||
|
||||
if field_obj.field_type == "VARCHAR":
|
||||
field_obj.max_length = field_obj.max_length or 255
|
||||
if (
|
||||
cls._meta.fields[field_name].ddl_datatype(ctx).sql
|
||||
!= db_version["data_type"]
|
||||
):
|
||||
problems.append(
|
||||
MigrationRequired(
|
||||
f"CharField `{field_name}` has changed the field type."
|
||||
)
|
||||
)
|
||||
else:
|
||||
if (
|
||||
cls._meta.database.get_context_options()["field_types"][
|
||||
field_obj.field_type
|
||||
]
|
||||
!= db_version["data_type"]
|
||||
):
|
||||
problems.append(
|
||||
MigrationRequired(
|
||||
f"Field `{field_name}` has changed the field type."
|
||||
)
|
||||
)
|
||||
if field_obj.null != db_version["null"]:
|
||||
problems.append(
|
||||
MigrationRequired(
|
||||
f"Field `{field_name}` has changed the nullability."
|
||||
)
|
||||
)
|
||||
if field_obj.__class__.__name__ == "BooleanField":
|
||||
if field_obj.default == False and db_version["default"] not in (
|
||||
False,
|
||||
None,
|
||||
0,
|
||||
):
|
||||
problems.append(
|
||||
MigrationRequired(
|
||||
f"BooleanField `{field_name}` has changed the default value."
|
||||
)
|
||||
)
|
||||
elif field_obj.default == True and db_version["default"] not in (
|
||||
True,
|
||||
1,
|
||||
):
|
||||
problems.append(
|
||||
MigrationRequired(
|
||||
f"BooleanField `{field_name}` has changed the default value."
|
||||
)
|
||||
)
|
||||
else:
|
||||
if field_obj.default != db_version["default"]:
|
||||
problems.append(
|
||||
MigrationRequired(
|
||||
f"Field `{field_name}` has changed the default value."
|
||||
)
|
||||
)
|
||||
|
||||
if problems:
|
||||
raise MigrationsNeeded(f"The model {cls} requires migrations.", problems)
|
||||
|
||||
class Meta:
|
||||
database = DATABASE_PROXY
|
@ -1,8 +1,8 @@
|
||||
from pydantic import EmailStr
|
||||
|
||||
from spiderweb.middleware.pydantic import SpiderwebModel
|
||||
from spiderweb.middleware.pydantic import RequestModel
|
||||
|
||||
|
||||
class CommentForm(SpiderwebModel):
|
||||
class CommentForm(RequestModel):
|
||||
email: EmailStr
|
||||
comment: str
|
||||
|
@ -1,6 +1,10 @@
|
||||
class SpiderwebException(Exception):
|
||||
# parent error class; all child exceptions should inherit from this
|
||||
def __str__(self):
|
||||
name = self.__class__.__name__
|
||||
msg = self.args[0] if len(self.args) > 0 else ""
|
||||
if msg:
|
||||
return f"{name}() - {msg}"
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@ class SpiderwebRequestHandler(WSGIRequestHandler):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class LocalServerMiddleware:
|
||||
class LocalServerMixin:
|
||||
"""Cannot be called on its own. Requires context of SpiderwebRouter."""
|
||||
|
||||
addr: str
|
||||
|
@ -8,9 +8,15 @@ from typing import Optional, Callable
|
||||
from wsgiref.simple_server import WSGIServer
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from peewee import Database, SqliteDatabase
|
||||
|
||||
from spiderweb.middleware import MiddlewareMiddleware
|
||||
from spiderweb.constants import DEFAULT_ENCODING, DEFAULT_ALLOWED_METHODS
|
||||
from spiderweb.middleware import MiddlewareMixin
|
||||
from spiderweb.constants import (
|
||||
DATABASE_PROXY,
|
||||
DEFAULT_ENCODING,
|
||||
DEFAULT_ALLOWED_METHODS,
|
||||
)
|
||||
from spiderweb.db import SpiderwebModel
|
||||
from spiderweb.default_views import * # noqa: F403
|
||||
from spiderweb.exceptions import (
|
||||
ConfigError,
|
||||
@ -19,24 +25,23 @@ from spiderweb.exceptions import (
|
||||
NoResponseError,
|
||||
SpiderwebNetworkException,
|
||||
)
|
||||
from spiderweb.local_server import LocalServerMiddleware
|
||||
from spiderweb.local_server import LocalServerMixin
|
||||
from spiderweb.request import Request
|
||||
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
|
||||
from spiderweb.routes import RoutesMiddleware
|
||||
from spiderweb.secrets import FernetMiddleware
|
||||
from spiderweb.routes import RoutesMixin
|
||||
from spiderweb.secrets import FernetMixin
|
||||
from spiderweb.utils import get_http_status_by_code
|
||||
|
||||
file_logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class SpiderwebRouter(
|
||||
LocalServerMiddleware, MiddlewareMiddleware, RoutesMiddleware, FernetMiddleware
|
||||
):
|
||||
class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixin):
|
||||
def __init__(
|
||||
self,
|
||||
addr: str = None,
|
||||
port: int = None,
|
||||
db: Optional[Database] = None,
|
||||
templates_dirs: list[str] = None,
|
||||
middleware: list[str] = None,
|
||||
append_slash: bool = False,
|
||||
@ -44,6 +49,12 @@ class SpiderwebRouter(
|
||||
routes: list[list[str | Callable | dict]] = None,
|
||||
error_routes: dict[str, Callable] = None,
|
||||
secret_key: str = None,
|
||||
session_max_age=60 * 60 * 24 * 14, # 2 weeks
|
||||
session_cookie_name="swsession",
|
||||
session_cookie_secure=False, # should be true if serving over HTTPS
|
||||
session_cookie_http_only=True,
|
||||
session_cookie_same_site="lax",
|
||||
session_cookie_path="/",
|
||||
log=None,
|
||||
):
|
||||
self._routes = {}
|
||||
@ -59,9 +70,17 @@ class SpiderwebRouter(
|
||||
self.middleware = middleware if middleware else []
|
||||
self.secret_key = secret_key if secret_key else self.generate_key()
|
||||
|
||||
# session middleware
|
||||
self.session_max_age = session_max_age
|
||||
self.session_cookie_name = session_cookie_name
|
||||
self.session_cookie_secure = session_cookie_secure
|
||||
self.session_cookie_http_only = session_cookie_http_only
|
||||
self.session_cookie_same_site = session_cookie_same_site
|
||||
self.session_cookie_path = session_cookie_path
|
||||
|
||||
self.DEFAULT_ENCODING = DEFAULT_ENCODING
|
||||
self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS
|
||||
self.log = log if log else file_logger
|
||||
self.log: logging.Logger = log if log else file_logger
|
||||
|
||||
# for using .start() and .stop()
|
||||
self._thread: Optional[Thread] = None
|
||||
@ -71,6 +90,13 @@ class SpiderwebRouter(
|
||||
self.init_fernet()
|
||||
self.init_middleware()
|
||||
|
||||
self.db = db or SqliteDatabase(self.BASE_DIR / "spiderweb.db")
|
||||
# give the models the db connection
|
||||
DATABASE_PROXY.initialize(self.db)
|
||||
self.db.create_tables(SpiderwebModel.__subclasses__())
|
||||
for model in SpiderwebModel.__subclasses__():
|
||||
model.check_for_needed_migration()
|
||||
|
||||
if self.routes:
|
||||
self.add_routes()
|
||||
|
||||
|
@ -2,13 +2,14 @@ from typing import Callable, ClassVar
|
||||
|
||||
from .base import SpiderwebMiddleware as SpiderwebMiddleware
|
||||
from .csrf import CSRFMiddleware as CSRFMiddleware
|
||||
from .sessions import SessionMiddleware as SessionMiddleware
|
||||
from ..exceptions import ConfigError, UnusedMiddleware
|
||||
from ..request import Request
|
||||
from ..response import HttpResponse
|
||||
from ..utils import import_by_string
|
||||
|
||||
|
||||
class MiddlewareMiddleware:
|
||||
class MiddlewareMixin:
|
||||
"""Cannot be called on its own. Requires context of SpiderwebRouter."""
|
||||
|
||||
middleware: list[ClassVar]
|
||||
|
@ -7,7 +7,7 @@ from spiderweb.request import Request
|
||||
from spiderweb.response import JsonResponse
|
||||
|
||||
|
||||
class SpiderwebModel(BaseModel, Request):
|
||||
class RequestModel(BaseModel, Request):
|
||||
# type hinting shenanigans that allow us to annotate Request objects
|
||||
# with the pydantic models we want to validate them with, but doesn't
|
||||
# break the Request object's ability to be used as a Request object
|
||||
|
117
spiderweb/middleware/sessions.py
Normal file
117
spiderweb/middleware/sessions.py
Normal file
@ -0,0 +1,117 @@
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
|
||||
from peewee import CharField, TextField, DateTimeField, BooleanField
|
||||
|
||||
from spiderweb.middleware import SpiderwebMiddleware
|
||||
from spiderweb.request import Request
|
||||
from spiderweb.response import HttpResponse
|
||||
from spiderweb.db import SpiderwebModel
|
||||
from spiderweb.utils import generate_key, is_jsonable
|
||||
|
||||
|
||||
class Session(SpiderwebModel):
|
||||
session_key = CharField(max_length=64)
|
||||
user_id = CharField(max_length=64, null=True)
|
||||
is_authenticated = BooleanField(default=False)
|
||||
session_data = TextField()
|
||||
created_at = DateTimeField()
|
||||
last_active = DateTimeField()
|
||||
ip_address = CharField(max_length=30)
|
||||
user_agent = TextField()
|
||||
|
||||
|
||||
class SessionMiddleware(SpiderwebMiddleware):
|
||||
def process_request(self, request: Request):
|
||||
existing_session = (
|
||||
Session.select()
|
||||
.where(
|
||||
Session.session_key
|
||||
== request.COOKIES.get(self.server.session_cookie_name),
|
||||
Session.ip_address == request.META.get("client_address"),
|
||||
Session.user_agent == request.headers.get("HTTP_USER_AGENT"),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
new_session = False
|
||||
if not existing_session:
|
||||
new_session = True
|
||||
elif datetime.now() - existing_session.created_at > timedelta(
|
||||
seconds=self.server.session_max_age
|
||||
):
|
||||
existing_session.delete_instance()
|
||||
new_session = True
|
||||
|
||||
if new_session:
|
||||
request.SESSION = {}
|
||||
request._session["id"] = generate_key()
|
||||
request._session["new_session"] = True
|
||||
return
|
||||
|
||||
request.SESSION = json.loads(existing_session.session_data)
|
||||
request._session["id"] = existing_session.session_key
|
||||
existing_session.save()
|
||||
|
||||
def process_response(self, request: Request, response: HttpResponse):
|
||||
cookie_settings = {
|
||||
"max_age": self.server.session_max_age,
|
||||
"same_site": self.server.session_cookie_same_site,
|
||||
"http_only": self.server.session_cookie_http_only,
|
||||
"secure": self.server.session_cookie_secure
|
||||
or request.META.get("HTTPS", False),
|
||||
"path": self.server.session_cookie_path,
|
||||
}
|
||||
|
||||
# if a new session has been requested, ignore everything else and make that happen
|
||||
if request._session["new_session"]:
|
||||
# we generated a new one earlier, so we can use it now
|
||||
session_key = request._session["id"]
|
||||
response.set_cookie(
|
||||
self.server.session_cookie_name,
|
||||
session_key,
|
||||
**cookie_settings,
|
||||
)
|
||||
session = Session(
|
||||
session_key=session_key,
|
||||
session_data=json.dumps(request.SESSION),
|
||||
created_at=datetime.now(),
|
||||
last_active=datetime.now(),
|
||||
ip_address=request.META.get("client_address"),
|
||||
user_agent=request.headers.get("HTTP_USER_AGENT"),
|
||||
)
|
||||
session.save()
|
||||
return
|
||||
|
||||
# Otherwise, we can save the one we already have.
|
||||
session_key = request._session["id"]
|
||||
# update the session expiration time
|
||||
response.set_cookie(
|
||||
self.server.session_cookie_name,
|
||||
session_key,
|
||||
**cookie_settings,
|
||||
)
|
||||
|
||||
session = (
|
||||
Session.select()
|
||||
.where(
|
||||
Session.session_key == session_key,
|
||||
Session.ip_address == request.META.get("client_address"),
|
||||
Session.user_agent == request.headers.get("HTTP_USER_AGENT"),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not session:
|
||||
if not is_jsonable(request.SESSION):
|
||||
raise ValueError("Session data is not JSON serializable.")
|
||||
session = Session(
|
||||
session_key=session_key,
|
||||
session_data=json.dumps(request.SESSION),
|
||||
created_at=datetime.now(),
|
||||
last_active=datetime.now(),
|
||||
ip_address=request.META.get("client_address"),
|
||||
user_agent=request.META.get("HTTP_USER_AGENT"),
|
||||
)
|
||||
else:
|
||||
session.session_data = json.dumps(request.SESSION)
|
||||
session.last_active = datetime.now()
|
||||
session.save()
|
@ -2,6 +2,7 @@ import json
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from spiderweb.constants import DEFAULT_ENCODING
|
||||
from spiderweb.utils import get_client_address
|
||||
|
||||
|
||||
class Request:
|
||||
@ -27,6 +28,9 @@ class Request:
|
||||
self.POST = {}
|
||||
self.META = {}
|
||||
self.COOKIES = {}
|
||||
# only used for the session middleware
|
||||
self.SESSION = {}
|
||||
self._session: dict = {"new_session": False, "id": None}
|
||||
# only used for the pydantic middleware and only on POST requests
|
||||
self.validated_data = {}
|
||||
|
||||
@ -50,6 +54,8 @@ class Request:
|
||||
self.headers[k] = v
|
||||
|
||||
def populate_meta(self) -> None:
|
||||
# all caps fields are from WSGI, lowercase names
|
||||
# are custom
|
||||
fields = [
|
||||
"SERVER_PROTOCOL",
|
||||
"SERVER_SOFTWARE",
|
||||
@ -66,6 +72,7 @@ class Request:
|
||||
]
|
||||
for f in fields:
|
||||
self.META[f] = self.environ.get(f)
|
||||
self.META["client_address"] = get_client_address(self.environ)
|
||||
|
||||
def populate_cookies(self) -> None:
|
||||
if cookies := self.environ.get("HTTP_COOKIE"):
|
||||
|
@ -24,7 +24,7 @@ class DummyRedirectRoute:
|
||||
return RedirectResponse(self.location)
|
||||
|
||||
|
||||
class RoutesMiddleware:
|
||||
class RoutesMixin:
|
||||
"""Cannot be called on its own. Requires context of SpiderwebRouter."""
|
||||
|
||||
# ones that start with underscores are the compiled versions, non-underscores
|
||||
|
@ -3,7 +3,7 @@ from cryptography.fernet import Fernet
|
||||
from spiderweb.constants import DEFAULT_ENCODING
|
||||
|
||||
|
||||
class FernetMiddleware:
|
||||
class FernetMixin:
|
||||
"""Cannot be called on its own. Requires context of SpiderwebRouter."""
|
||||
|
||||
fernet: Fernet
|
||||
|
@ -1,9 +1,16 @@
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from spiderweb.request import Request
|
||||
|
||||
|
||||
VALID_CHARS = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def import_by_string(name):
|
||||
# https://stackoverflow.com/a/547867
|
||||
components = name.split(".")
|
||||
@ -31,8 +38,28 @@ def get_http_status_by_code(code: int) -> Optional[str]:
|
||||
return f"{resp.value} {resp.phrase}"
|
||||
|
||||
|
||||
def is_form_request(request: Request) -> bool:
|
||||
def is_form_request(request: "Request") -> bool:
|
||||
return (
|
||||
"Content-Type" in request.headers
|
||||
and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
|
||||
)
|
||||
|
||||
|
||||
# https://stackoverflow.com/a/7839576
|
||||
def get_client_address(environ: dict) -> str:
|
||||
try:
|
||||
return environ["HTTP_X_FORWARDED_FOR"].split(",")[-1].strip()
|
||||
except KeyError:
|
||||
return environ.get("REMOTE_ADDR", "unknown")
|
||||
|
||||
|
||||
def generate_key(length=64):
|
||||
return "".join(secrets.choice(VALID_CHARS) for _ in range(length))
|
||||
|
||||
|
||||
def is_jsonable(data: str) -> bool:
|
||||
try:
|
||||
json.dumps(data)
|
||||
return True
|
||||
except (TypeError, OverflowError):
|
||||
return False
|
||||
|
Loading…
Reference in New Issue
Block a user