AI Safety in Production Systems

November 11, 2024

AI safety isn’t just a research topic—it’s a production engineering concern. Prompt injection, harmful outputs, data leakage, and misuse are real risks in deployed systems. Practical safety measures are essential.

Here’s how to implement AI safety in production.

The Safety Landscape

Production Safety Risks

ai_safety_risks:
  prompt_injection:
    description: "User input manipulates model behavior"
    impact: "Data leakage, unauthorized actions, system compromise"
    prevalence: "Common attack vector"

  harmful_outputs:
    description: "Model generates inappropriate content"
    impact: "Brand damage, legal liability, user harm"
    prevalence: "Depends on use case"

  data_leakage:
    description: "Model reveals training or context data"
    impact: "Privacy violations, competitive disclosure"
    prevalence: "Subtle but real"

  misuse:
    description: "System used for unintended harmful purposes"
    impact: "Varies by application"
    prevalence: "Application dependent"

Input Safeguards

Prompt Injection Defense

class InputSanitizer:
    """Defense against prompt injection."""

    def __init__(self, llm_detector):
        self.llm_detector = llm_detector
        self.blocked_patterns = [
            r"ignore (all )?(previous|above) instructions",
            r"you are now",
            r"pretend you are",
            r"act as if",
            r"system prompt:",
            r"<\|.*\|>",  # Special tokens
        ]

    async def sanitize(self, user_input: str) -> SanitizeResult:
        # Pattern matching (fast)
        for pattern in self.blocked_patterns:
            if re.search(pattern, user_input, re.IGNORECASE):
                return SanitizeResult(
                    safe=False,
                    reason="Blocked pattern detected",
                    sanitized=None
                )

        # LLM-based detection (thorough)
        is_injection = await self.llm_detector.detect_injection(user_input)
        if is_injection.likely:
            return SanitizeResult(
                safe=False,
                reason=is_injection.reason,
                sanitized=None
            )

        # Escape special characters
        sanitized = self._escape_special(user_input)

        return SanitizeResult(safe=True, sanitized=sanitized)

    def _escape_special(self, text: str) -> str:
        """Escape characters that might be interpreted specially."""
        # Remove or escape XML-like tags
        text = re.sub(r'<[^>]+>', '', text)
        return text

Input Validation Layer

class InputValidator:
    """Validate and constrain user inputs."""

    def __init__(self, config: ValidationConfig):
        self.max_length = config.max_length
        self.allowed_languages = config.allowed_languages
        self.content_filter = ContentFilter()

    async def validate(self, input_data: UserInput) -> ValidationResult:
        errors = []

        # Length check
        if len(input_data.text) > self.max_length:
            errors.append(f"Input exceeds max length of {self.max_length}")

        # Language detection
        language = detect_language(input_data.text)
        if language not in self.allowed_languages:
            errors.append(f"Language {language} not supported")

        # Content filtering
        content_check = await self.content_filter.check(input_data.text)
        if content_check.flagged:
            errors.append(f"Content flagged: {content_check.categories}")

        # PII detection
        pii = detect_pii(input_data.text)
        if pii and not input_data.pii_consent:
            errors.append("PII detected without consent")

        return ValidationResult(
            valid=len(errors) == 0,
            errors=errors
        )

Output Safeguards

Output Filtering

class OutputFilter:
    """Filter and validate model outputs."""

    def __init__(self, llm, config: FilterConfig):
        self.llm = llm
        self.config = config

    async def filter_output(
        self,
        output: str,
        context: RequestContext
    ) -> FilterResult:
        checks = await asyncio.gather(
            self._check_harmful_content(output),
            self._check_data_leakage(output, context),
            self._check_policy_compliance(output),
            self._check_hallucination_risk(output, context)
        )

        issues = [c for c in checks if c.flagged]

        if issues:
            if self.config.mode == "block":
                return FilterResult(
                    allowed=False,
                    reason=issues[0].reason,
                    original=output
                )
            elif self.config.mode == "redact":
                redacted = await self._redact_issues(output, issues)
                return FilterResult(
                    allowed=True,
                    output=redacted,
                    redactions=len(issues)
                )

        return FilterResult(allowed=True, output=output)

    async def _check_data_leakage(
        self,
        output: str,
        context: RequestContext
    ) -> CheckResult:
        """Detect if output leaks sensitive data."""
        # Check for system prompt leakage
        if context.system_prompt and context.system_prompt[:50] in output:
            return CheckResult(flagged=True, reason="System prompt leaked")

        # Check for PII in output that wasn't in input
        output_pii = detect_pii(output)
        input_pii = detect_pii(context.user_input)
        new_pii = output_pii - input_pii
        if new_pii:
            return CheckResult(flagged=True, reason="New PII in output")

        return CheckResult(flagged=False)

Structured Output Enforcement

class SafeOutputParser:
    """Enforce structured output with safety constraints."""

    async def parse_safe(
        self,
        output: str,
        schema: OutputSchema
    ) -> ParseResult:
        try:
            # Parse JSON
            data = json.loads(output)

            # Validate against schema
            validated = schema.validate(data)

            # Apply safety constraints
            for field, value in validated.items():
                constraint = schema.safety_constraints.get(field)
                if constraint:
                    if not constraint.check(value):
                        return ParseResult(
                            success=False,
                            error=f"Safety constraint violated: {field}"
                        )

            return ParseResult(success=True, data=validated)

        except json.JSONDecodeError:
            return ParseResult(
                success=False,
                error="Invalid JSON output"
            )

System-Level Safeguards

Rate Limiting and Abuse Prevention

class AbusePreventor:
    """Prevent abuse of AI system."""

    async def check_request(
        self,
        user_id: str,
        request: Request
    ) -> AbuseCheckResult:
        # Rate limiting
        rate_ok = await self.rate_limiter.check(user_id)
        if not rate_ok:
            return AbuseCheckResult(allowed=False, reason="Rate limit exceeded")

        # Pattern detection
        recent_requests = await self.get_recent_requests(user_id)
        pattern = self._analyze_pattern(recent_requests)
        if pattern.suspicious:
            return AbuseCheckResult(
                allowed=False,
                reason=f"Suspicious pattern: {pattern.description}"
            )

        # Cost tracking
        user_cost = await self.cost_tracker.get_user_cost(user_id)
        if user_cost > self.config.max_user_cost:
            return AbuseCheckResult(allowed=False, reason="Cost limit exceeded")

        return AbuseCheckResult(allowed=True)

Monitoring and Alerting

safety_monitoring:
  metrics_to_track:
    - input_filter_blocks_rate
    - output_filter_flags_rate
    - injection_attempts_detected
    - rate_limit_hits
    - cost_anomalies

  alert_conditions:
    - injection_rate > 0.1%
    - output_filter_rate > 5%
    - single_user_cost_spike > 10x
    - unusual_query_patterns

  logging:
    - Log all filtered inputs/outputs
    - Log safety check failures
    - Retain for incident investigation

Defense in Depth

layered_defense:
  layer_1_input:
    - Pattern matching
    - Length limits
    - Content filtering
    - LLM injection detection

  layer_2_system:
    - Constrained system prompts
    - Limited tool access
    - Sandboxed execution
    - Clear boundaries

  layer_3_output:
    - Content filtering
    - PII detection
    - Leakage detection
    - Schema validation

  layer_4_monitoring:
    - Real-time alerting
    - Pattern analysis
    - Human review queue
    - Incident response

Key Takeaways

Safe AI requires engineering discipline. Build it in.