// article
Part 1 of 2 in the ML deployment series. Part 2: building Atlas’s forecasting system in production.
Production ML is 5–10× more infrastructure code than model training1. The model is the part everyone focuses on; the rest — serving, containerization, lifecycle management, monitoring, drift — is what makes it work or doesn’t. This piece is the reference I’ve used to think through each stage, drawing on practice from Netflix23, Uber45, AWS6, and NVIDIA789 alongside systems I’ve built.
| Pattern | Latency | Throughput | Fits | Complexity |
|---|---|---|---|---|
| REST API (FastAPI) | 50–200ms | Medium | Public APIs, <10KB payloads | Low |
| gRPC | 10–50ms | High | Internal services, >100KB data | Medium |
| Kafka streaming | <10ms | Very high | Real-time events, millions msg/s | High |
| Batch (Celery) | Minutes–hours | Very high | Offline scoring, cacheable inputs | Low |
| Serverless (Lambda) | Variable | Low–medium | <1,000 req/day, sporadic traffic | Low |
The subtle failure mode: synchronous model inference inside an async framework blocks the event loop, degrading I/O performance 10–20× (5ms → 100ms). FastAPI plus a model is a textbook example.
The fix is to push the model into a process pool and await it.
from concurrent.futures import ProcessPoolExecutor
import asyncio
from fastapi import FastAPI
def create_model():
global model
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
pool = ProcessPoolExecutor(max_workers=1, initializer=create_model)
app = FastAPI()
async def predict_async(text):
loop = asyncio.get_event_loop()
vector = await loop.run_in_executor(pool, model_predict, text)
return vector
@app.post("/predict")
async def predict(request: PredictionRequest):
result = await predict_async(request.text)
return {"prediction": result}
CPU-bound inference runs in the pool; async I/O stays responsive.
For internal service-to-service traffic, gRPC is consistently 9–10× the throughput of REST10, with P99 latency 11× better at 500 concurrent threads (7s vs. 30s). Computer vision models see 75–85% latency reductions:
Three reasons it wins:
Millions of messages per second with single-digit-ms P99. Two patterns:
| Pattern | Latency | Coupling | Fits |
|---|---|---|---|
| Remote serving | Higher | Coupled availability | Separation of concerns |
| Embedded inference | Lowest | Decoupled | Best latency, exactly-once semantics |
Embedded looks like:
Kafka → Flink (with embedded model) → Kafka
Good for fraud detection, real-time recommendations, IoT telemetry.
| Criterion | Batch | Real-time |
|---|---|---|
| Update latency | Hours/days OK | Seconds required |
| Input space | Cacheable, limited variety | Unpredictable |
| Traffic | Predictable, periodic | Sporadic, on-demand |
| Cost sensitivity | Optimize compute | Optimize experience |
| Mesh | Latency overhead | When |
|---|---|---|
| Linkerd | 5–10% | Latency-critical ML, GPU cost significant |
| Istio | 25–35% | Complex routing, multi-cluster, large teams |
Small teams default to Linkerd. Platform teams with feature requirements default to Istio1112.
Right pattern takes an ML image from 450MB to 158MB:
# Stage 1: dependency builder
FROM python:3.11-slim as deps-builder
RUN apt-get update && apt-get install -y build-essential && \
rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY requirements.txt .
RUN python -m venv /venv && \
/venv/bin/pip install --no-cache-dir -r requirements.txt
# Stage 2: model preparation
FROM python:3.11-slim as model-stage
COPY --from=deps-builder /venv /venv
ENV PATH="/venv/bin:$PATH"
WORKDIR /app
COPY download_models.py ./
RUN python download_models.py
# Stage 3: production runtime
FROM python:3.11-slim as runtime
RUN apt-get update && apt-get install -y libgomp1 && \
rm -rf /var/lib/apt/lists/* && \
groupadd -r appuser && useradd -r -g appuser appuser
COPY --from=deps-builder /venv /venv
ENV PATH="/venv/bin:$PATH"
COPY --from=model-stage /app/models /models
WORKDIR /app
COPY . .
RUN chown -R appuser:appuser /app
USER appuser
HEALTHCHECK --interval=30s --timeout=10s \
CMD python -c "import requests; requests.get('http://localhost:8000/health')"
EXPOSE 8000
CMD ["gunicorn", "--bind", "0.0.0.0:8000", "--workers", "2", "app:app"]
Use python:3.11-slim-bookworm, not Alpine. Alpine’s musl libc is incompatible with most ML wheels and builds run roughly 50× slower.
| Strategy | Cold start (140MB) | Cold start (1GB) | When |
|---|---|---|---|
| Baked into image | 5.6s | 17.6s | Small models (<500MB) |
| S3 loading | 5.7s | 20.2s | Large models, frequent updates |
| Provisioned concurrency | 0s | 0s | Production APIs (~$15/mo) |
Lazy (@lru_cache) | First request: 10s | First request: 30s | Low-traffic endpoints |
CPU-based autoscaling responds in 30–60s. GPU and queue-length signals: 5–15s.
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: ml-inference-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: ml-inference-deployment
minReplicas: 2
maxReplicas: 15
metrics:
- type: External
external:
metric:
name: dcgm_fi_dev_gpu_util # NVIDIA DCGM exporter
target:
type: AverageValue
averageValue: "75"
- type: Pods
pods:
metric:
name: tgi_queue_size
target:
type: AverageValue
averageValue: "10"
behavior:
scaleDown:
stabilizationWindowSeconds: 300
policies:
- type: Percent
value: 50
periodSeconds: 60
scaleUp:
stabilizationWindowSeconds: 0
policies:
- type: Percent
value: 100
periodSeconds: 30
Scale up fast; scale down slow. Thrashing is more expensive than excess capacity.
GPUs must be in limits only, with requests matching:
resources:
limits:
nvidia.com/gpu: 1
requests:
nvidia.com/gpu: 1
Node affinity for specific GPU types:
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: accelerator
operator: In
values:
- nvidia-tesla-v100
- nvidia-tesla-a100
| Metric | Serverless | Kubernetes |
|---|---|---|
| Traffic | <1,000 req/day | >10K req/day |
| Latency SLA | Best-effort | <100ms required |
| GPU needs | None | Required |
| Cost | Optimize idle | Optimize throughput |
Cold-start reality:
Mitigations in order of cost:
Dependency injection makes the service testable and the resource lifecycle explicit.
from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
class MLServiceContainer(containers.DeclarativeContainer):
config = providers.Configuration()
model = providers.Singleton(
MLModel,
model_path=config.model_path
)
prediction_service = providers.Factory(
PredictionService,
model=model
)
@inject
def predict(
data: dict,
service: PredictionService = Provide[MLServiceContainer.prediction_service]
):
return service.predict(data)
| Probe | Endpoint | Purpose | Timeout |
|---|---|---|---|
| Liveness | /health | Process alive | 5s |
| Readiness | /health/ready | Dependencies + model loaded | 10s |
| Startup | /health/startup | Extended initialization | 150s (30 × 5s) |
import signal
import time
from threading import Lock
is_shutting_down = False
active_requests = 0
request_lock = Lock()
def shutdown_handler(signum, frame):
global is_shutting_down
is_shutting_down = True
logger.info("Shutdown initiated, draining requests...")
while active_requests > 0:
time.sleep(0.1)
model.cleanup()
db_connection.close()
logger.info("Shutdown complete")
sys.exit(0)
signal.signal(signal.SIGTERM, shutdown_handler)
@app.before_request
def check_shutdown():
if is_shutting_down:
abort(503, "Service shutting down")
global active_requests
with request_lock:
active_requests += 1
@app.after_request
def decrement_counter(response):
global active_requests
with request_lock:
active_requests -= 1
return response
Semantic versioning:
MLflow registry:
import mlflow
mlflow.register_model(
model_uri="runs:/abc123/model",
name="fraud_detector"
)
client = mlflow.tracking.MlflowClient()
client.transition_model_version_stage(
name="fraud_detector",
version=5,
stage="Production",
archive_existing_versions=True
)
50/50 splits are optimal — variance is minimized when both arms are equal. A 95/5 split has 5× more variance for the same total traffic, which is more samples or less power either way.
Sample size:
import numpy as np
from scipy import stats
def calculate_sample_size(
baseline_rate: float,
minimum_detectable_effect: float,
alpha: float = 0.05,
power: float = 0.8
) -> int:
z_alpha = stats.norm.ppf(1 - alpha/2)
z_beta = stats.norm.ppf(power)
p1 = baseline_rate
p2 = baseline_rate * (1 + minimum_detectable_effect)
p_avg = (p1 + p2) / 2
n = ((z_alpha * np.sqrt(2 * p_avg * (1 - p_avg)) +
z_beta * np.sqrt(p1 * (1 - p1) + p2 * (1 - p2))) /
(p2 - p1)) ** 2
return int(np.ceil(n))
# 3% baseline, 20% relative improvement
n = calculate_sample_size(0.03, 0.20) # ~6,500 per group
Minimum test duration: max(sample_size / daily_users, 14 days) — the 14-day floor captures weekly patterns.
Standard progression: 5% → 25% → 50% → 100%, with 15min / 30min / 60min pauses.
apiVersion: argoproj.io/v1alpha1
kind: Rollout
metadata:
name: ml-model
spec:
replicas: 10
strategy:
canary:
steps:
- setWeight: 5
- pause: {duration: 15m}
- setWeight: 25
- pause: {duration: 30m}
- setWeight: 50
- pause: {duration: 1h}
analysis:
templates:
- templateName: success-metrics
startingStep: 1
Automatic rollback on: error rate >10%, latency >1s, accuracy <70%, or three consecutive checks past a warning threshold.
| Pattern | When | Split | Duration | Rollback |
|---|---|---|---|---|
| A/B test | Statistical comparison | 50/50 | 1–2 weeks | Moderate |
| Canary | Gradual rollout | 5 → 25 → 50 → 100% | 2–4 hours | Fast (<1min) |
| Blue-green | Instant rollback needed | 0 → 100% | Immediate | Instant |
| Shadow | Zero user impact | 0% (observe) | Continuous | N/A |
| Technique | Speedup | Memory reduction | Accuracy cost | Complexity |
|---|---|---|---|---|
| INT8 quantization | 2–4× | 4× | 1–3% | Medium |
| FP16 quantization | 2× | 2× | <1% | Low |
| Structured pruning | 1.5–2× | 3× | 2–5% | Medium–high |
| Knowledge distillation | 2–5× | 2–10× | 2–5% | High |
| TensorRT | 5× | — | Minimal | Low–medium |
| ONNX Runtime | 3× | — | None | Low |
| Redis caching | 100–1000× | — | None | Low |
| Dynamic batching | 3–4× | — | None | Low |
The best return on the list. INT8: 2–4× speedup, 4× memory cut, <1% accuracy loss in most cases.
TensorRT measured speedups1314:
QAT vs. PTQ matters for complex models:
PTQ on a complex model can crater accuracy. Use QAT.
import torch.quantization
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
for data in calibration_loader:
model(data)
torch.quantization.convert(model, inplace=True)
| Framework | Platform | Speedup | Fits |
|---|---|---|---|
| TensorRT | NVIDIA GPU | 5× | Maximum GPU performance |
| ONNX Runtime | Cross-platform | 3× | Compatibility, portability |
| OpenVINO | Intel CPU/GPU | 27% | Intel hardware, edge |
ONNX export:
import torch.onnx
import onnxruntime as ort
torch.onnx.export(
model,
dummy_input,
"model.onnx",
opset_version=13,
do_constant_folding=True
)
session = ort.InferenceSession("model.onnx")
output = session.run(None, {"input": input_data.numpy()})
TensorRT, in concrete numbers — DenseNet: 113 → 273 inferences/sec (2.4× throughput); latency 9ms → 5ms.
100–1000× for repeat inputs.
import redis
import json
import hashlib
class RedisCachingDecorator:
def __init__(self, host='localhost', port=6379, prefix=None):
self._redis_client = redis.Redis(host=host, port=port, db=0)
self.prefix = prefix
def _create_cache_key(self, model_name, version, input_data):
input_hash = hashlib.sha256(
json.dumps(input_data, sort_keys=True).encode()
).hexdigest()[:16]
return f"{self.prefix}/{model_name}/{version}/{input_hash}"
def predict(self, model, input_data):
cache_key = self._create_cache_key(
model.__class__.__name__,
getattr(model, 'version', '1.0'),
input_data
)
cached_result = self._redis_client.get(cache_key)
if cached_result:
return json.loads(cached_result)
result = model.predict(input_data)
self._redis_client.setex(cache_key, 3600, json.dumps(result))
return result
TTLs:
NVIDIA Triton15: 3.7× throughput. Inception plateaus at 73 inferences/sec without batching; with dynamic batching it reaches 272 at concurrency 8.
max_batch_size: 16
dynamic_batching {
preferred_batch_size: [4, 8, 16]
max_queue_delay_microseconds: 100
priority_levels: 2
default_priority_level: 1
}
instance_group [{
count: 2
kind: KIND_GPU
gpus: [0]
}]
Latency/throughput trade — Llama-2-70B on A100:
| Batch size | Latency/token | Throughput | Speedup |
|---|---|---|---|
| 1 | 30ms | 33 tok/s | 1× |
| 4 | 35ms | 114 tok/s | 3.5× |
| 16 | 40ms | 400 tok/s | 12× |
| 64 | 120ms | 533 tok/s | 16× (plateau) |
The right batch size depends on the SLA. Hit batch 16 if your latency budget allows it.
| Percentile | Meaning | Alert |
|---|---|---|
| P50 (median) | Typical experience | <200ms |
| P95 | Tail catch | <500ms |
| P99 | Architectural bottlenecks | <1000ms |
from prometheus_client import Histogram
latency_histogram = Histogram(
'http_request_duration_seconds',
'Duration of HTTP requests',
['status', 'path', 'method'],
buckets=[0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10]
)
# P95 in PromQL:
# histogram_quantile(0.95, sum(rate(http_request_duration_seconds_bucket[5m])) by (le))
| Type | Aggregates across instances | Flexible percentiles | Use |
|---|---|---|---|
| Histogram | Yes | Yes | Distributed systems, Kubernetes |
| Summary | No (per-instance only) | No (pre-defined) | Single instance, exact values |
Default to Histogram for ML.
Beyond standard service metrics, track:
Recording rules to precompute the expensive queries:
groups:
- name: ml_sli_rules
interval: 30s
rules:
- record: ml:prediction_latency:p95
expr: |
histogram_quantile(0.95,
sum(rate(ml_prediction_duration_seconds_bucket[5m]))
by (le, model_version))
- record: ml:prediction_error_rate
expr: |
sum(rate(ml_predictions_total{result="error"}[5m]))
by (model_version) /
sum(rate(ml_predictions_total[5m]))
by (model_version)
- record: ml:slo:prediction_latency_good_ratio
expr: |
sum(rate(ml_prediction_duration_seconds_bucket{le="0.5"}[5m]))
by (model_version) /
sum(rate(ml_prediction_duration_seconds_count[5m]))
by (model_version)
Fast burn (1h window): page. Slow burn (6h): ticket.
- alert: MLLatencyFastBurn
expr: |
(ml:prediction_latency:p95 > 0.5 and
(1 - ml:slo:prediction_latency_good_ratio) > (14.4 * 0.001))
for: 2m
labels:
severity: page
- alert: MLLatencySlowBurn
expr: |
(ml:prediction_latency:p95 > 0.5 and
(1 - ml:slo:prediction_latency_good_ratio) > (6 * 0.001))
for: 15m
labels:
severity: ticket
structlog with JSON in production, plain console in development.
import structlog
import uuid
from flask import Flask, request, g
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.JSONRenderer()
]
)
logger = structlog.get_logger()
@app.before_request
def add_correlation_id():
correlation_id = request.headers.get(
'X-Correlation-ID',
str(uuid.uuid4())
)
g.correlation_id = correlation_id
structlog.contextvars.bind_contextvars(
correlation_id=correlation_id
)
@app.route('/predict', methods=['POST'])
def predict():
logger.info(
"prediction_made",
prediction=0.85,
confidence=0.92,
latency_ms=45
)
| Metric | Target | What it means |
|---|---|---|
| P50 latency | <200ms | Typical user experience |
| P95 latency | <500ms | Good experience threshold |
| P99 latency | <1000ms | Worst acceptable case |
| Availability | 99.5% | 21 min downtime/month |
| Error rate | <0.1% | 1 in 1,000 requests |
Capacity planning: track at P95, alert at 80% utilization.
| Method | Data type | Sensitivity | Sample dependency | Fits |
|---|---|---|---|---|
| KS test | Continuous | Very high | Strong | Small datasets, tiny shifts |
| PSI | Both | Low–medium | None | Large datasets, finance/credit |
| Wasserstein | Continuous | Medium | Weak | General purpose, balanced |
| Chi-square | Categorical | Medium | Moderate | Categorical features |
| JS divergence | Both | Medium | Weak | Symmetric comparison |
PSI = Σ((%_new − %_old) × ln(%_new / %_old))
import numpy as np
def calculate_psi(expected, actual, buckets=10):
"""Calculate Population Stability Index"""
breakpoints = np.arange(0, buckets + 1) / buckets * 100
expected_percents = np.histogram(expected, bins=breakpoints)[0] / len(expected)
actual_percents = np.histogram(actual, bins=breakpoints)[0] / len(actual)
psi_values = (actual_percents - expected_percents) * \
np.log(actual_percents / expected_percents)
psi_values[np.isnan(psi_values)] = 0
psi_values[np.isinf(psi_values)] = 0
return np.sum(psi_values)
Thresholds:
| Severity | PSI | Accuracy drop | Action |
|---|---|---|---|
| Small | 0.10–0.15 | <2% | Automated retrain |
| Moderate | 0.15–0.25 | 2–5% | Human review |
| Severe | >0.25 | >5% | Emergency |
Plus scheduled retraining (weekly/monthly), minimum-data thresholds, and cost-benefit checks.
from pydantic import BaseModel, validator, Field
class PredictionInput(BaseModel):
age: int = Field(..., ge=18, le=100)
income: float = Field(..., gt=0, lt=10000000)
credit_score: int = Field(..., ge=300, le=850)
@validator('income')
def income_outlier_check(cls, v):
if v > 500000:
raise ValueError('income value suspicious')
return v
| Layer | Practice | Implementation |
|---|---|---|
| Input | Schema validation | Pydantic, strict types |
| Input | Adversarial detection | Confidence threshold, perturbation |
| API | Authentication | JWT RS256, API key rotation |
| API | Rate limiting | Token bucket, 100 req/min |
| Model | Access control | RBAC, audit logs, least privilege |
| Model | Encryption | AES-256 at rest, TLS 1.2+ in transit |
| Pipeline | Dependency scanning | safety, Snyk, version pinning |
| Compliance | Audit trails | Append-only logs, hash chaining |
Token bucket with Redis:
import redis
from datetime import datetime
class RateLimiter:
def __init__(self, redis_client):
self.redis = redis_client
def check_rate_limit(
self,
user_id: str,
max_requests: int = 100,
window_seconds: int = 60
):
key = f"rate_limit:{user_id}:{datetime.now().minute}"
current = self.redis.incr(key)
if current == 1:
self.redis.expire(key, window_seconds)
if current > max_requests:
return False, f"Rate limit exceeded: {current}/{max_requests}"
return True, f"Requests remaining: {max_requests - current}"
Defaults:
| Pattern | Purpose | When | Complexity |
|---|---|---|---|
| Retry | Transient failures | Network timeouts, rate limits | Low |
| Circuit breaker | Cascading failures | Service degradation | Medium |
| Fallback | Graceful degradation | Circuit open | Low–medium |
| Timeout | Bound latency | Hung requests | Low |
| Bulkhead | Resource isolation | Prevent exhaustion | Medium |
from tenacity import retry, stop_after_attempt, wait_exponential
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10)
)
def call_external_api(data):
response = requests.post(API_URL, json=data, timeout=5)
response.raise_for_status()
return response.json()
Delays: 4s, 8s, 10s.
from pybreaker import CircuitBreaker, CircuitBreakerError
breaker = CircuitBreaker(
fail_max=5,
reset_timeout=60,
exclude=[ValueError]
)
@breaker
def call_prediction_service(data):
response = requests.post(PREDICTION_URL, json=data, timeout=5)
response.raise_for_status()
return response.json()
try:
result = call_prediction_service(input_data)
except CircuitBreakerError:
result = get_cached_prediction(input_data)
States: CLOSED (normal), OPEN (fail fast after 5 consecutive failures), HALF-OPEN (probe recovery after 60s).
Serving:
First optimizations (return-on-effort order):
Monitoring baseline:
Security baseline:
Part 2: building Atlas’s forecasting system in production.
Production ML Systems — Google for Developers. ↩
Supporting Diverse ML Systems at Netflix — Netflix Technology Blog. ↩
Scaling Media Machine Learning at Netflix — Netflix Technology Blog. ↩
Meet Michelangelo: Uber’s Machine Learning Platform — Uber Engineering. ↩
Upgrading Uber’s MySQL Fleet to version 8.0 — Uber Engineering. ↩
NVIDIA TensorRT Best Practices Guide — ResNet benchmarks. ↩