You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
204 lines
6.8 KiB
Python
204 lines
6.8 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
import uuid
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from flask import current_app, g, has_app_context, has_request_context, request
|
|
|
|
from .config import ServiceConfig
|
|
from .errors import ServiceHTTPError, ServiceUnavailableError
|
|
|
|
|
|
class ServiceClient:
|
|
"""Synchronous HTTP JSON service client."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: ServiceConfig,
|
|
*,
|
|
transport: httpx.BaseTransport | None = None,
|
|
) -> None:
|
|
self.config = config
|
|
timeout = httpx.Timeout(
|
|
connect=config.timeout.connect,
|
|
read=config.timeout.read,
|
|
write=config.timeout.write,
|
|
pool=config.timeout.pool,
|
|
)
|
|
self._client = httpx.Client(
|
|
base_url=config.base_url,
|
|
timeout=timeout,
|
|
transport=transport,
|
|
)
|
|
self._fail_count = 0
|
|
self._opened_at: float | None = None
|
|
|
|
def get(self, endpoint: str, **kwargs: Any) -> Any:
|
|
return self.request("GET", endpoint, **kwargs)
|
|
|
|
def post(self, endpoint: str, **kwargs: Any) -> Any:
|
|
return self.request("POST", endpoint, **kwargs)
|
|
|
|
def put(self, endpoint: str, **kwargs: Any) -> Any:
|
|
return self.request("PUT", endpoint, **kwargs)
|
|
|
|
def delete(self, endpoint: str, **kwargs: Any) -> Any:
|
|
return self.request("DELETE", endpoint, **kwargs)
|
|
|
|
def request(
|
|
self,
|
|
method: str,
|
|
endpoint: str,
|
|
*,
|
|
path: dict[str, Any] | None = None,
|
|
path_params: dict[str, Any] | None = None,
|
|
path_values: dict[str, Any] | None = None,
|
|
path_: dict[str, Any] | None = None,
|
|
path_args: dict[str, Any] | None = None,
|
|
path_map: dict[str, Any] | None = None,
|
|
params: dict[str, Any] | None = None,
|
|
json: Any = None,
|
|
headers: dict[str, str] | None = None,
|
|
retry: bool | None = None,
|
|
expect_json: bool = True,
|
|
) -> Any:
|
|
method = method.upper()
|
|
values = path or path_params or path_values or path_ or path_args or path_map or {}
|
|
url = endpoint.format(**values)
|
|
self._raise_if_open()
|
|
|
|
attempts = self._attempts_for(method, retry)
|
|
last_error: Exception | None = None
|
|
|
|
for attempt in range(1, attempts + 1):
|
|
start = time.monotonic()
|
|
try:
|
|
response = self._client.request(
|
|
method,
|
|
url,
|
|
params=params,
|
|
json=json,
|
|
headers=self._headers(headers),
|
|
)
|
|
elapsed_ms = int((time.monotonic() - start) * 1000)
|
|
self._log_call(method, url, response.status_code, elapsed_ms, attempt)
|
|
if response.status_code >= 400:
|
|
if self._should_retry_status(method, response.status_code, attempt, attempts):
|
|
time.sleep(self.config.retry.backoff * attempt)
|
|
continue
|
|
self._record_failure()
|
|
raise ServiceHTTPError(
|
|
self.config.name, response.status_code, response.text
|
|
)
|
|
self._record_success()
|
|
if not expect_json:
|
|
return response
|
|
if not response.content:
|
|
return None
|
|
return response.json()
|
|
except (httpx.TimeoutException, httpx.TransportError) as exc:
|
|
last_error = exc
|
|
elapsed_ms = int((time.monotonic() - start) * 1000)
|
|
self._log_call(method, url, "transport_error", elapsed_ms, attempt)
|
|
if attempt < attempts:
|
|
time.sleep(self.config.retry.backoff * attempt)
|
|
continue
|
|
self._record_failure()
|
|
raise ServiceUnavailableError(
|
|
f"service {self.config.name} unavailable: {exc}"
|
|
) from exc
|
|
|
|
raise ServiceUnavailableError(
|
|
f"service {self.config.name} unavailable: {last_error}"
|
|
)
|
|
|
|
def close(self) -> None:
|
|
self._client.close()
|
|
|
|
def _headers(self, headers: dict[str, str] | None) -> dict[str, str]:
|
|
result = dict(headers or {})
|
|
result.setdefault("Accept", "application/json")
|
|
result.setdefault("Content-Type", "application/json")
|
|
if self.config.token:
|
|
result.setdefault("Authorization", f"Bearer {self.config.token}")
|
|
trace_id = self._trace_id()
|
|
result.setdefault("X-Trace-Id", trace_id)
|
|
return result
|
|
|
|
def _trace_id(self) -> str:
|
|
if has_request_context():
|
|
header_trace = request.headers.get("X-Trace-Id")
|
|
if header_trace:
|
|
return header_trace
|
|
if has_app_context():
|
|
trace_id = getattr(g, "trace_id", None)
|
|
if trace_id:
|
|
return trace_id
|
|
g.trace_id = uuid.uuid4().hex
|
|
return g.trace_id
|
|
return uuid.uuid4().hex
|
|
|
|
def _attempts_for(self, method: str, retry: bool | None) -> int:
|
|
if retry is False:
|
|
return 1
|
|
if retry is True:
|
|
return self.config.retry.attempts
|
|
if method in self.config.retry.methods:
|
|
return self.config.retry.attempts
|
|
return 1
|
|
|
|
def _should_retry_status(
|
|
self, method: str, status_code: int, attempt: int, attempts: int
|
|
) -> bool:
|
|
return (
|
|
attempt < attempts
|
|
and method in self.config.retry.methods
|
|
and status_code in self.config.retry.statuses
|
|
)
|
|
|
|
def _raise_if_open(self) -> None:
|
|
breaker = self.config.circuit_breaker
|
|
if not breaker.enabled or self._opened_at is None:
|
|
return
|
|
elapsed = time.monotonic() - self._opened_at
|
|
if elapsed < breaker.reset_timeout:
|
|
raise ServiceUnavailableError(
|
|
f"service {self.config.name} circuit breaker is open"
|
|
)
|
|
self._opened_at = None
|
|
self._fail_count = 0
|
|
|
|
def _record_success(self) -> None:
|
|
self._fail_count = 0
|
|
self._opened_at = None
|
|
|
|
def _record_failure(self) -> None:
|
|
breaker = self.config.circuit_breaker
|
|
if not breaker.enabled:
|
|
return
|
|
self._fail_count += 1
|
|
if self._fail_count >= breaker.fail_max:
|
|
self._opened_at = time.monotonic()
|
|
|
|
def _log_call(
|
|
self,
|
|
method: str,
|
|
url: str,
|
|
status: int | str,
|
|
elapsed_ms: int,
|
|
attempt: int,
|
|
) -> None:
|
|
if not has_app_context():
|
|
return
|
|
current_app.logger.info(
|
|
"service_call service=%s method=%s path=%s status=%s elapsed_ms=%s attempt=%s",
|
|
self.config.name,
|
|
method,
|
|
url,
|
|
status,
|
|
elapsed_ms,
|
|
attempt,
|
|
)
|