AI Infrastructure at Scale

December 18, 2023

Moving from AI prototype to production at scale reveals infrastructure challenges: API rate limits, latency requirements, cost management, and reliability. The patterns that work for demo don’t work for millions of users.

Here’s how to build AI infrastructure that scales.

Scale Challenges

What Changes at Scale

scale_challenges:
  rate_limits:
    problem: API providers limit requests
    at_scale: Hit limits frequently
    impact: Failed requests, degraded experience

  latency:
    problem: LLMs are slow (seconds, not milliseconds)
    at_scale: Queuing adds more delay
    impact: User experience suffers

  cost:
    problem: Pay per token
    at_scale: Costs become significant
    impact: Margins erode

  reliability:
    problem: API dependencies can fail
    at_scale: Failures affect many users
    impact: Availability issues

Architecture Patterns

Request Queuing

import asyncio
from collections import deque
from dataclasses import dataclass

@dataclass
class QueuedRequest:
    id: str
    prompt: str
    priority: int
    future: asyncio.Future
    created_at: float

class AIRequestQueue:
    def __init__(self, rate_limit: int = 100, window_seconds: int = 60):
        self.queue = deque()
        self.rate_limit = rate_limit
        self.window_seconds = window_seconds
        self.request_times = deque()

    async def submit(self, prompt: str, priority: int = 0) -> str:
        future = asyncio.Future()
        request = QueuedRequest(
            id=str(uuid.uuid4()),
            prompt=prompt,
            priority=priority,
            future=future,
            created_at=time.time()
        )

        # Insert by priority (higher priority first)
        inserted = False
        for i, existing in enumerate(self.queue):
            if priority > existing.priority:
                self.queue.insert(i, request)
                inserted = True
                break
        if not inserted:
            self.queue.append(request)

        return await future

    async def process_loop(self):
        while True:
            if not self.queue:
                await asyncio.sleep(0.1)
                continue

            # Check rate limit
            current_time = time.time()
            while self.request_times and self.request_times[0] < current_time - self.window_seconds:
                self.request_times.popleft()

            if len(self.request_times) >= self.rate_limit:
                wait_time = self.request_times[0] + self.window_seconds - current_time
                await asyncio.sleep(wait_time)
                continue

            # Process next request
            request = self.queue.popleft()
            self.request_times.append(current_time)

            try:
                result = await self.llm.generate(request.prompt)
                request.future.set_result(result)
            except Exception as e:
                request.future.set_exception(e)

Load Balancing Across Providers

class MultiProviderLLM:
    """Load balance across multiple LLM providers."""

    def __init__(self, providers: list):
        self.providers = providers
        self.provider_stats = {p.name: {"success": 0, "errors": 0, "latency": []} for p in providers}

    async def generate(self, prompt: str, **kwargs) -> str:
        # Select provider based on health and load
        provider = self._select_provider()

        try:
            start = time.time()
            result = await provider.generate(prompt, **kwargs)
            self._record_success(provider, time.time() - start)
            return result
        except Exception as e:
            self._record_error(provider)
            # Fallback to next provider
            for fallback in self.providers:
                if fallback != provider:
                    try:
                        return await fallback.generate(prompt, **kwargs)
                    except:
                        continue
            raise e

    def _select_provider(self):
        # Weighted selection based on success rate and latency
        scores = {}
        for provider in self.providers:
            stats = self.provider_stats[provider.name]
            total = stats["success"] + stats["errors"]
            if total == 0:
                scores[provider] = 1.0
            else:
                success_rate = stats["success"] / total
                avg_latency = sum(stats["latency"][-100:]) / max(len(stats["latency"][-100:]), 1)
                scores[provider] = success_rate / (avg_latency + 1)

        return max(scores, key=scores.get)

Caching Layer

