107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Per-tenant + per-platform circuit breaker (pybreaker) and rate limiter (aiolimiter).
|
|
|
|
Both libraries are NOT pre-installed in CoPaw. We fall back to in-process
|
|
implementations if they are missing so the gateway still works on the
|
|
existing image. This avoids a deploy-time dependency on a third-party
|
|
package for a feature that, for now, is mostly cosmetic.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from collections import deque
|
|
from typing import Awaitable, Callable, TypeVar
|
|
|
|
log = logging.getLogger("chathub.gateway.breaker")
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
# ============ Circuit Breaker (fail_fast + reset) ============
|
|
|
|
class SimpleBreaker:
|
|
"""Minimal circuit breaker.
|
|
|
|
State machine:
|
|
CLOSED -> on fail_max consecutive failures -> OPEN
|
|
OPEN -> after reset_timeout seconds -> HALF_OPEN
|
|
HALF_OPEN -> next call passes through
|
|
HALF_OPEN -> success -> CLOSED, failure -> OPEN
|
|
"""
|
|
|
|
CLOSED = "closed"
|
|
OPEN = "open"
|
|
HALF = "half_open"
|
|
|
|
def __init__(self, fail_max: int = 5, reset_timeout: float = 60.0) -> None:
|
|
self.fail_max = fail_max
|
|
self.reset_timeout = reset_timeout
|
|
self.state = SimpleBreaker.CLOSED
|
|
self._fails: deque[float] = deque()
|
|
self._opened_at: float = 0.0
|
|
self._lock = asyncio.Lock()
|
|
|
|
def allow(self) -> bool:
|
|
if self.state == SimpleBreaker.CLOSED:
|
|
return True
|
|
if self.state == SimpleBreaker.OPEN:
|
|
if time.time() - self._opened_at >= self.reset_timeout:
|
|
self.state = SimpleBreaker.HALF
|
|
return True
|
|
return False
|
|
# HALF_OPEN: allow one
|
|
return True
|
|
|
|
def on_success(self) -> None:
|
|
self.state = SimpleBreaker.CLOSED
|
|
self._fails.clear()
|
|
|
|
def on_failure(self) -> None:
|
|
self._fails.append(time.time())
|
|
if len(self._fails) >= self.fail_max:
|
|
self.state = SimpleBreaker.OPEN
|
|
self._opened_at = time.time()
|
|
log.warning("Circuit breaker OPEN, will half-open in %ss", self.reset_timeout)
|
|
|
|
|
|
_breakers: dict[str, SimpleBreaker] = {}
|
|
|
|
|
|
def get_breaker(channel: str) -> SimpleBreaker:
|
|
if channel not in _breakers:
|
|
_breakers[channel] = SimpleBreaker(fail_max=5, reset_timeout=60.0)
|
|
return _breakers[channel]
|
|
|
|
|
|
# ============ Async token bucket limiter ============
|
|
|
|
class AsyncLimiter:
|
|
"""Naive per-key async limiter. rps requests per second, burst=2*rps."""
|
|
|
|
def __init__(self, rps: float) -> None:
|
|
self.rps = rps
|
|
self._min_interval = 1.0 / max(rps, 0.001)
|
|
self._last: dict[str, float] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def acquire(self, key: str) -> None:
|
|
async with self._lock:
|
|
now = time.time()
|
|
last = self._last.get(key, 0.0)
|
|
wait = self._min_interval - (now - last)
|
|
if wait > 0:
|
|
await asyncio.sleep(wait)
|
|
self._last[key] = time.time()
|
|
|
|
|
|
_limiters: dict[int, AsyncLimiter] = {}
|
|
|
|
|
|
def get_tenant_limiter(tenant_id: int, rps: float = 5.0) -> AsyncLimiter:
|
|
if tenant_id not in _limiters:
|
|
_limiters[tenant_id] = AsyncLimiter(rps=rps)
|
|
return _limiters[tenant_id]
|