Files

107 lines
3.2 KiB
Python

# -*- coding: utf-8 -*-
"""Per-tenant + per-platform circuit breaker (pybreaker) and rate limiter (aiolimiter).
Both libraries are NOT pre-installed in CoPaw. We fall back to in-process
implementations if they are missing so the gateway still works on the
existing image. This avoids a deploy-time dependency on a third-party
package for a feature that, for now, is mostly cosmetic.
"""
from __future__ import annotations
import asyncio
import logging
import time
from collections import deque
from typing import Awaitable, Callable, TypeVar
log = logging.getLogger("chathub.gateway.breaker")
T = TypeVar("T")
# ============ Circuit Breaker (fail_fast + reset) ============
class SimpleBreaker:
"""Minimal circuit breaker.
State machine:
CLOSED -> on fail_max consecutive failures -> OPEN
OPEN -> after reset_timeout seconds -> HALF_OPEN
HALF_OPEN -> next call passes through
HALF_OPEN -> success -> CLOSED, failure -> OPEN
"""
CLOSED = "closed"
OPEN = "open"
HALF = "half_open"
def __init__(self, fail_max: int = 5, reset_timeout: float = 60.0) -> None:
self.fail_max = fail_max
self.reset_timeout = reset_timeout
self.state = SimpleBreaker.CLOSED
self._fails: deque[float] = deque()
self._opened_at: float = 0.0
self._lock = asyncio.Lock()
def allow(self) -> bool:
if self.state == SimpleBreaker.CLOSED:
return True
if self.state == SimpleBreaker.OPEN:
if time.time() - self._opened_at >= self.reset_timeout:
self.state = SimpleBreaker.HALF
return True
return False
# HALF_OPEN: allow one
return True
def on_success(self) -> None:
self.state = SimpleBreaker.CLOSED
self._fails.clear()
def on_failure(self) -> None:
self._fails.append(time.time())
if len(self._fails) >= self.fail_max:
self.state = SimpleBreaker.OPEN
self._opened_at = time.time()
log.warning("Circuit breaker OPEN, will half-open in %ss", self.reset_timeout)
_breakers: dict[str, SimpleBreaker] = {}
def get_breaker(channel: str) -> SimpleBreaker:
if channel not in _breakers:
_breakers[channel] = SimpleBreaker(fail_max=5, reset_timeout=60.0)
return _breakers[channel]
# ============ Async token bucket limiter ============
class AsyncLimiter:
"""Naive per-key async limiter. rps requests per second, burst=2*rps."""
def __init__(self, rps: float) -> None:
self.rps = rps
self._min_interval = 1.0 / max(rps, 0.001)
self._last: dict[str, float] = {}
self._lock = asyncio.Lock()
async def acquire(self, key: str) -> None:
async with self._lock:
now = time.time()
last = self._last.get(key, 0.0)
wait = self._min_interval - (now - last)
if wait > 0:
await asyncio.sleep(wait)
self._last[key] = time.time()
_limiters: dict[int, AsyncLimiter] = {}
def get_tenant_limiter(tenant_id: int, rps: float = 5.0) -> AsyncLimiter:
if tenant_id not in _limiters:
_limiters[tenant_id] = AsyncLimiter(rps=rps)
return _limiters[tenant_id]