Part 1 of 2 in the ML Deployment series
Production ML requires 5-10x more infrastructure code than model training1. Deployment success depends on serving architecture, containerization, lifecycle management, performance optimization, and operational monitoring. This guide synthesizes industry practices from Netflix23, Uber45, AWS6, and NVIDIA789 into actionable patterns with concrete benchmarks and decision frameworks.
| Pattern | Latency | Throughput | Use Case | Complexity |
|---|---|---|---|---|
| REST API (FastAPI) | 50-200ms | Medium | Public APIs, <10KB payloads | Low |
| gRPC | 10-50ms | High | Internal services, >100KB data | Medium |
| Kafka Streaming | <10ms | Very High | Real-time events, millions msg/sec | High |
| Batch (Celery) | Minutes-Hours | Very High | Offline scoring, cacheable inputs | Low |
| Serverless (Lambda) | Variable | Low-Medium | <1000 req/day, sporadic traffic | Low |
Critical mistake: Blocking the event loop with synchronous model inference degrades async I/O performance by 10-20x (5ms → 100ms for I/O tasks). FastAPI and async frameworks require proper isolation of CPU-bound ML inference.
Solution: ProcessPoolExecutor Pattern
from concurrent.futures import ProcessPoolExecutor
import asyncio
from fastapi import FastAPI
def create_model():
global model
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
pool = ProcessPoolExecutor(max_workers=1, initializer=create_model)
app = FastAPI()
async def predict_async(text):
loop = asyncio.get_event_loop()
vector = await loop.run_in_executor(pool, model_predict, text)
return vector
@app.post("/predict")
async def predict(request: PredictionRequest):
result = await predict_async(request.text)
return {"prediction": result}
Performance: Isolates CPU-bound inference while preserving async I/O performance.
Benchmarks show 9-10x throughput improvement10 for large payloads, with P99 latency 11x better (7s vs 30s at 500 concurrent threads). Computer vision models see 75-85% latency reduction:
Why gRPC Wins:
Kafka delivers millions of messages/second with single-digit millisecond P99 latency.
Two Architecture Patterns:
| Pattern | Latency | Coupling | Use Case |
|---|---|---|---|
| Remote Serving | Higher | Coupled availability | Separation of concerns |
| Embedded Inference | Lowest | Decoupled | Best latency, exactly-once semantics |
Embedded Pattern:
Kafka → Flink (with embedded model) → Kafka
Ideal for fraud detection, real-time recommendations, IoT sensor processing.
| Criteria | Use Batch | Use Real-Time |
|---|---|---|
| Update latency | Hours/days acceptable | Seconds required |
| Input space | Cacheable, limited variety | Unpredictable inputs |
| Traffic pattern | Predictable, periodic | Sporadic, on-demand |
| Cost sensitivity | Optimize compute cost | Optimize user experience |
| Mesh | Latency Overhead | When to Use |
|---|---|---|
| Linkerd | 5-10% | Latency-critical ML, GPU cost significant |
| Istio | 25-35% | Complex routing, multi-cluster, large teams |
Rule: Small teams → Linkerd’s simplicity; Large platform engineering → Istio’s features.
Reduce ML images from 450MB to 158MB (65% reduction).
Production Dockerfile:
# Stage 1: Dependency Builder
FROM python:3.11-slim as deps-builder
RUN apt-get update && apt-get install -y build-essential && \
rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY requirements.txt .
RUN python -m venv /venv && \
/venv/bin/pip install --no-cache-dir -r requirements.txt
# Stage 2: Model Preparation
FROM python:3.11-slim as model-stage
COPY --from=deps-builder /venv /venv
ENV PATH="/venv/bin:$PATH"
WORKDIR /app
COPY download_models.py ./
RUN python download_models.py
# Stage 3: Production Runtime
FROM python:3.11-slim as runtime
RUN apt-get update && apt-get install -y libgomp1 && \
rm -rf /var/lib/apt/lists/* && \
groupadd -r appuser && useradd -r -g appuser appuser
COPY --from=deps-builder /venv /venv
ENV PATH="/venv/bin:$PATH"
COPY --from=model-stage /app/models /models
WORKDIR /app
COPY . .
RUN chown -R appuser:appuser /app
USER appuser
HEALTHCHECK --interval=30s --timeout=10s \
CMD python -c "import requests; requests.get('http://localhost:8000/health')"
EXPOSE 8000
CMD ["gunicorn", "--bind", "0.0.0.0:8000", "--workers", "2", "app:app"]
Critical: Use python:3.11-slim-bookworm, NOT Alpine (50x slower builds due to musl libc incompatibility).
| Strategy | Cold Start (140MB) | Cold Start (1GB) | When to Use |
|---|---|---|---|
| Baked into image | 5.6s | 17.6s | Small models (\u003c500MB) |
| S3 loading | 5.7s | 20.2s | Large models, frequent updates |
| Provisioned concurrency | 0s | 0s | Production APIs (cost: ~$15/mo) |
| Lazy loading (@lru_cache) | First request: 10s | First request: 30s | Low-traffic endpoints |
CPU metrics respond in 30-60 seconds. GPU and queue metrics: 5-15 seconds.
GPU-Based HPA:
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: ml-inference-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: ml-inference-deployment
minReplicas: 2
maxReplicas: 15
metrics:
- type: External
external:
metric:
name: dcgm_fi_dev_gpu_util # NVIDIA DCGM exporter
target:
type: AverageValue
averageValue: "75"
- type: Pods
pods:
metric:
name: tgi_queue_size
target:
type: AverageValue
averageValue: "10"
behavior:
scaleDown:
stabilizationWindowSeconds: 300 # Prevent thrashing
policies:
- type: Percent
value: 50
periodSeconds: 60
scaleUp:
stabilizationWindowSeconds: 0 # Immediate
policies:
- type: Percent
value: 100
periodSeconds: 30
Critical: GPUs MUST be in limits only, with requests = limits.
resources:
limits:
nvidia.com/gpu: 1
requests:
nvidia.com/gpu: 1
Node affinity for GPU types:
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: accelerator
operator: In
values:
- nvidia-tesla-v100
- nvidia-tesla-a100
| Metric | Use Serverless | Use Kubernetes |
|---|---|---|
| Traffic | \u003c1000 req/day | \u003e10K req/day |
| Latency SLA | Best effort | \u003c100ms required |
| GPU needs | None | Required |
| Cost | Optimize for idle | Optimize for throughput |
Cold start reality:
Mitigation hierarchy:
Use dependency injection for testability and resource management.
from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
class MLServiceContainer(containers.DeclarativeContainer):
config = providers.Configuration()
# Singleton: load model once
model = providers.Singleton(
MLModel,
model_path=config.model_path
)
# Factory: per-request instances
prediction_service = providers.Factory(
PredictionService,
model=model
)
@inject
def predict(
data: dict,
service: PredictionService = Provide[MLServiceContainer.prediction_service]
):
return service.predict(data)
| Probe | Endpoint | Purpose | Timeout |
|---|---|---|---|
| Liveness | /health | Process alive | 5s |
| Readiness | /health/ready | Dependencies + model loaded | 10s |
| Startup | /health/startup | Extended initialization | 150s (30×5s) |
import signal
import time
from threading import Lock
is_shutting_down = False
active_requests = 0
request_lock = Lock()
def shutdown_handler(signum, frame):
global is_shutting_down
is_shutting_down = True
logger.info("Shutdown initiated, draining requests...")
# Wait for active requests to complete
while active_requests > 0:
time.sleep(0.1)
# Cleanup
model.cleanup()
db_connection.close()
logger.info("Shutdown complete")
sys.exit(0)
signal.signal(signal.SIGTERM, shutdown_handler)
@app.before_request
def check_shutdown():
if is_shutting_down:
abort(503, "Service shutting down")
global active_requests
with request_lock:
active_requests += 1
@app.after_request
def decrement_counter(response):
global active_requests
with request_lock:
active_requests -= 1
return response
Semantic Versioning:
MLflow Registry:
import mlflow
# Register model
mlflow.register_model(
model_uri="runs:/abc123/model",
name="fraud_detector"
)
# Promote to production
client = mlflow.tracking.MlflowClient()
client.transition_model_version_stage(
name="fraud_detector",
version=5,
stage="Production",
archive_existing_versions=True
)
50/50 split is optimal. Even split minimizes variance (0.5 × 0.5 = 0.25) vs 95/5 (0.95 × 0.05 = 0.0475) by 5x.
Sample size calculation:
import numpy as np
from scipy import stats
def calculate_sample_size(
baseline_rate: float,
minimum_detectable_effect: float,
alpha: float = 0.05,
power: float = 0.8
) -> int:
z_alpha = stats.norm.ppf(1 - alpha/2)
z_beta = stats.norm.ppf(power)
p1 = baseline_rate
p2 = baseline_rate * (1 + minimum_detectable_effect)
p_avg = (p1 + p2) / 2
n = ((z_alpha * np.sqrt(2 * p_avg * (1 - p_avg)) +
z_beta * np.sqrt(p1 * (1 - p1) + p2 * (1 - p2))) /
(p2 - p1)) ** 2
return int(np.ceil(n))
# Example: 3% baseline, 20% relative improvement
n = calculate_sample_size(0.03, 0.20) # ~6,500 per group
Minimum test duration: max(sample_size/daily_users, 14 days) to capture weekly patterns.
Standard progression: 5% → 25% → 50% → 100%
Pause durations: 15min, 30min, 60min between stages
apiVersion: argoproj.io/v1alpha1
kind: Rollout
metadata:
name: ml-model
spec:
replicas: 10
strategy:
canary:
steps:
- setWeight: 5
- pause: {duration: 15m}
- setWeight: 25
- pause: {duration: 30m}
- setWeight: 50
- pause: {duration: 1h}
analysis:
templates:
- templateName: success-metrics
startingStep: 1
Automated rollback triggers:
| Pattern | Use When | Traffic Split | Duration | Rollback Speed |
|---|---|---|---|---|
| A/B Testing | Statistical comparison | 50/50 | 1-2 weeks | Moderate |
| Canary | Gradual rollout | 5→25→50→100% | 2-4 hours | Fast (\u003c1 min) |
| Blue-Green | Instant rollback needed | 0→100% | Immediate | Instant |
| Shadow | Zero user impact | 0% (observe) | Continuous | N/A |
| Technique | Speedup | Memory Reduction | Accuracy Impact | Complexity |
|---|---|---|---|---|
| INT8 Quantization | 2-4x | 4x | 1-3% loss | Medium |
| FP16 Quantization | 2x | 2x | \u003c1% loss | Low |
| Structured Pruning | 1.5-2x | 3x | 2-5% loss | Medium-High |
| Knowledge Distillation | 2-5x | 2-10x | 2-5% loss | High |
| TensorRT | 5x | - | Minimal | Low-Medium |
| ONNX Runtime | 3x | - | None | Low |
| Redis Caching | 100-1000x | - | None | Low |
| Dynamic Batching | 3-4x | - | None | Low |
INT8 quantization: 2-4x speedup, 4x memory reduction, <1% accuracy loss.
Quantization-Aware Training (QAT) vs Post-Training (PTQ):
QAT is mandatory for complex models.
# PyTorch Static Quantization (QAT)
import torch.quantization
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
# Calibrate with 100 samples
for data in calibration_loader:
model(data)
torch.quantization.convert(model, inplace=True)
Framework comparison:
| Framework | Platform | Speedup | Use Case |
|---|---|---|---|
| TensorRT | NVIDIA GPU | 5x | Maximum GPU performance |
| ONNX Runtime | Cross-platform | 3x | Compatibility, portability |
| OpenVINO | Intel CPU/GPU | 27% | Intel hardware, edge |
ONNX Export:
import torch.onnx
import onnxruntime as ort
# Export to ONNX
torch.onnx.export(
model,
dummy_input,
"model.onnx",
opset_version=13,
do_constant_folding=True
)
# Load and run
session = ort.InferenceSession("model.onnx")
output = session.run(None, {"input": input_data.numpy()})
TensorRT benchmarks:
100-1000x speedup for repeat predictions.
import redis
import json
import hashlib
class RedisCachingDecorator:
def __init__(self, host='localhost', port=6379, prefix=None):
self._redis_client = redis.Redis(host=host, port=port, db=0)
self.prefix = prefix
def _create_cache_key(self, model_name, version, input_data):
input_hash = hashlib.sha256(
json.dumps(input_data, sort_keys=True).encode()
).hexdigest()[:16]
return f"{self.prefix}/{model_name}/{version}/{input_hash}"
def predict(self, model, input_data):
cache_key = self._create_cache_key(
model.__class__.__name__,
getattr(model, 'version', '1.0'),
input_data
)
# Try cache first
cached_result = self._redis_client.get(cache_key)
if cached_result:
return json.loads(cached_result)
# Compute and cache
result = model.predict(input_data)
self._redis_client.setex(cache_key, 3600, json.dumps(result))
return result
TTL guidelines:
NVIDIA Triton benchmarks: 3.7x throughput improvement13.
Without batching: Inception plateaus at 73 infer/sec With dynamic batching: Reaches 272 infer/sec at concurrency 8
# Triton config.pbtxt
max_batch_size: 16
dynamic_batching {
preferred_batch_size: [4, 8, 16]
max_queue_delay_microseconds: 100
priority_levels: 2
default_priority_level: 1
}
instance_group [{
count: 2
kind: KIND_GPU
gpus: [0]
}]
Latency-throughput tradeoff (Llama-2-70B on A100):
| Batch Size | Latency/Token | Throughput | Improvement |
|---|---|---|---|
| 1 | 30ms | 33 tok/sec | Baseline |
| 4 | 35ms | 114 tok/sec | 3.5x |
| 16 | 40ms | 400 tok/sec | 12x |
| 64 | 120ms | 533 tok/sec | 16x (plateau) |
| Percentile | Meaning | Alert Threshold |
|---|---|---|
| P50 (median) | Typical experience | \u003c 200ms |
| P95 | Catch tail latency early | \u003c 500ms |
| P99 | Architectural bottlenecks | \u003c 1000ms |
Never rely on averages - they mislead in skewed distributions.
from prometheus_client import Histogram
latency_histogram = Histogram(
'http_request_duration_seconds',
'Duration of HTTP requests',
['status', 'path', 'method'],
buckets=[0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10]
)
# Query P95 in PromQL
# histogram_quantile(0.95, sum(rate(http_request_duration_seconds_bucket[5m])) by (le))
| Metric Type | Aggregation | Query-time Percentiles | Use Case |
|---|---|---|---|
| Histogram | ✅ Across instances | ✅ Flexible | Distributed systems, Kubernetes |
| Summary | ❌ Per-instance only | ❌ Pre-defined | Single instance, exact values |
Use Histogram for ML deployments.
Track:
Recording rules (precompute expensive queries):
# ml_recording_rules.yml
groups:
- name: ml_sli_rules
interval: 30s
rules:
- record: ml:prediction_latency:p95
expr: |
histogram_quantile(0.95,
sum(rate(ml_prediction_duration_seconds_bucket[5m]))
by (le, model_version))
- record: ml:prediction_error_rate
expr: |
sum(rate(ml_predictions_total{result="error"}[5m]))
by (model_version) /
sum(rate(ml_predictions_total[5m]))
by (model_version)
- record: ml:slo:prediction_latency_good_ratio
expr: |
sum(rate(ml_prediction_duration_seconds_bucket{le="0.5"}[5m]))
by (model_version) /
sum(rate(ml_prediction_duration_seconds_count[5m]))
by (model_version)
Fast burn (1h): Immediate paging Slow burn (6h): Create ticket
- alert: MLLatencyFastBurn
expr: |
(ml:prediction_latency:p95 > 0.5 and
(1 - ml:slo:prediction_latency_good_ratio) > (14.4 * 0.001))
for: 2m
labels:
severity: page
- alert: MLLatencySlowBurn
expr: |
(ml:prediction_latency:p95 > 0.5 and
(1 - ml:slo:prediction_latency_good_ratio) > (6 * 0.001))
for: 15m
labels:
severity: ticket
Use structlog with JSON in production, console in development.
import structlog
import uuid
from flask import Flask, request, g
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.JSONRenderer()
]
)
logger = structlog.get_logger()
@app.before_request
def add_correlation_id():
correlation_id = request.headers.get(
'X-Correlation-ID',
str(uuid.uuid4())
)
g.correlation_id = correlation_id
structlog.contextvars.bind_contextvars(
correlation_id=correlation_id
)
@app.route('/predict', methods=['POST'])
def predict():
logger.info(
"prediction_made",
prediction=0.85,
confidence=0.92,
latency_ms=45
)
| Metric | Target | Measurement |
|---|---|---|
| P50 latency | \u003c 200ms | Typical user experience |
| P95 latency | \u003c 500ms | Good experience threshold |
| P99 latency | \u003c 1000ms | Worst acceptable case |
| Availability | 99.5% | 21 min downtime/month |
| Error rate | \u003c 0.1% | 1 in 1000 requests |
Capacity planning: Track at P95 latency, alert at 80% capacity utilization.
| Method | Data Type | Sensitivity | Sample Dependency | Use Case |
|---|---|---|---|---|
| KS Test | Continuous | Very High | Strong | Small datasets, detect tiny changes |
| PSI | Both | Low-Medium | None | Large datasets, finance/credit |
| Wasserstein | Continuous | Medium | Weak | General purpose, balanced |
| Chi-Square | Categorical | Medium | Moderate | Categorical features |
| JS Divergence | Both | Medium | Weak | Symmetric comparison |
PSI = Σ((%_new - %_old) × ln(%_new / %_old))
import numpy as np
def calculate_psi(expected, actual, buckets=10):
"""Calculate Population Stability Index"""
breakpoints = np.arange(0, buckets + 1) / buckets * 100
expected_percents = np.histogram(expected, bins=breakpoints)[0] / len(expected)
actual_percents = np.histogram(actual, bins=breakpoints)[0] / len(actual)
psi_values = (actual_percents - expected_percents) * \
np.log(actual_percents / expected_percents)
psi_values[np.isnan(psi_values)] = 0
psi_values[np.isinf(psi_values)] = 0
return np.sum(psi_values)
Thresholds:
Tiered system:
| Severity | PSI Range | Accuracy Drop | Action |
|---|---|---|---|
| Small | 0.10-0.15 | \u003c 2% | Automated retraining |
| Moderate | 0.15-0.25 | 2-5% | Human review |
| Severe | \u003e 0.25 | \u003e 5% | Emergency intervention |
Additional triggers:
from pydantic import BaseModel, validator, Field
class PredictionInput(BaseModel):
age: int = Field(..., ge=18, le=100)
income: float = Field(..., gt=0, lt=10000000)
credit_score: int = Field(..., ge=300, le=850)
@validator('income')
def income_outlier_check(cls, v):
# Reject beyond 99th percentile from training
if v > 500000: # Example threshold
raise ValueError('income value suspicious')
return v
| Layer | Practice | Implementation |
|---|---|---|
| Input | Schema validation | Pydantic with strict types |
| Input | Adversarial detection | Confidence threshold, perturbation |
| API | Authentication | JWT RS256, API key rotation |
| API | Rate limiting | Token bucket, 100 req/min |
| Model | Access control | RBAC, audit logs, least privilege |
| Model | Encryption | AES-256 at rest, TLS 1.2+ in transit |
| Pipeline | Dependency scanning | safety, Snyk, version pinning |
| Compliance | Audit trails | Append-only logs, hash chaining |
Token bucket algorithm:
import redis
from datetime import datetime
class RateLimiter:
def __init__(self, redis_client):
self.redis = redis_client
def check_rate_limit(
self,
user_id: str,
max_requests: int = 100,
window_seconds: int = 60
):
key = f"rate_limit:{user_id}:{datetime.now().minute}"
current = self.redis.incr(key)
if current == 1:
self.redis.expire(key, window_seconds)
if current > max_requests:
return False, f"Rate limit exceeded: {current}/{max_requests}"
return True, f"Requests remaining: {max_requests - current}"
Thresholds:
| Pattern | Purpose | When to Use | Complexity |
|---|---|---|---|
| Retry | Transient failures | Network timeouts, rate limits | Low |
| Circuit Breaker | Cascading failures | Service degradation, outages | Medium |
| Fallback | Graceful degradation | Circuit open, service down | Low-Medium |
| Timeout | Bound latency | Prevent hung requests | Low |
| Bulkhead | Resource isolation | Prevent resource exhaustion | Medium |
from tenacity import retry, stop_after_attempt, wait_exponential
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10)
)
def call_external_api(data):
response = requests.post(API_URL, json=data, timeout=5)
response.raise_for_status()
return response.json()
Retry: 4s, 8s, 10s delays
from pybreaker import CircuitBreaker, CircuitBreakerError
breaker = CircuitBreaker(
fail_max=5, # Open after 5 failures
reset_timeout=60, # Retry after 60 seconds
exclude=[ValueError] # Don't count these
)
@breaker
def call_prediction_service(data):
response = requests.post(PREDICTION_URL, json=data, timeout=5)
response.raise_for_status()
return response.json()
try:
result = call_prediction_service(input_data)
except CircuitBreakerError:
# Use fallback
result = get_cached_prediction(input_data)
States:
When to use:
First optimizations (ROI order):
Monitoring essentials:
Security baseline:
Continue the series:
Production ML Systems - Google for Developers. Infrastructure requirements for production ML. ↩
Supporting Diverse ML Systems at Netflix - Netflix Technology Blog. Metaflow and ML platform architecture. ↩
Scaling Media Machine Learning at Netflix - Netflix Technology Blog. Multimodal ML at scale. ↩
Meet Michelangelo: Uber’s Machine Learning Platform - Uber Engineering. End-to-end ML platform architecture. ↩
Upgrading Uber’s MySQL Fleet to version 8.0 - Uber Engineering. Infrastructure handling 3M queries/second. ↩
AWS HealthLake - FHIR-native healthcare data platform with ML capabilities. ↩
NVIDIA TensorRT Best Practices Guide - Official NVIDIA documentation. ↩
DCGM Exporter Documentation - NVIDIA GPU monitoring for Kubernetes. ↩
Monitoring GPUs in Kubernetes with DCGM - NVIDIA Technical Blog. ↩
gRPC vs REST Performance Comparison - Google Cloud benchmarks showing 9-10x throughput improvement. ↩
NVIDIA TensorRT Best Practices Guide - ResNet benchmarks. ↩
Torch-TensorRT ResNet-50 Example - PyTorch TensorRT integration. ↩
NVIDIA Triton Optimization Guide - Dynamic batching benchmarks. ↩