You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
233 lines
6.7 KiB
Python
233 lines
6.7 KiB
Python
from fastapi import APIRouter, HTTPException
|
|
from fastapi.testclient import TestClient
|
|
from starlette.responses import PlainTextResponse
|
|
|
|
from iti import create_app
|
|
from iti.config import BaseConfig
|
|
from iti.exceptions import BizError
|
|
from iti.limiter import limit
|
|
from iti.responses import ok, raw_response
|
|
|
|
|
|
class RoutesModule:
|
|
name = "routes"
|
|
|
|
def register_routes(self, app):
|
|
router = APIRouter()
|
|
|
|
@router.get("/demo")
|
|
def demo():
|
|
return ok({"value": 1})
|
|
|
|
@router.get("/auto")
|
|
def auto():
|
|
return {"value": 2}
|
|
|
|
@router.get("/raw")
|
|
@raw_response
|
|
def raw():
|
|
return {"value": 3}
|
|
|
|
@router.get("/text")
|
|
def text():
|
|
return PlainTextResponse("ok")
|
|
|
|
@router.get("/boom")
|
|
def boom():
|
|
raise BizError("业务失败", code=400)
|
|
|
|
@router.get("/http-error")
|
|
def http_error():
|
|
raise HTTPException(status_code=418, detail={"message": "茶壶错误"})
|
|
|
|
@router.get("/limited", dependencies=[limit("1 per minute")])
|
|
def limited():
|
|
return ok()
|
|
|
|
app.include_router(router)
|
|
|
|
|
|
def make_app(**config_values):
|
|
config = BaseConfig(
|
|
database_url="sqlite+pysqlite:///:memory:",
|
|
testing=True,
|
|
**config_values,
|
|
)
|
|
return create_app(modules=[RoutesModule()], config_mapping=config)
|
|
|
|
|
|
def test_framework_health_routes():
|
|
client = TestClient(make_app())
|
|
|
|
assert client.get("/health").json() == {"status": "ok"}
|
|
assert client.get("/ready").json() == {"status": "ok"}
|
|
|
|
|
|
def test_docs_picker_and_ui_variants_are_available():
|
|
client = TestClient(make_app())
|
|
|
|
picker = client.get("/docs")
|
|
swagger = client.get("/docs?ui=swagger")
|
|
scalar = client.get("/docs?ui=scalar")
|
|
redoc = client.get("/docs?ui=redoc")
|
|
|
|
assert picker.status_code == 200
|
|
assert "Swagger" in picker.text
|
|
assert "Scalar" in picker.text
|
|
assert "ReDoc" in picker.text
|
|
assert "<svg viewBox" in picker.text
|
|
assert "->" not in picker.text
|
|
assert swagger.status_code == 200
|
|
assert "swagger-ui" in swagger.text
|
|
assert scalar.status_code == 200
|
|
assert "@scalar/api-reference" in scalar.text
|
|
assert 'data-url="/openapi.json"' in scalar.text
|
|
assert redoc.status_code == 200
|
|
assert "redoc" in redoc.text.lower()
|
|
|
|
|
|
def test_docs_templates_are_packaged_resources():
|
|
client = TestClient(make_app())
|
|
|
|
picker = client.get("/docs")
|
|
scalar = client.get("/docs?ui=scalar")
|
|
|
|
assert "<title>iTi - API Docs</title>" in picker.text
|
|
assert "<title>iTi - Scalar API Reference</title>" in scalar.text
|
|
|
|
|
|
def test_docs_picker_only_shows_enabled_ui():
|
|
client = TestClient(make_app(docs_ui_enabled=["swagger"]))
|
|
|
|
picker = client.get("/docs")
|
|
scalar = client.get("/docs?ui=scalar")
|
|
redoc = client.get("/docs?ui=redoc")
|
|
|
|
assert picker.status_code == 200
|
|
assert "Swagger" in picker.text
|
|
assert "Scalar" not in picker.text
|
|
assert "ReDoc" not in picker.text
|
|
assert scalar.status_code == 200
|
|
assert "@scalar/api-reference" not in scalar.text
|
|
assert redoc.status_code == 200
|
|
assert ">ReDoc<" not in redoc.text
|
|
|
|
|
|
def test_openapi_tags_are_grouped_by_prefix():
|
|
class GroupedRoutesModule:
|
|
name = "grouped-routes"
|
|
|
|
def register_routes(self, app):
|
|
system_router = APIRouter(prefix="/system", tags=["system.user"])
|
|
common_router = APIRouter(prefix="/common", tags=["common.file"])
|
|
plain_router = APIRouter(prefix="/plain", tags=["exchange"])
|
|
|
|
@system_router.get("/users")
|
|
def users():
|
|
return ok([])
|
|
|
|
@common_router.get("/files")
|
|
def files():
|
|
return ok([])
|
|
|
|
@plain_router.get("/exports")
|
|
def exports():
|
|
return ok([])
|
|
|
|
app.include_router(system_router)
|
|
app.include_router(common_router)
|
|
app.include_router(plain_router)
|
|
|
|
config = BaseConfig(database_url="sqlite+pysqlite:///:memory:", testing=True)
|
|
app = create_app(modules=[GroupedRoutesModule()], config_mapping=config)
|
|
schema = TestClient(app).get("/openapi.json").json()
|
|
|
|
assert schema["x-tagGroups"] == [
|
|
{"name": "system", "tags": ["system.user"]},
|
|
{"name": "common", "tags": ["common.file"]},
|
|
{"name": "exchange", "tags": ["exchange"]},
|
|
]
|
|
assert schema["tags"] == [
|
|
{"name": "system.user", "x-displayName": "user"},
|
|
{"name": "common.file", "x-displayName": "file"},
|
|
{"name": "exchange"},
|
|
]
|
|
|
|
|
|
def test_envelope_and_error_handlers():
|
|
client = TestClient(make_app())
|
|
|
|
assert client.get("/demo").json() == {
|
|
"data": {"value": 1},
|
|
"code": 200,
|
|
"message": "成功",
|
|
}
|
|
response = client.get("/boom")
|
|
assert response.status_code == 200
|
|
assert response.json()["message"] == "业务失败"
|
|
|
|
|
|
def test_http_errors_use_envelope():
|
|
client = TestClient(make_app())
|
|
|
|
not_found = client.get("/")
|
|
assert not_found.status_code == 200
|
|
assert not_found.json() == {
|
|
"data": None,
|
|
"code": 404,
|
|
"message": "Not Found",
|
|
}
|
|
|
|
method_not_allowed = client.post("/demo")
|
|
assert method_not_allowed.status_code == 200
|
|
assert method_not_allowed.json()["code"] == 405
|
|
assert method_not_allowed.json()["message"] == "Method Not Allowed"
|
|
|
|
http_error = client.get("/http-error")
|
|
assert http_error.status_code == 200
|
|
assert http_error.json() == {
|
|
"data": {"message": "茶壶错误"},
|
|
"code": 418,
|
|
"message": "茶壶错误",
|
|
}
|
|
|
|
|
|
def test_auto_envelope_wraps_plain_json_and_raw_can_skip():
|
|
client = TestClient(make_app(raw_response_paths=["/health", "/ready", "/raw-path"]))
|
|
|
|
assert client.get("/auto").json() == {
|
|
"data": {"value": 2},
|
|
"code": 200,
|
|
"message": "成功",
|
|
}
|
|
assert client.get("/raw").json() == {"value": 3}
|
|
assert client.get("/text").text == "ok"
|
|
|
|
|
|
def test_rate_limit_dependency():
|
|
client = TestClient(make_app(ratelimit_enabled=True))
|
|
|
|
assert client.get("/limited").status_code == 200
|
|
response = client.get("/limited")
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["code"] == 429
|
|
|
|
|
|
def test_request_log_includes_business_code(monkeypatch):
|
|
records = []
|
|
|
|
def capture(message, *args, **kwargs):
|
|
records.append((message, args, kwargs))
|
|
|
|
monkeypatch.setattr("iti.app.logger.info", capture)
|
|
client = TestClient(make_app())
|
|
|
|
assert client.get("/auto").json()["code"] == 200
|
|
assert records[-1][1][3] == 200
|
|
assert records[-1][2]["extra"]["response_code"] == 200
|
|
|
|
assert client.get("/boom").json()["code"] == 400
|
|
assert records[-1][1][3] == 400
|
|
assert records[-1][2]["extra"]["response_code"] == 400
|