AI Infrastructure at Scale

December 9, 2024

AI applications that work at demo scale often fail at production scale. Latency spikes, cost explosions, and reliability issues emerge. Robust AI infrastructure requires intentional design for scale.

Here’s how to build AI infrastructure that works at scale.

Scale Challenges

What Changes at Scale

scale_challenges:
  latency:
    at_demo: "2 seconds is fine"
    at_scale: "P99 latency matters, queuing delays compound"

  cost:
    at_demo: "$100/month, who cares"
    at_scale: "$50,000/month, need optimization"

  reliability:
    at_demo: "Retry manually"
    at_scale: "Need automatic failover, graceful degradation"

  consistency:
    at_demo: "One user, one model"
    at_scale: "Model updates affect millions, need rollout control"

Architecture Patterns

Gateway Pattern

class AIGateway:
    """Central gateway for all AI API calls."""

    def __init__(self, config: GatewayConfig):
        self.providers = self._init_providers(config)
        self.router = ModelRouter(config.routing_rules)
        self.rate_limiter = RateLimiter(config.rate_limits)
        self.circuit_breaker = CircuitBreaker(config.circuit_config)
        self.cache = ResponseCache(config.cache_config)
        self.metrics = MetricsCollector()

    async def generate(
        self,
        request: GenerateRequest,
        context: RequestContext
    ) -> GenerateResponse:
        # Rate limiting
        await self.rate_limiter.acquire(context.user_id)

        # Check cache
        cache_key = self._cache_key(request)
        cached = await self.cache.get(cache_key)
        if cached:
            self.metrics.record_cache_hit()
            return cached

        # Route to provider
        provider = await self.router.select(request, context)

        # Circuit breaker
        if not self.circuit_breaker.allow(provider.name):
            provider = await self.router.select_fallback(request)

        # Execute with retry
        try:
            response = await self._execute_with_retry(provider, request)
            await self.cache.set(cache_key, response)
            self.metrics.record_success(provider.name)
            return response
        except Exception as e:
            self.circuit_breaker.record_failure(provider.name)
            self.metrics.record_failure(provider.name, e)
            raise

    async def _execute_with_retry(
        self,
        provider: Provider,
        request: GenerateRequest,
        max_retries: int = 3
    ) -> GenerateResponse:
        for attempt in range(max_retries):
            try:
                return await provider.generate(request)
            except RetryableError as e:
                if attempt == max_retries - 1:
                    raise
                await asyncio.sleep(2 ** attempt)

Async Processing

class AsyncAIProcessor:
    """Process AI requests asynchronously for batch workloads."""

    def __init__(self, queue: MessageQueue, workers: int = 10):
        self.queue = queue
        self.workers = workers

    async def submit(self, request: AIRequest) -> str:
        """Submit request, return job ID."""
        job_id = generate_job_id()

        await self.queue.publish(
            topic="ai-requests",
            message={
                "job_id": job_id,
                "request": request.dict(),
                "submitted_at": datetime.utcnow().isoformat()
            }
        )

        return job_id

    async def get_result(
        self,
        job_id: str,
        timeout: int = 300
    ) -> AIResponse:
        """Poll for result."""
        start = time.time()
        while time.time() - start < timeout:
            result = await self.result_store.get(job_id)
            if result:
                return AIResponse.parse(result)
            await asyncio.sleep(1)
        raise TimeoutError(f"Job {job_id} not completed in {timeout}s")

    async def worker_loop(self):
        """Worker that processes requests."""
        while True:
            message = await self.queue.consume("ai-requests")
            try:
                result = await self.gateway.generate(
                    GenerateRequest.parse(message["request"])
                )
                await self.result_store.set(
                    message["job_id"],
                    result.dict()
                )
            except Exception as e:
                await self.result_store.set(
                    message["job_id"],
                    {"error": str(e)}
                )

Caching Layer

