CORS! #1

Merged
jkaufeld merged 9 commits from origins into main 2024-09-02 00:39:35 -04:00
26 changed files with 745 additions and 84 deletions

View File

@ -19,3 +19,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
---
Substantial portions of spiderweb/middleware/cors.py and docs/middleware/cors.md
are from django-cors-headers and are subject to the following license:
MIT License
Copyright (c) 2017 Otto Yiu (https://ottoyiu.com) and other contributors.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,5 +1,22 @@
# spiderweb # spiderweb
<p align="center">
<img
src="https://img.shields.io/pypi/v/spiderweb-framework.svg?style=for-the-badge"
alt="PyPI release version for Spiderweb"
/>
<a href="https://gitmoji.dev">
<img
src="https://img.shields.io/badge/gitmoji-%20😜%20😍-FFDD67.svg?style=for-the-badge"
alt="Gitmoji"
/>
</a>
<img
src="https://img.shields.io/badge/code%20style-black-000000.svg?style=for-the-badge"
alt="Code style: Black"
/>
</p>
As a professional web developer focusing on arcane uses of Django for arcane purposes, it occurred to me a little while ago that I didn't actually know how a web framework _worked_. As a professional web developer focusing on arcane uses of Django for arcane purposes, it occurred to me a little while ago that I didn't actually know how a web framework _worked_.
So I built one. So I built one.

View File

@ -8,13 +8,15 @@ This is `spiderweb`, a WSGI-compatible web framework that's just big enough to h
- Learn a lot - Learn a lot
- Create an unholy blend of Django and Flask - Create an unholy blend of Django and Flask
- Not look at any existing code. Go off of vibes alone and try to solve all the problems I could think of in my own way - Not look at any existing code[^1]. Go off of vibes alone and try to solve all the problems I could think of in my own way
> [!WARNING] > [!WARNING]
> This is a learning project. It should not be used for production without heavy auditing. It's not secure. It's not fast. It's not well-tested. It's not well-documented. It's not well-anything. It's a learning project. > This is a learning project. It should not be used for production without heavy auditing. It's not secure. It's not fast. It's not well-tested. It's not well-documented. It's not well-anything. It's a learning project.
> >
> That being said, it's fun and it works, so I'm counting that as a win. > That being said, it's fun and it works, so I'm counting that as a win.
> [!TIP|style:flat]
> To jump in with both feet, [head over to the quickstart!](quickstart.md)
## Design & Usage Decisions ## Design & Usage Decisions
@ -90,6 +92,7 @@ Simply having these declared in a place that Django can find them is enough, and
Spiderweb takes a middle ground approach: it allows you to declare framework-first arguments on the SpiderwebRouter object, and if you need to pass along other data to other parts of the system (like custom middleware), you can do so by passing in any keyword argument you'd like to the constructor. Spiderweb takes a middle ground approach: it allows you to declare framework-first arguments on the SpiderwebRouter object, and if you need to pass along other data to other parts of the system (like custom middleware), you can do so by passing in any keyword argument you'd like to the constructor.
```python ```python
from spiderweb import SpiderwebRouter
from peewee import SqliteDatabase from peewee import SqliteDatabase
app = SpiderwebRouter( app = SpiderwebRouter(
@ -112,7 +115,6 @@ Here's a non-exhaustive list of things this can do:
- URLs with variables in them a lá Django - URLs with variables in them a lá Django
- Full middleware implementation - Full middleware implementation
- Limit routes by HTTP verbs - Limit routes by HTTP verbs
- (Only GET and POST are implemented right now)
- Custom error routes - Custom error routes
- Built-in dev server - Built-in dev server
- Gunicorn support - Gunicorn support
@ -120,13 +122,11 @@ Here's a non-exhaustive list of things this can do:
- Static files support - Static files support
- Cookies (reading and setting) - Cookies (reading and setting)
- Optional append_slash (with automatic redirects!) - Optional append_slash (with automatic redirects!)
- ~~CSRF middleware implementation~~ (it's there, but it's crappy and unsafe. This might be beyond my skillset.) - CSRF middleware
- CORS middleware
- Optional POST data validation middleware with Pydantic - Optional POST data validation middleware with Pydantic
- Database support (using Peewee, but you can use whatever you want as long as there's a Peewee driver for it)
- Session middleware with built-in session store - Session middleware with built-in session store
- Database support (using Peewee, but you can use whatever you want as long as there's a Peewee driver for it)
- Tests (currently a little over 80% coverage) - Tests (currently a little over 80% coverage)
## What's left to build? [^1]: I mostly succeeded. The way that I'm approaching this is that I did my level best, then looked at (and copied) existing solutions where necessary. At the time of this writing, I did all of it solo except for the CORS middleware. [Read more about it here.](middleware/cors.md)
- Fix CSRF middleware
- Add more HTTP verbs

View File

@ -3,7 +3,7 @@
> the web framework just big enough for a spider > the web framework just big enough for a spider
[GitHub](https://github.com/itsthejoker/spiderweb/) [GitHub](https://github.com/itsthejoker/spiderweb/)
[Get Started](#spiderweb) [Get Started](/README)
![color](#222) ![color](#222)

View File

@ -5,5 +5,6 @@
- [overview](middleware/overview.md) - [overview](middleware/overview.md)
- [session](middleware/sessions.md) - [session](middleware/sessions.md)
- [csrf](middleware/csrf.md) - [csrf](middleware/csrf.md)
- [cors](middleware/cors.md)
- [pydantic](middleware/pydantic.md) - [pydantic](middleware/pydantic.md)
- [writing your own](middleware/custom_middleware.md) - [writing your own](middleware/custom_middleware.md)

View File

@ -10,3 +10,10 @@
> [!NOTE] > [!NOTE]
> An alert of type 'note' using global style 'callout'. > An alert of type 'note' using global style 'callout'.
> [!TIP|style:flat|label:My own heading|iconVisibility:hidden]
> An alert of type 'tip' using alert specific style 'flat' which overrides global style 'callout'.
> In addition, this alert uses an own heading and hides specific icon.
> [!NOTE|icon:fa-solid fa-notes]
> A custom icon!

View File

@ -48,6 +48,7 @@
</script> </script>
<!-- Docsify v4 --> <!-- Docsify v4 -->
<script src="//cdn.jsdelivr.net/npm/docsify@4"></script> <script src="//cdn.jsdelivr.net/npm/docsify@4"></script>
<script src="//cdn.jsdelivr.net/npm/docsify/lib/plugins/external-script.min.js"></script>
<script src="//cdn.jsdelivr.net/npm/prismjs@1/components/prism-python.min.js"></script> <script src="//cdn.jsdelivr.net/npm/prismjs@1/components/prism-python.min.js"></script>
<!-- admonitions --> <!-- admonitions -->
<script src="https://unpkg.com/docsify-plugin-flexible-alerts"></script> <script src="https://unpkg.com/docsify-plugin-flexible-alerts"></script>
@ -57,5 +58,11 @@
<script src="https://cdn.jsdelivr.net/npm/docsify-tabs@1"></script> <script src="https://cdn.jsdelivr.net/npm/docsify-tabs@1"></script>
<!-- search --> <!-- search -->
<script src="//cdn.jsdelivr.net/npm/docsify/lib/plugins/search.min.js"></script> <script src="//cdn.jsdelivr.net/npm/docsify/lib/plugins/search.min.js"></script>
<!-- footnotes -->
<script src="//cdn.jsdelivr.net/npm/@sy-records/docsify-footnotes/lib/index.min.js"></script>
<!-- click to copy in code blocks -->
<script src="//cdn.jsdelivr.net/npm/docsify-copy-code/dist/docsify-copy-code.min.js"></script>
<script src="https://kit.fontawesome.com/940400877f.js" crossorigin="anonymous"></script>
<script defer data-domain="itsthejoker.github.io/spiderweb" src="https://plausible.io/js/script.js"></script>
</body> </body>
</html> </html>

177
docs/middleware/cors.md Normal file
View File

@ -0,0 +1,177 @@
# cors middleware
```python
from spiderweb import SpiderwebRouter
app = SpiderwebRouter(
middleware=["spiderweb.middleware.cors.CorsMiddleware"],
)
```
CORS, or Cross-Origin Resource Sharing, is an incredibly important piece of how different parts of the web communicate. As such, there is a CORS handler built into Spiderweb.
> [!TIP]
> The CorsMiddleware should be placed as high as possible in the middleware list, as it needs as much control as possible over requests and responses.
This implementation is lovingly ~~ripped~~ ~~lifted~~ borrowed from [Django CORS Headers](https://github.com/adamchainz/django-cors-headers/), an industry-standard implementation for handing CORS that has existed for over a decade. It is essentially and functionally the same. The below doc is ~~copy-and-pasted~~ also borrowed from Django CORS Headers, with updates where needed. (They just already do a great job of explaining these things.)
The available configurations are listed below, and you must set at least one of three following settings:
- `cors_allowed_origins`
- `cors_allowed_origin_regexes`
- `cors_allow_all_origins`
## cors_allowed_origins
A list of origins that are authorized to make cross-site HTTP requests. The origins in this setting will be allowed, and the requesting origin will be echoed back to the client in the access-control-allow-origin header. Defaults to `[]`.
An Origin is defined as a URI scheme + hostname + port, or one of the special values 'null' or 'file://'. Default ports (HTTPS = 443, HTTP = 80) are optional.
```python
app = SpiderwebRouter(
cors_allowed_origins=[
"https://example.com",
"https://sub.example.com",
"http://localhost:8080",
"http://127.0.0.1:9000",
]
)
```
## cors_allowed_origin_regexes
A list of strings representing regexes that match Origins that are authorized to make cross-site HTTP requests. Defaults to `[]`. Useful when `cors_allowed_origins` is impractical, such as when you have a large number of subdomains.
```python
app = SpiderwebRouter(
cors_allowed_origin_regexes=[
r"^https://\w+\.example\.com$",
]
)
```
## cors_allow_all_origins
If `True`, all origins will be allowed. Other settings restricting allowed origins will be ignored. Defaults to `False`.
Setting this to `True` can be _dangerous_, as it allows any website to make cross-origin requests to yours. Generally you'll want to restrict the list of allowed origins with `cors_allowed_origins` or `cors_allowed_origin_regexes`.
```python
app = SpiderwebRouter(
cors_allow_all_origins=True
)
```
# Optional settings
All the following settings have sensible defaults, but are available if you want to tweak them for your use case. For most cases, you'll just want to leave these alone.
## cors_urls_regex
A regex which restricts the URL's for which the CORS headers will be sent. Defaults to `r'^.*$'`, i.e. match all URL's. Useful when you only need CORS on a part of your site, e.g. an API at /api/.
```python
app = SpiderwebRouter(
cors_urls_regex=r"^/api/.*$"
)
```
## cors_allow_methods
A list of HTTP verbs that are allowed for the actual request. Defaults to:
```python
DEFAULT_CORS_ALLOW_METHODS = (
"DELETE",
"GET",
"OPTIONS",
"PATCH",
"POST",
"PUT",
)
```
The default can be imported from `spiderweb.constants` so you can just extend it with custom methods. This allows you to keep up to date with any future changes. For example:
```python
from spiderweb.constants import DEFAULT_CORS_ALLOW_METHODS as default_methods
app = SpiderwebRouter(
cors_allow_methods=(
*default_methods,
"POKE",
)
)
```
## cors_allow_headers
The list of non-standard HTTP headers that you permit in requests from the browser. Sets the `Access-Control-Allow-Headers` header in responses to preflight requests. Defaults to:
```python
CORS_ALLOW_HEADERS = (
"accept",
"authorization",
"content-type",
"user-agent",
"x-csrftoken",
"x-requested-with",
)
```
The default can be imported from `spiderweb.constants` so you can extend it with your custom headers. This allows you to keep up to date with any future changes. For example:
```python
from spiderweb.constants import DEFAULT_CORS_ALLOW_HEADERS as default_headers
app = SpiderwebRouter(
cors_allow_headers=(
*default_headers,
"my-custom-header",
)
)
```
## cors_expose_headers
The list of extra HTTP headers to expose to the browser, in addition to the default [safelisted headers](https://developer.mozilla.org/en-US/docs/Glossary/CORS-safelisted_response_header). If non-empty, these are declared in the [`access-control-expose-headers` header](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers). Defaults to `[]`.
## cors_preflight_max_age
The number of seconds (integer) the browser can cache the preflight response. This sets the [`access-control-max-age` header](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age) in preflight responses. If this is 0 (or any falsey value), no max age header will be sent. Defaults to `86400` (one day).
Note: Browsers send [preflight requests](https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request) before certain “non-simple” requests, to check they will be allowed. Read more about it in the [CORS MDN article](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#preflighted_requests).
## cors_allow_credentials
If `True`, cookies will be allowed to be included in cross-site HTTP requests. This sets the [`Access-Control-Allow-Credentials` header](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/access-control-allow-credentials) in preflight and normal responses. Defaults to `False`.
> [!NOTE]
> The session cookie, by default, uses `Lax` as the security setting, which will prevent the session cookie from being sent cross-domain. If you want to use `cors_allow_credentials`, you will need to change `session_cookie_same_site` to `none` to bypass the security restriction.
## cors_allow_private_network
If `True`, allow requests from sites on “public” IP to this server on a “private” IP. In such cases, browsers send an extra CORS header `access-control-request-private-network`, for which `OPTIONS` responses must contain `access-control-allow-private-network: true`. Defaults to `False`.
Refer to:
- [Local Network Access](https://wicg.github.io/local-network-access/), the W3C Community Draft specification.
- [Private Network Access: introducing preflights](https://developer.chrome.com/blog/private-network-access-preflight/), a blog post from the Google Chrome team.
# A note about CSRF
Most sites will need to take advantage of the Cross-Site Request Forgery protection built into Spiderweb. CORS and CSRF are separate, and Spiderweb wants you to be explicit about how the domains that you work with fit together. If you need to exempt sites from the [`Referer`](https://en.wikipedia.org/wiki/HTTP_referer#Etymology) checking that Spiderweb performs does on secure requests, you can use the `csrf_trusted_origins` setting. For example:
```python
from spiderweb.constants import DEFAULT_CORS_ALLOW_HEADERS as default_headers
app = SpiderwebRouter(
cors_allowed_origins=[
"https://read-only.example.com",
"https://read-and-write.example.com",
],
csrf_trusted_origins=[
"https://read-and-write.example.com",
]
)
```

View File

@ -11,9 +11,6 @@ app = SpiderwebRouter(
) )
``` ```
> [!DANGER]
> The CSRFMiddleware is incomplete at best and dangerous at worst. I am not a security expert, and my implementation is [very susceptible to the thing it is meant to prevent](https://en.wikipedia.org/wiki/Cross-site_request_forgery). While this is an big issue (and moderately hilarious), the middleware is still provided to you in its unfinished state. Be aware.
Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more. Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more.
> [!TIP] > [!TIP]

View File

@ -1,5 +1,3 @@
from spiderweb import HttpResponse
# writing your own middleware # writing your own middleware
Sometimes you want to run the same code on every request or every response (or both!). Lots of processing happens in the middleware layer, and if you want to write your own, all you have to do is write a quick class and put it in a place that Spiderweb can find it. A piece of middleware only needs two things to be successful: Sometimes you want to run the same code on every request or every response (or both!). Lots of processing happens in the middleware layer, and if you want to write your own, all you have to do is write a quick class and put it in a place that Spiderweb can find it. A piece of middleware only needs two things to be successful:
@ -57,6 +55,54 @@ Unlike `process_request`, returning a value here doesn't change anything. We're
This is a helper function that is available for you to override; it's not often used by middleware, but there are some ([like the pydantic middleware](pydantic.md)) that call `on_error` when there is a validation failure. This is a helper function that is available for you to override; it's not often used by middleware, but there are some ([like the pydantic middleware](pydantic.md)) that call `on_error` when there is a validation failure.
## checks
If you want to have runtime verifications that ensure that everything is running smoothly, you can take advantage of Spiderweb's `checks` feature.
> [!TIP]
> If you just want to run startup checks, you can also tie this in with the `UnusedMiddleware` exception, as it'll trigger after the checks run.
A startup check looks like this:
```python
from spiderweb.exceptions import ConfigError
from spiderweb.server_checks import ServerCheck
class MyCheck(ServerCheck):
# You don't have to extract the message out into a top-level
# variable, but it does make testing your middleware easier.
MYMESSAGE = "Something has gone wrong!"
# The function must be called `check` and it takes no args.
def check(self):
if self.server.extra_args.get("mykeyword") != "propervalue":
# Note that we are returning an exception instead of
# raising it. All config errors are collected and then
# raised as a single group of all the errors that
# happened on startup.
# If everything looks good, don't return anything.
return ConfigError(self.MYMESSAGE)
```
> [!TIP]
> You should have one check class per actual check that you want to run, as it will make identifying issues much easier.
You can have as many checks as you'd like, and the base Spiderweb instance is available at `self.server`. All checks must return an exception (**not** raising it!), as they will all be raised at the same time as part of an ExceptionGroup called `StartupErrors`.
To enable your checks, link them to your middleware like this:
```python
class MyMiddleware(SpiderwebMiddleware):
checks = [MyCheck, ADifferentCheck]
def process_request(self, request):
...
```
List as many checks as you need there, and the server will run all of them during startup.
## UnusedMiddleware ## UnusedMiddleware
```python ```python

View File

@ -80,7 +80,7 @@ This is an example view. There are a few things to note here:
> See [declaring routes](routes.md) for more information. > See [declaring routes](routes.md) for more information.
> [!TIP] > [!NOTE]
> Every view must accept a `request` object as its first argument. This object contains all the information about the incoming request, including headers, cookies, and more. > Every view must accept a `request` object as its first argument. This object contains all the information about the incoming request, including headers, cookies, and more.
> >
> There's more that we can pass in, but for now, we'll keep it simple. > There's more that we can pass in, but for now, we'll keep it simple.

View File

@ -15,6 +15,7 @@ from spiderweb.response import (
app = SpiderwebRouter( app = SpiderwebRouter(
templates_dirs=["templates"], templates_dirs=["templates"],
middleware=[ middleware=[
"spiderweb.middleware.cors.CorsMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
"example_middleware.TestMiddleware", "example_middleware.TestMiddleware",

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "spiderweb-framework" name = "spiderweb-framework"
version = "0.11.0" version = "0.12.0"
description = "A small web framework, just big enough for a spider." description = "A small web framework, just big enough for a spider."
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"] authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
readme = "README.md" readme = "README.md"

View File

@ -2,9 +2,26 @@ from peewee import DatabaseProxy
DEFAULT_ALLOWED_METHODS = ["GET"] DEFAULT_ALLOWED_METHODS = ["GET"]
DEFAULT_ENCODING = "UTF-8" DEFAULT_ENCODING = "UTF-8"
__version__ = "0.11.0" __version__ = "0.12.0"
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
DATABASE_PROXY = DatabaseProxy() DATABASE_PROXY = DatabaseProxy()
DEFAULT_CORS_ALLOW_METHODS = (
"DELETE",
"GET",
"OPTIONS",
"PATCH",
"POST",
"PUT",
)
DEFAULT_CORS_ALLOW_HEADERS = (
"accept",
"authorization",
"content-type",
"user-agent",
"x-csrftoken",
"x-requested-with",
)

View File

@ -86,3 +86,7 @@ class UnusedMiddleware(SpiderwebException):
class NoResponseError(SpiderwebException): class NoResponseError(SpiderwebException):
pass pass
class StartupErrors(ExceptionGroup):
pass

View File

@ -1,16 +1,22 @@
import inspect import inspect
import logging import logging
import pathlib import pathlib
import re
import traceback import traceback
import urllib.parse as urlparse import urllib.parse as urlparse
from logging import Logger
from threading import Thread from threading import Thread
from typing import Optional, Callable from typing import Optional, Callable, Sequence, LiteralString, Literal
from wsgiref.simple_server import WSGIServer from wsgiref.simple_server import WSGIServer
from jinja2 import BaseLoader, Environment, FileSystemLoader from jinja2 import BaseLoader, Environment, FileSystemLoader
from peewee import Database, SqliteDatabase from peewee import Database, SqliteDatabase
from spiderweb.middleware import MiddlewareMixin from spiderweb.middleware import MiddlewareMixin
from spiderweb.constants import (
DEFAULT_CORS_ALLOW_METHODS,
DEFAULT_CORS_ALLOW_HEADERS,
)
from spiderweb.constants import ( from spiderweb.constants import (
DATABASE_PROXY, DATABASE_PROXY,
DEFAULT_ENCODING, DEFAULT_ENCODING,
@ -30,7 +36,7 @@ from spiderweb.request import Request
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
from spiderweb.routes import RoutesMixin from spiderweb.routes import RoutesMixin
from spiderweb.secrets import FernetMixin from spiderweb.secrets import FernetMixin
from spiderweb.utils import get_http_status_by_code from spiderweb.utils import get_http_status_by_code, convert_url_to_regex
console_logger = logging.getLogger(__name__) console_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -42,21 +48,33 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
*, *,
addr: str = None, addr: str = None,
port: int = None, port: int = None,
allowed_hosts: Sequence[str | re.Pattern] = None,
cors_allowed_origins: Sequence[str] = None,
cors_allowed_origins_regexes: Sequence[str] = None,
cors_allow_all_origins: bool = False,
cors_urls_regex: str | re.Pattern[str] = r"^.*$",
cors_allow_methods: Sequence[str] = None,
cors_allow_headers: Sequence[str] = None,
cors_expose_headers: Sequence[str] = None,
cors_preflight_max_age: int = 86400,
cors_allow_credentials: bool = False,
cors_allow_private_network: bool = False,
csrf_trusted_origins: Sequence[str] = None,
db: Optional[Database] = None, db: Optional[Database] = None,
templates_dirs: list[str] = None, templates_dirs: Sequence[str] = None,
middleware: list[str] = None, middleware: Sequence[str] = None,
append_slash: bool = False, append_slash: bool = False,
staticfiles_dirs: list[str] = None, staticfiles_dirs: Sequence[str] = None,
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None, routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
error_routes: dict[int, Callable] = None, error_routes: dict[int, Callable] = None,
secret_key: str = None, secret_key: str = None,
session_max_age=60 * 60 * 24 * 14, # 2 weeks session_max_age: int = 60 * 60 * 24 * 14, # 2 weeks
session_cookie_name="swsession", session_cookie_name: str = "swsession",
session_cookie_secure=False, # should be true if serving over HTTPS session_cookie_secure: bool = False, # should be true if serving over HTTPS
session_cookie_http_only=True, session_cookie_http_only: bool = True,
session_cookie_same_site="lax", session_cookie_same_site: Literal["strict", "lax", "none"] = "lax",
session_cookie_path="/", session_cookie_path: str = "/",
log=None, log: Logger = None,
**kwargs, **kwargs,
): ):
self._routes = {} self._routes = {}
@ -69,9 +87,27 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.append_slash = append_slash self.append_slash = append_slash
self.templates_dirs = templates_dirs self.templates_dirs = templates_dirs
self.staticfiles_dirs = staticfiles_dirs self.staticfiles_dirs = staticfiles_dirs
self._middleware: list[str] = middleware if middleware else [] self._middleware: list[str] = middleware or []
self.middleware: list[Callable] = [] self.middleware: list[Callable] = []
self.secret_key = secret_key if secret_key else self.generate_key() self.secret_key = secret_key if secret_key else self.generate_key()
self._allowed_hosts = allowed_hosts or ["*"]
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
self.cors_allowed_origins = cors_allowed_origins or []
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or []
self.cors_allow_all_origins = cors_allow_all_origins
self.cors_urls_regex = cors_urls_regex
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
self.cors_allow_headers = cors_allow_headers or DEFAULT_CORS_ALLOW_HEADERS
self.cors_expose_headers = cors_expose_headers or []
self.cors_preflight_max_age = cors_preflight_max_age
self.cors_allow_credentials = cors_allow_credentials
self.cors_allow_private_network = cors_allow_private_network
self._csrf_trusted_origins = csrf_trusted_origins or []
self.csrf_trusted_origins = [
convert_url_to_regex(i) for i in self._csrf_trusted_origins
]
self.extra_data = kwargs self.extra_data = kwargs
@ -134,12 +170,18 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
try: try:
status = get_http_status_by_code(resp.status_code) status = get_http_status_by_code(resp.status_code)
cookies = [] cookies = []
if "Set-Cookie" in resp.headers: varies = []
cookies = resp.headers["Set-Cookie"] if "set-cookie" in resp.headers:
del resp.headers["Set-Cookie"] cookies = resp.headers["set-cookie"]
del resp.headers["set-cookie"]
if "vary" in resp.headers:
varies = resp.headers["vary"]
del resp.headers["vary"]
headers = list(resp.headers.items()) headers = list(resp.headers.items())
for c in cookies: for c in cookies:
headers.append(("Set-Cookie", c)) headers.append(("Set-Cookie", c))
for v in varies:
headers.append(("Vary", v))
start_response(status, headers) start_response(status, headers)
@ -180,7 +222,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
): ):
try: try:
status = get_http_status_by_code(500) status = get_http_status_by_code(500)
headers = [("Content-type", "text/plain; charset=utf-8")] headers = [("Content-Type", "text/plain; charset=utf-8")]
start_response(status, headers) start_response(status, headers)
@ -217,6 +259,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
start_response, request, self.get_error_route(500)(request) start_response, request, self.get_error_route(500)(request)
) )
def check_valid_host(self, request) -> bool:
host = request.headers.get("http_host")
if not host:
return False
for option in self.allowed_hosts:
if re.match(option, host):
return True
return False
def __call__(self, environ, start_response, *args, **kwargs): def __call__(self, environ, start_response, *args, **kwargs):
"""Entry point for WSGI apps.""" """Entry point for WSGI apps."""
request = self.get_request(environ) request = self.get_request(environ)
@ -233,6 +284,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
# replace the potentially valid handler with the error route # replace the potentially valid handler with the error route
handler = self.get_error_route(405) handler = self.get_error_route(405)
if not self.check_valid_host(request):
handler = self.get_error_route(403)
if request.is_form_request(): if request.is_form_request():
form_data = urlparse.parse_qs(request.content) form_data = urlparse.parse_qs(request.content)
for key, value in form_data.items(): for key, value in form_data.items():

View File

@ -1,9 +1,11 @@
from typing import Callable, ClassVar from typing import Callable, ClassVar
import sys
from .base import SpiderwebMiddleware as SpiderwebMiddleware from .base import SpiderwebMiddleware as SpiderwebMiddleware
from .cors import CorsMiddleware as CorsMiddleware
from .csrf import CSRFMiddleware as CSRFMiddleware from .csrf import CSRFMiddleware as CSRFMiddleware
from .sessions import SessionMiddleware as SessionMiddleware from .sessions import SessionMiddleware as SessionMiddleware
from ..exceptions import ConfigError, UnusedMiddleware from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors
from ..request import Request from ..request import Request
from ..response import HttpResponse from ..response import HttpResponse
from ..utils import import_by_string from ..utils import import_by_string
@ -27,10 +29,19 @@ class MiddlewareMixin:
self.middleware = middleware_by_reference self.middleware = middleware_by_reference
def run_middleware_checks(self): def run_middleware_checks(self):
errors = []
for middleware in self.middleware: for middleware in self.middleware:
if hasattr(middleware, "checks"): if hasattr(middleware, "checks"):
for check in middleware.checks: for check in middleware.checks:
check(server=self).check() if issue := check(server=self).check():
errors.append(issue)
if errors:
# just show the messages
sys.tracebacklimit = 0
raise StartupErrors(
"Problems were identified during startup — cannot continue.", errors
)
def process_request_middleware(self, request: Request) -> None | bool: def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.middleware: for middleware in self.middleware:

View File

@ -0,0 +1,158 @@
import re
from urllib.parse import urlsplit, SplitResult
from spiderweb.exceptions import ConfigError
from spiderweb.request import Request
from spiderweb.response import HttpResponse
from spiderweb.middleware import SpiderwebMiddleware
from spiderweb.server_checks import ServerCheck
ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
class VerifyValidCorsSetting(ServerCheck):
INVALID_BASE_CONFIG = (
"To enable CORS, one of the three primary configurations must be set:"
" `cors_allowed_origins`, `cors_allowed_origin_regexes`, or"
" `cors_allow_all_origins`.",
)
def check(self):
# - `cors_allowed_origins`
# - `cors_allowed_origin_regexes`
# - `cors_allow_all_origins`
if (
not self.server.cors_allowed_origins
and not self.server.cors.allowed_origin_regexes
and not self.server.cors_allow_all_origins
):
return ConfigError(self.INVALID_BASE_CONFIG)
class CorsMiddleware(SpiderwebMiddleware):
# heavily 'based' on https://github.com/adamchainz/django-cors-headers,
# which is provided under the MIT license. This is essentially a direct
# port, since django-cors-headers is battle-tested code that has been
# around for a long time and it works well. Shoutouts to Otto, Adam, and
# crew for helping make this a complete non-issue in Django for a very long
# time.
checks = [VerifyValidCorsSetting]
def is_enabled(self, request: Request):
return bool(re.match(self.server.cors_urls_regex, request.path))
def add_response_headers(self, request: Request, response: HttpResponse):
enabled = getattr(request, "_cors_enabled", None)
if enabled is None:
enabled = self.is_enabled(request)
if not enabled:
return response
if "vary" in response.headers:
response.headers["vary"].append("origin")
else:
response.headers["vary"] = ["origin"]
origin = request.headers.get("origin")
if not origin:
return response
try:
url = urlsplit(origin)
except ValueError:
return response
if (
not self.server.cors_allow_all_origins
and not self.origin_found_in_allow_lists(origin, url)
):
return response
if (
self.server.cors_allow_all_origins
and not self.server.cors_allow_credentials
):
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
else:
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
if self.server.cors_allow_credentials:
response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
if len(self.server.cors_expose_headers):
response.headers[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
self.server.cors_expose_headers
)
if request.method == "OPTIONS":
response.headers[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
self.server.cors_allow_headers
)
response.headers[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
self.server.cors_allow_methods
)
if self.server.cors_preflight_max_age:
response.headers[ACCESS_CONTROL_MAX_AGE] = str(
self.server.cors_preflight_max_age
)
if (
self.server.cors_allow_private_network
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
):
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
return response
def origin_found_in_allow_lists(self, origin: str, url: SplitResult) -> bool:
return (
(origin == "null" and origin in self.server.cors_allowed_origins)
or self._url_in_allowlist(url)
or self.regex_domain_match(origin)
)
def _url_in_allowlist(self, url: SplitResult) -> bool:
origins = [urlsplit(o) for o in self.server.cors_allowed_origins]
return any(
origin.scheme == url.scheme and origin.netloc == url.netloc
for origin in origins
)
def regex_domain_match(self, origin: str) -> bool:
return any(
re.match(domain_pattern, origin)
for domain_pattern in self.server.cors_allowed_origin_regexes
)
def process_request(self, request: Request) -> HttpResponse | None:
# Identify and handle a preflight request
# origin = request.META.get("HTTP_ORIGIN")
request._cors_enabled = self.is_enabled(request)
if (
request._cors_enabled
and request.method == "OPTIONS"
and "access-control-request-method" in request.headers
):
# this should be 204, but according to mozilla, not all browsers
# parse that correctly. See [204] comment below.
resp = HttpResponse(
"",
status_code=200,
headers={"content-type": "text/plain", "content-length": 0},
)
self.add_response_headers(request, resp)
return resp
def process_response(
self, request: Request, response: HttpResponse
) -> None:
self.add_response_headers(request, response)
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@ -1,4 +1,7 @@
import re
from re import Pattern
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional
from spiderweb.exceptions import CSRFError, ConfigError from spiderweb.exceptions import CSRFError, ConfigError
from spiderweb.middleware import SpiderwebMiddleware from spiderweb.middleware import SpiderwebMiddleware
@ -7,49 +10,96 @@ from spiderweb.response import HttpResponse
from spiderweb.server_checks import ServerCheck from spiderweb.server_checks import ServerCheck
class SessionCheck(ServerCheck): class CheckForSessionMiddleware(ServerCheck):
SESSION_MIDDLEWARE_NOT_FOUND = ( SESSION_MIDDLEWARE_NOT_FOUND = (
"Session middleware is not enabled. It must be listed above" "Session middleware is not enabled. It must be listed above"
"CSRFMiddleware in the middleware list." "CSRFMiddleware in the middleware list."
) )
def check(self) -> Optional[Exception]:
if (
"spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware
):
return ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
class VerifyCorrectMiddlewarePlacement(ServerCheck):
SESSION_MIDDLEWARE_BELOW_CSRF = ( SESSION_MIDDLEWARE_BELOW_CSRF = (
"SessionMiddleware is enabled, but it must be listed above" "SessionMiddleware is enabled, but it must be listed above"
"CSRFMiddleware in the middleware list." "CSRFMiddleware in the middleware list."
) )
def check(self): def check(self) -> Optional[Exception]:
if ( if (
"spiderweb.middleware.sessions.SessionMiddleware" "spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware not in self.server._middleware
): ):
raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND) # this is handled by CheckForSessionMiddleware
return
if self.server._middleware.index( if self.server._middleware.index(
"spiderweb.middleware.sessions.SessionMiddleware" "spiderweb.middleware.sessions.SessionMiddleware"
) > self.server._middleware.index( ) > self.server._middleware.index("spiderweb.middleware.csrf.CSRFMiddleware"):
"spiderweb.middleware.csrf.CSRFMiddleware" return ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
):
raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
class VerifyCorrectFormatForTrustedOrigins(ServerCheck):
CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR = (
"The csrf_trusted_origins setting must be a list of strings."
)
def check(self) -> Optional[Exception]:
if not isinstance(self.server.csrf_trusted_origins, list):
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
for item in self.server.csrf_trusted_origins:
if not isinstance(item, Pattern):
# It's a pattern here because we've already manipulated it
# by the time this check runs
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
class CSRFMiddleware(SpiderwebMiddleware): class CSRFMiddleware(SpiderwebMiddleware):
checks = [SessionCheck] checks = [
CheckForSessionMiddleware,
VerifyCorrectMiddlewarePlacement,
VerifyCorrectFormatForTrustedOrigins,
]
CSRF_EXPIRY = 60 * 60 # 1 hour CSRF_EXPIRY = 60 * 60 # 1 hour
def is_trusted_origin(self, request) -> bool:
origin = request.headers.get("http_origin")
referrer = request.headers.get("http_referer") or request.headers.get("http_referrer")
host = request.headers.get("http_host")
if not origin and not (host == referrer):
return False
if not origin and (host == referrer):
origin = host
for re_origin in self.server.csrf_trusted_origins:
if re.match(re_origin, origin):
return True
return False
def process_request(self, request: Request) -> HttpResponse | None: def process_request(self, request: Request) -> HttpResponse | None:
if request.method == "POST": if request.method == "POST":
if hasattr(request.handler, "csrf_exempt"): if hasattr(request.handler, "csrf_exempt"):
if request.handler.csrf_exempt is True: if request.handler.csrf_exempt is True:
return return
csrf_token = ( csrf_token = (
request.headers.get("X-CSRF-TOKEN") request.headers.get("X-CSRF-TOKEN")
or request.GET.get("csrf_token") or request.GET.get("csrf_token")
or request.POST.get("csrf_token") or request.POST.get("csrf_token")
) )
if not self.is_trusted_origin(request):
if self.is_csrf_valid(request, csrf_token): if self.is_csrf_valid(request, csrf_token):
return None return None
else: else:

View File

@ -2,7 +2,7 @@ import json
from urllib.parse import urlparse from urllib.parse import urlparse
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.utils import get_client_address from spiderweb.utils import get_client_address, Headers
class Request: class Request:
@ -38,20 +38,22 @@ class Request:
self.populate_meta() self.populate_meta()
self.populate_cookies() self.populate_cookies()
content_length = int(self.headers.get("CONTENT_LENGTH") or 0) content_length = int(self.headers.get("content_length") or 0)
if content_length: if content_length:
self.content = ( self.content = (
self.environ["wsgi.input"].read(content_length).decode(DEFAULT_ENCODING) self.environ["wsgi.input"].read(content_length).decode(DEFAULT_ENCODING)
) )
def populate_headers(self) -> None: def populate_headers(self) -> None:
self.headers |= { data = self.headers
"CONTENT_TYPE": self.environ.get("CONTENT_TYPE"), data |= {
"CONTENT_LENGTH": self.environ.get("CONTENT_LENGTH"), "content_type": self.environ.get("CONTENT_TYPE"),
"content_length": self.environ.get("CONTENT_LENGTH"),
} }
for k, v in self.environ.items(): for k, v in self.environ.items():
if k.startswith("HTTP_"): if k.startswith("HTTP_"):
self.headers[k] = v data[k] = v
self.headers = Headers(**{k.lower(): v for k, v in data.items()})
def populate_meta(self) -> None: def populate_meta(self) -> None:
# all caps fields are from WSGI, lowercase names # all caps fields are from WSGI, lowercase names
@ -72,6 +74,9 @@ class Request:
] ]
for f in fields: for f in fields:
self.META[f] = self.environ.get(f) self.META[f] = self.environ.get(f)
for f in self.environ.keys():
if f.startswith("HTTP_"):
self.META[f] = self.environ[f]
self.META["client_address"] = get_client_address(self.environ) self.META["client_address"] = get_client_address(self.environ)
def populate_cookies(self) -> None: def populate_cookies(self) -> None:
@ -86,6 +91,6 @@ class Request:
def is_form_request(self) -> bool: def is_form_request(self) -> bool:
return ( return (
"CONTENT_TYPE" in self.headers "content_type" in self.headers
and self.headers["CONTENT_TYPE"] == "application/x-www-form-urlencoded" and self.headers["content_type"] == "application/x-www-form-urlencoded"
) )

View File

@ -10,6 +10,8 @@ from wsgiref.util import FileWrapper
from spiderweb.constants import REGEX_COOKIE_NAME from spiderweb.constants import REGEX_COOKIE_NAME
from spiderweb.exceptions import GeneralException from spiderweb.exceptions import GeneralException
from spiderweb.request import Request from spiderweb.request import Request
from spiderweb.utils import Headers
mimetypes.init() mimetypes.init()
@ -28,10 +30,11 @@ class HttpResponse:
self.context = context if context else {} self.context = context if context else {}
self.status_code = status_code self.status_code = status_code
self.headers = headers if headers else {} self.headers = headers if headers else {}
if not self.headers.get("Content-Type"): self.headers = Headers(**{k.lower(): v for k, v in self.headers.items()})
self.headers["Content-Type"] = "text/html; charset=utf-8" if not self.headers.get("content-type"):
self.headers["Server"] = "Spiderweb" self.headers["content-type"] = "text/html; charset=utf-8"
self.headers["Date"] = datetime.datetime.now(tz=datetime.UTC).strftime( self.headers["server"] = "Spiderweb"
self.headers["date"] = datetime.datetime.now(tz=datetime.UTC).strftime(
"%a, %d %b %Y %H:%M:%S GMT" "%a, %d %b %Y %H:%M:%S GMT"
) )
@ -89,10 +92,10 @@ class HttpResponse:
attrs = [urllib.parse.quote_plus(value)] + attrs attrs = [urllib.parse.quote_plus(value)] + attrs
cookie = f"{name}={'; '.join(attrs)}" cookie = f"{name}={'; '.join(attrs)}"
if "Set-Cookie" in self.headers: if "set-cookie" in self.headers:
self.headers["Set-Cookie"].append(cookie) self.headers["set-cookie"].append(cookie)
else: else:
self.headers["Set-Cookie"] = [cookie] self.headers["set-cookie"] = [cookie]
def render(self) -> str: def render(self) -> str:
return str(self.body) return str(self.body)
@ -103,7 +106,7 @@ class FileResponse(HttpResponse):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.filename = filename self.filename = filename
self.content_type = mimetypes.guess_type(self.filename)[0] self.content_type = mimetypes.guess_type(self.filename)[0]
self.headers["Content-Type"] = self.content_type self.headers["content-type"] = self.content_type
def render(self) -> list[bytes]: def render(self) -> list[bytes]:
with open(self.filename, "rb") as f: with open(self.filename, "rb") as f:
@ -114,7 +117,7 @@ class FileResponse(HttpResponse):
class JsonResponse(HttpResponse): class JsonResponse(HttpResponse):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.headers["Content-Type"] = "application/json" self.headers["content-type"] = "application/json"
def render(self) -> str: def render(self) -> str:
return json.dumps(self.data) return json.dumps(self.data)
@ -124,7 +127,7 @@ class RedirectResponse(HttpResponse):
def __init__(self, location: str, *args, **kwargs): def __init__(self, location: str, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.status_code = 302 self.status_code = 302
self.headers["Location"] = location self.headers["location"] = location
class TemplateResponse(HttpResponse): class TemplateResponse(HttpResponse):

View File

@ -1,5 +1,5 @@
import re import re
from typing import Callable, Any, Optional from typing import Callable, Any, Optional, Sequence
from spiderweb.constants import DEFAULT_ALLOWED_METHODS from spiderweb.constants import DEFAULT_ALLOWED_METHODS
from spiderweb.converters import * # noqa: F403 from spiderweb.converters import * # noqa: F403
@ -30,7 +30,7 @@ class RoutesMixin:
# ones that start with underscores are the compiled versions, non-underscores # ones that start with underscores are the compiled versions, non-underscores
# are the user-supplied versions # are the user-supplied versions
_routes: dict _routes: dict
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,) routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]]
_error_routes: dict _error_routes: dict
error_routes: dict[int, Callable] error_routes: dict[int, Callable]
append_slash: bool append_slash: bool

View File

@ -4,12 +4,16 @@ from datetime import timedelta
import pytest import pytest
from peewee import SqliteDatabase from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.sessions import Session from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf from spiderweb.tests.views_for_tests import (
form_view_with_csrf,
form_csrf_exempt,
form_view_without_csrf,
)
# app = SpiderwebRouter( # app = SpiderwebRouter(
@ -99,18 +103,21 @@ def test_exploding_middleware():
def test_csrf_middleware_without_session_middleware(): def test_csrf_middleware_without_session_middleware():
_, environ, start_response = setup() _, environ, start_response = setup()
with pytest.raises(ConfigError) as e: with pytest.raises(StartupErrors) as e:
SpiderwebRouter( SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"], middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"), db=SqliteDatabase("spiderweb-tests.db"),
) )
exceptiongroup = e.value.args[1]
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND assert (
exceptiongroup[0].args[0]
== csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND
)
def test_csrf_middleware_above_session_middleware(): def test_csrf_middleware_above_session_middleware():
_, environ, start_response = setup() _, environ, start_response = setup()
with pytest.raises(ConfigError) as e: with pytest.raises(StartupErrors) as e:
SpiderwebRouter( SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
@ -118,8 +125,11 @@ def test_csrf_middleware_above_session_middleware():
], ],
db=SqliteDatabase("spiderweb-tests.db"), db=SqliteDatabase("spiderweb-tests.db"),
) )
exceptiongroup = e.value.args[1]
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF assert (
exceptiongroup[0].args[0]
== csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF
)
def test_csrf_middleware(): def test_csrf_middleware():
@ -211,6 +221,7 @@ def test_csrf_expired_token():
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}" f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
) )
environ["REQUEST_METHOD"] = "POST" environ["REQUEST_METHOD"] = "POST"
environ["HTTP_ORIGIN"] = "example.com"
environ["HTTP_X_CSRF_TOKEN"] = token environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata) environ["CONTENT_LENGTH"] = len(formdata)
@ -254,3 +265,44 @@ def test_csrf_exempt():
environ["PATH_INFO"] = "/2" environ["PATH_INFO"] = "/2"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp2 assert "CSRF token is invalid" in resp2
def test_csrf_trusted_origins():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
csrf_trusted_origins=[
"example.com",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["REQUEST_METHOD"] = "POST"
formdata = "name=bob"
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "notvalid.com"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "example.com"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert resp2 == '{"name": "bob"}'

View File

@ -71,7 +71,7 @@ def test_redirect_response():
return RedirectResponse(location="/redirected") return RedirectResponse(location="/redirected")
assert app(environ, start_response) == [b"None"] assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/redirected" assert start_response.get_headers()["location"] == "/redirected"
def test_add_route_at_server_start(): def test_add_route_at_server_start():
@ -91,7 +91,7 @@ def test_add_route_at_server_start():
) )
assert app(environ, start_response) == [b"None"] assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/redirected" assert start_response.get_headers()["location"] == "/redirected"
def test_redirect_on_append_slash(): def test_redirect_on_append_slash():
@ -104,7 +104,7 @@ def test_redirect_on_append_slash():
environ["PATH_INFO"] = f"/hello" environ["PATH_INFO"] = f"/hello"
assert app(environ, start_response) == [b"None"] assert app(environ, start_response) == [b"None"]
assert start_response.get_headers()["Location"] == "/hello/" assert start_response.get_headers()["location"] == "/hello/"
@given(st.text()) @given(st.text())

View File

@ -1,4 +1,5 @@
import json import json
import re
import secrets import secrets
import string import string
from http import HTTPStatus from http import HTTPStatus
@ -63,3 +64,26 @@ def is_jsonable(data: str) -> bool:
return True return True
except (TypeError, OverflowError): except (TypeError, OverflowError):
return False return False
class Headers(dict):
# special dict that forces lowercase for all keys
def __getitem__(self, key):
return super().__getitem__(key.lower())
def __setitem__(self, key, value):
return super().__setitem__(key.lower(), value)
def get(self, key, default=None):
return super().get(key.lower(), default)
def setdefault(self, key, default=None):
return super().setdefault(key.lower(), default)
def convert_url_to_regex(url: str | re.Pattern) -> re.Pattern:
if isinstance(url, re.Pattern):
return url
url = url.replace(".", "\\.")
url = url.replace("*", ".+")
return re.compile(url)

View File

@ -15,4 +15,7 @@
<p> <p>
<img src="/static/aaaaaa.gif" alt="AAAAAAAAAA"> <img src="/static/aaaaaa.gif" alt="AAAAAAAAAA">
</p> </p>
<p>
{{ request.META }}
</p>
{% endblock %} {% endblock %}