class ScalableAICache:
    """Multi-level caching for AI responses."""

    def __init__(self, redis_client, local_cache_size=10000):
        self.redis = redis_client
        self.local = LRUCache(local_cache_size)
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2')

    async def get_or_generate(self, prompt: str, generator, ttl=3600):
        # Level 1: Exact match in local cache
        cache_key = self._hash_key(prompt)
        if cache_key in self.local:
            return self.local[cache_key]

        # Level 2: Exact match in Redis
        cached = await self.redis.get(f"ai:exact:{cache_key}")
        if cached:
            result = cached.decode()
            self.local[cache_key] = result
            return result

        # Level 3: Semantic similarity search
        similar = await self._find_similar(prompt)
        if similar:
            return similar

        # Generate new response
        result = await generator(prompt)

        # Store in all cache levels
        self.local[cache_key] = result
        await self.redis.setex(f"ai:exact:{cache_key}", ttl, result)
        await self._store_for_semantic_search(prompt, result)

        return result

    async def _find_similar(self, prompt: str, threshold=0.95):
        embedding = self.encoder.encode(prompt)
        # Search in vector index
        results = await self.vector_store.query(embedding, k=1)
        if results and results[0].score >= threshold:
            return results[0].response
        return None

Reliability Patterns

Circuit Breaker

from enum import Enum
import time

class CircuitState(Enum):
    CLOSED = "closed"
    OPEN = "open"
    HALF_OPEN = "half_open"

class CircuitBreaker:
    def __init__(self, failure_threshold=5, recovery_timeout=60):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.failures = 0
        self.state = CircuitState.CLOSED
        self.last_failure_time = None

    async def call(self, func, *args, **kwargs):
        if self.state == CircuitState.OPEN:
            if time.time() - self.last_failure_time > self.recovery_timeout:
                self.state = CircuitState.HALF_OPEN
            else:
                raise CircuitBreakerOpen("Circuit is open")

        try:
            result = await func(*args, **kwargs)
            self._on_success()
            return result
        except Exception as e:
            self._on_failure()
            raise

    def _on_success(self):
        self.failures = 0
        self.state = CircuitState.CLOSED

    def _on_failure(self):
        self.failures += 1
        self.last_failure_time = time.time()
        if self.failures >= self.failure_threshold:
            self.state = CircuitState.OPEN

Graceful Degradation

class ResilientAIService:
    """AI service with fallbacks for degraded operation."""

    def __init__(self, primary_llm, fallback_llm, cache):
        self.primary = primary_llm
        self.fallback = fallback_llm
        self.cache = cache
        self.circuit_breaker = CircuitBreaker()

    async def generate(self, prompt: str) -> dict:
        # Try cache first
        cached = await self.cache.get(prompt)
        if cached:
            return {"response": cached, "source": "cache"}

        # Try primary with circuit breaker
        try:
            response = await self.circuit_breaker.call(
                self.primary.generate, prompt
            )
            await self.cache.set(prompt, response)
            return {"response": response, "source": "primary"}
        except CircuitBreakerOpen:
            pass
        except Exception as e:
            logger.warning(f"Primary failed: {e}")

        # Fallback to secondary model
        try:
            response = await self.fallback.generate(prompt)
            return {"response": response, "source": "fallback"}
        except Exception as e:
            logger.error(f"Fallback failed: {e}")

        # Final fallback: static response
        return {
            "response": "I'm experiencing issues. Please try again.",
            "source": "static",
            "error": True
        }

Cost Management

Budget Controls

class BudgetController:
    """Control AI spending with budgets and alerts."""

    def __init__(self, redis_client, daily_budget: float):
        self.redis = redis_client
        self.daily_budget = daily_budget

    async def check_and_record(self, estimated_cost: float) -> bool:
        key = f"ai_budget:{datetime.now().strftime('%Y%m%d')}"

        current = float(await self.redis.get(key) or 0)

        if current + estimated_cost > self.daily_budget:
            await self._alert_budget_exceeded()
            return False

        await self.redis.incrbyfloat(key, estimated_cost)
        await self.redis.expire(key, 86400 * 2)

        if current + estimated_cost > self.daily_budget * 0.8:
            await self._alert_budget_warning()

        return True

    async def _alert_budget_exceeded(self):
        # Send alert to ops team
        pass

Observability

Comprehensive Metrics

ai_infrastructure_metrics:
  system:
    - Request rate by provider
    - Latency percentiles
    - Error rate by type
    - Queue depth

  business:
    - Cost per request
    - Daily spending
    - Cache hit rate
    - Fallback usage rate

  quality:
    - Response quality scores
    - User feedback
    - Retry rates

Key Takeaways

AI at scale is distributed systems. Apply those principles.