class SemanticCache:
    """Cache responses with semantic similarity matching."""

    def __init__(
        self,
        embedding_model,
        vector_store,
        similarity_threshold: float = 0.95
    ):
        self.embedder = embedding_model
        self.store = vector_store
        self.threshold = similarity_threshold

    async def get(self, request: GenerateRequest) -> Optional[GenerateResponse]:
        # Embed the request
        query_embedding = await self.embedder.embed(
            self._request_to_text(request)
        )

        # Search for similar cached requests
        results = await self.store.search(
            embedding=query_embedding,
            top_k=1
        )

        if results and results[0].score >= self.threshold:
            # Cache hit
            return GenerateResponse.parse(results[0].metadata["response"])

        return None

    async def set(
        self,
        request: GenerateRequest,
        response: GenerateResponse,
        ttl: int = 3600
    ):
        embedding = await self.embedder.embed(
            self._request_to_text(request)
        )

        await self.store.insert(
            embedding=embedding,
            metadata={
                "request": request.dict(),
                "response": response.dict(),
                "expires_at": datetime.utcnow() + timedelta(seconds=ttl)
            }
        )

Cost Management

Cost Controls

cost_control_strategies:
  budget_limits:
    - Per-user daily limits
    - Per-feature budgets
    - Organization-wide caps
    - Alert thresholds

  optimization:
    - Model routing (use cheap when possible)
    - Caching (avoid repeat work)
    - Batch processing (OpenAI batch API)
    - Prompt optimization (reduce tokens)

  monitoring:
    - Real-time cost tracking
    - Cost attribution by feature
    - Anomaly detection
    - Forecasting
class CostController:
    async def check_budget(
        self,
        context: RequestContext,
        estimated_cost: float
    ) -> BudgetCheck:
        # Check user budget
        user_usage = await self.get_user_usage(context.user_id)
        if user_usage + estimated_cost > self.user_daily_limit:
            return BudgetCheck(
                allowed=False,
                reason="User daily limit exceeded"
            )

        # Check feature budget
        feature_usage = await self.get_feature_usage(context.feature)
        if feature_usage + estimated_cost > self.feature_budget:
            return BudgetCheck(
                allowed=False,
                reason="Feature budget exceeded"
            )

        return BudgetCheck(allowed=True)

Reliability Patterns

Multi-Provider Failover

failover_configuration:
  primary:
    provider: "anthropic"
    model: "claude-3-5-sonnet"
    timeout: 30s

  fallbacks:
    - provider: "openai"
      model: "gpt-4o"
      timeout: 30s
    - provider: "self-hosted"
      model: "llama-3-70b"
      timeout: 60s

  circuit_breaker:
    failure_threshold: 5
    recovery_time: 60s
    half_open_requests: 3

Graceful Degradation

async def generate_with_degradation(
    self,
    request: GenerateRequest,
    context: RequestContext
) -> GenerateResponse:
    try:
        # Try primary (best quality)
        return await self.primary.generate(request)
    except PrimaryUnavailable:
        # Fall back to faster, cheaper model
        logger.warning("Primary unavailable, using fallback")
        return await self.fallback.generate(request)
    except AllProvidersUnavailable:
        # Return cached or static response
        cached = await self.cache.get_stale(request)
        if cached:
            return cached.with_warning("Using cached response")

        # Last resort: acknowledge limitation
        return GenerateResponse(
            content="I'm experiencing high demand. Please try again shortly.",
            degraded=True
        )

Observability

observability_stack:
  metrics:
    - Request latency (p50, p95, p99)
    - Token usage by model
    - Cost per request/user/feature
    - Cache hit rate
    - Error rate by provider
    - Queue depth (async)

  logging:
    - Request/response pairs (sampled)
    - Errors with context
    - Cost events
    - Routing decisions

  tracing:
    - End-to-end request traces
    - Provider latency breakdown
    - Cache/queue time

Key Takeaways

Scale requires infrastructure. Build it proactively.