-
-
Notifications
You must be signed in to change notification settings - Fork 9.1k
Expand file tree
/
Copy pathtest_openapi_cache_root_path.py
More file actions
75 lines (56 loc) · 2.29 KB
/
test_openapi_cache_root_path.py
File metadata and controls
75 lines (56 loc) · 2.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from fastapi import FastAPI
from fastapi.testclient import TestClient
def test_root_path_does_not_persist_across_requests():
app = FastAPI()
@app.get("/")
def read_root(): # pragma: no cover
return {"ok": True}
# Attacker request with a spoofed root_path
attacker_client = TestClient(app, root_path="/evil-api")
response1 = attacker_client.get("/openapi.json")
data1 = response1.json()
assert any(s.get("url") == "/evil-api" for s in data1.get("servers", []))
# Subsequent legitimate request with no root_path
clean_client = TestClient(app)
response2 = clean_client.get("/openapi.json")
data2 = response2.json()
servers = [s.get("url") for s in data2.get("servers", [])]
assert "/evil-api" not in servers
def test_multiple_different_root_paths_do_not_accumulate():
app = FastAPI()
@app.get("/")
def read_root(): # pragma: no cover
return {"ok": True}
for prefix in ["/path-a", "/path-b", "/path-c"]:
c = TestClient(app, root_path=prefix)
c.get("/openapi.json")
# A clean request should not have any of them
clean_client = TestClient(app)
response = clean_client.get("/openapi.json")
data = response.json()
servers = [s.get("url") for s in data.get("servers", [])]
for prefix in ["/path-a", "/path-b", "/path-c"]:
assert prefix not in servers, (
f"root_path '{prefix}' leaked into clean request: {servers}"
)
def test_legitimate_root_path_still_appears():
app = FastAPI()
@app.get("/")
def read_root(): # pragma: no cover
return {"ok": True}
client = TestClient(app, root_path="/api/v1")
response = client.get("/openapi.json")
data = response.json()
servers = [s.get("url") for s in data.get("servers", [])]
assert "/api/v1" in servers
def test_configured_servers_not_mutated():
configured_servers = [{"url": "https://prod.example.com"}]
app = FastAPI(servers=configured_servers)
@app.get("/")
def read_root(): # pragma: no cover
return {"ok": True}
# Request with a rogue root_path
attacker_client = TestClient(app, root_path="/evil")
attacker_client.get("/openapi.json")
# The original servers list must be untouched
assert configured_servers == [{"url": "https://prod.example.com"}]