// article
Training a model and serving a model are different problems. The first one is mostly captured in libraries; the second one is mostly captured in operational practice. This piece is the reference I wish I’d had when I first put models in front of real traffic — serving patterns, containerization, orchestration, monitoring, drift, optimization, security, and three case studies from production systems I built.
It is not a guide to writing a model. It is a guide to keeping one alive after you deploy it.
A trained model is just a file. The hard problems are everything around it:
The rest of this piece walks each of those, in the order you’ll hit them.
from flask import Flask, request, jsonify
import joblib
import numpy as np
app = Flask(__name__)
model = joblib.load('model.pkl')
@app.route('/predict', methods=['POST'])
def predict():
data = request.json
features = np.array(data['features']).reshape(1, -1)
prediction = model.predict(features)
return jsonify({
'prediction': prediction.tolist(),
'model_version': '1.2.3',
'timestamp': datetime.utcnow().isoformat()
})
Trade-off: HTTP overhead adds 5–20ms of latency. For most internal services and batch consumers, that’s fine. Reach for gRPC or streaming only when the latency budget says you must.
syntax = "proto3";
service ModelService {
rpc Predict(PredictRequest) returns (PredictResponse);
rpc BatchPredict(BatchPredictRequest) returns (BatchPredictResponse);
}
message PredictRequest {
repeated float features = 1;
string model_version = 2;
}
message PredictResponse {
repeated float predictions = 1;
float confidence = 2;
int64 latency_ms = 3;
}
50–90% lower wire latency than REST plus strong typing. The cost is the proto schema, language-specific clients, and tooling that’s less universal than curl.
from kafka import KafkaConsumer, KafkaProducer
import json
consumer = KafkaConsumer('feature_stream')
producer = KafkaProducer('prediction_stream')
for message in consumer:
features = json.loads(message.value)
prediction = model.predict([features['data']])
result = {
'prediction': prediction.tolist(),
'input_id': features['id'],
'timestamp': time.time()
}
producer.send('prediction_stream', json.dumps(result))
Use when predictions are an output of an event pipeline rather than a request/response — fraud scoring, sensor telemetry, real-time personalization.
A production Dockerfile for an ML service does five things: pins the runtime, installs build tools only for the build, drops privileges, declares a healthcheck, and runs a real WSGI server (not the Flask dev server).
FROM python:3.9-slim
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd --create-home --shell /bin/bash mluser
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY --chown=mluser:mluser . .
USER mluser
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
EXPOSE 8000
CMD ["gunicorn", "--bind", "0.0.0.0:8000", "--workers", "4", "app:app"]
Training images carry build tools and training data. Serving images shouldn’t. Split them:
# Training stage
FROM python:3.9 as trainer
WORKDIR /app
COPY training/ .
RUN python train_model.py --output /models/
# Serving stage
FROM python:3.9-slim as serving
WORKDIR /app
COPY --from=trainer /models/ /app/models/
COPY requirements-serving.txt .
RUN pip install -r requirements-serving.txt
COPY serve.py .
CMD ["python", "serve.py"]
The serving image ships only what’s needed to load the model and answer requests.
The deployment manifest below is the minimum I’d ship: resource requests and limits, readiness and liveness probes (different intervals — readiness fast, liveness slow), and version labels for traffic routing.
apiVersion: apps/v1
kind: Deployment
metadata:
name: ml-model-serving
labels:
app: ml-model
version: v1.2.3
spec:
replicas: 3
selector:
matchLabels:
app: ml-model
template:
metadata:
labels:
app: ml-model
version: v1.2.3
spec:
containers:
- name: model-server
image: your-registry/ml-model:v1.2.3
ports:
- containerPort: 8000
resources:
requests:
cpu: 500m
memory: 1Gi
limits:
cpu: 2000m
memory: 4Gi
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 60
periodSeconds: 30
env:
- name: MODEL_VERSION
value: "v1.2.3"
- name: LOG_LEVEL
value: "INFO"
For most model servers, CPU and memory are the right signals. GPU servers want custom metrics — queue length or in-flight requests.
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: ml-model-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: ml-model-serving
minReplicas: 2
maxReplicas: 20
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: 80
class ModelVersion:
def __init__(self, major, minor, patch):
self.major = major
self.minor = minor
self.patch = patch
def __str__(self):
return f"{self.major}.{self.minor}.{self.patch}"
def is_compatible(self, other):
return self.major == other.major
current_version = ModelVersion(2, 1, 3)
new_version = ModelVersion(2, 2, 0)
if current_version.is_compatible(new_version):
deploy_model(new_version)
10% of traffic to v2, rest to v1. Headers let you force a single client onto the canary for testing.
apiVersion: networking.istio.io/v1beta1
kind: VirtualService
metadata:
name: ml-model-route
spec:
http:
- match:
- headers:
canary:
exact: "true"
route:
- destination:
host: ml-model-serving
subset: v2
- route:
- destination:
host: ml-model-serving
subset: v1
weight: 90
- destination:
host: ml-model-serving
subset: v2
weight: 10
Don’t compare two models by eyeballing accuracy. Test it.
import scipy.stats as stats
from dataclasses import dataclass
from typing import List, Tuple
@dataclass
class ModelMetrics:
predictions: List[float]
ground_truth: List[float]
response_times: List[float]
def compare_models(model_a: ModelMetrics, model_b: ModelMetrics) -> dict:
"""Compare two models using statistical tests"""
accuracy_a = calculate_accuracy(model_a.predictions, model_a.ground_truth)
accuracy_b = calculate_accuracy(model_b.predictions, model_b.ground_truth)
t_stat, p_value = stats.ttest_ind(
model_a.response_times,
model_b.response_times
)
return {
'accuracy_improvement': accuracy_b - accuracy_a,
'performance_p_value': p_value,
'significant_performance_diff': p_value < 0.05,
'recommendation': 'deploy_b' if accuracy_b > accuracy_a and p_value < 0.05 else 'keep_a'
}
System metrics and ML metrics are different things. Track both.
from prometheus_client import Counter, Histogram, Gauge, start_http_server
import time
REQUEST_COUNT = Counter('ml_requests_total', 'Total ML prediction requests')
REQUEST_LATENCY = Histogram('ml_request_duration_seconds', 'ML prediction latency')
MODEL_ACCURACY = Gauge('ml_model_accuracy', 'Current model accuracy')
PREDICTION_DRIFT = Gauge('ml_prediction_drift', 'Distribution drift score')
def track_prediction(func):
def wrapper(*args, **kwargs):
start_time = time.time()
REQUEST_COUNT.inc()
try:
result = func(*args, **kwargs)
return result
finally:
REQUEST_LATENCY.observe(time.time() - start_time)
return wrapper
@track_prediction
def predict(features):
return model.predict(features)
Kolmogorov-Smirnov on each feature against a reference distribution. Per-feature p-values plus an any aggregator. Set the threshold conservatively — a 0.05 p-value on every request is going to false-positive constantly.
import numpy as np
from scipy.stats import ks_2samp
from sklearn.preprocessing import StandardScaler
class DriftDetector:
def __init__(self, reference_data: np.ndarray, threshold: float = 0.05):
self.reference_data = reference_data
self.threshold = threshold
self.scaler = StandardScaler()
self.scaler.fit(reference_data)
def detect_drift(self, new_data: np.ndarray) -> dict:
"""Detect drift using Kolmogorov-Smirnov test"""
ref_normalized = self.scaler.transform(self.reference_data)
new_normalized = self.scaler.transform(new_data)
drift_scores = []
for feature_idx in range(ref_normalized.shape[1]):
statistic, p_value = ks_2samp(
ref_normalized[:, feature_idx],
new_normalized[:, feature_idx]
)
drift_scores.append({
'feature': feature_idx,
'ks_statistic': statistic,
'p_value': p_value,
'drift_detected': p_value < self.threshold
})
overall_drift = any(score['drift_detected'] for score in drift_scores)
return {
'overall_drift': overall_drift,
'feature_scores': drift_scores,
'recommendation': 'retrain_model' if overall_drift else 'continue_monitoring'
}
groups:
- name: ml-model-alerts
rules:
- alert: ModelAccuracyDrop
expr: ml_model_accuracy < 0.85
for: 5m
labels:
severity: warning
annotations:
summary: "Model accuracy below threshold"
description: "Model accuracy has dropped to {{ $value }}"
- alert: HighPredictionLatency
expr: histogram_quantile(0.95, ml_request_duration_seconds) > 0.5
for: 2m
labels:
severity: critical
annotations:
summary: "High prediction latency detected"
description: "95th percentile latency: {{ $value }}s"
- alert: DataDriftDetected
expr: ml_prediction_drift > 0.1
for: 10m
labels:
severity: warning
annotations:
summary: "Data drift detected"
description: "Input distribution has changed significantly"
Cuts model size and inference time, costs you some accuracy. Worth measuring the trade.
import torch
import torch.quantization as quantization
def quantize_model(model, representative_data):
"""Apply dynamic quantization to PyTorch model"""
model.eval()
model.qconfig = quantization.get_default_qconfig('fbgemm')
model_prepared = quantization.prepare(model, inplace=False)
with torch.no_grad():
for data in representative_data:
model_prepared(data)
model_quantized = quantization.convert(model_prepared, inplace=False)
return model_quantized
def quantize_keras_model(model):
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_model = converter.convert()
return quantized_model
Drop the smallest-magnitude weights. 30% pruning is a safe default for most fully-connected layers.
import torch.nn.utils.prune as prune
def prune_model(model, pruning_amount=0.3):
"""Apply magnitude-based pruning"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=pruning_amount)
prune.remove(module, 'weight')
return model
If the same input shows up repeatedly — and in many real services it does — cache the answer.
import redis
import pickle
import hashlib
from functools import wraps
class PredictionCache:
def __init__(self, redis_url: str, ttl: int = 3600):
self.redis_client = redis.from_url(redis_url)
self.ttl = ttl
def cache_key(self, features: list) -> str:
"""Generate consistent cache key from features"""
feature_str = str(sorted(features))
return hashlib.md5(feature_str.encode()).hexdigest()
def get(self, features: list):
key = self.cache_key(features)
cached = self.redis_client.get(key)
if cached:
return pickle.loads(cached)
return None
def set(self, features: list, prediction):
key = self.cache_key(features)
self.redis_client.setex(
key,
self.ttl,
pickle.dumps(prediction)
)
def cached_prediction(cache: PredictionCache):
def decorator(predict_func):
@wraps(predict_func)
def wrapper(features):
cached_result = cache.get(features)
if cached_result is not None:
return cached_result
result = predict_func(features)
cache.set(features, result)
return result
return wrapper
return decorator
For throughput-bound workloads, accumulating a few requests and running them as a batch beats serving them one-at-a-time — especially on GPU. The trade is added tail latency.
import asyncio
from collections import deque
import time
class BatchProcessor:
def __init__(self, model, max_batch_size=32, max_wait_time=0.1):
self.model = model
self.max_batch_size = max_batch_size
self.max_wait_time = max_wait_time
self.batch_queue = deque()
self.processing = False
async def predict(self, features):
"""Add prediction request to batch queue"""
future = asyncio.Future()
request = {
'features': features,
'future': future,
'timestamp': time.time()
}
self.batch_queue.append(request)
if not self.processing:
asyncio.create_task(self._process_batch())
return await future
async def _process_batch(self):
if self.processing:
return
self.processing = True
try:
await asyncio.sleep(self.max_wait_time)
if not self.batch_queue:
return
batch_requests = []
batch_features = []
while (self.batch_queue and
len(batch_requests) < self.max_batch_size):
request = self.batch_queue.popleft()
batch_requests.append(request)
batch_features.append(request['features'])
predictions = self.model.predict(batch_features)
for request, prediction in zip(batch_requests, predictions):
request['future'].set_result(prediction)
finally:
self.processing = False
if self.batch_queue:
asyncio.create_task(self._process_batch())
The model trusts whatever the API hands it. The API shouldn’t trust whatever the client hands it.
import numpy as np
from typing import List, Tuple
class InputValidator:
def __init__(self, feature_ranges: dict, max_request_size: int = 1000):
self.feature_ranges = feature_ranges
self.max_request_size = max_request_size
def validate_features(self, features: List[float]) -> Tuple[bool, str]:
if len(features) > self.max_request_size:
return False, f"Request too large: {len(features)} features"
for i, value in enumerate(features):
if not isinstance(value, (int, float)):
return False, f"Invalid type for feature {i}: {type(value)}"
if np.isnan(value) or np.isinf(value):
return False, f"Invalid value for feature {i}: {value}"
if i in self.feature_ranges:
min_val, max_val = self.feature_ranges[i]
if not (min_val <= value <= max_val):
return False, f"Feature {i} out of range: {value}"
return True, "Valid"
def detect_adversarial_pattern(self, features: List[float]) -> bool:
"""Simple adversarial input detection"""
feature_array = np.array(features)
if np.var(feature_array) > 1000:
return True
boundary_count = sum(1 for i, val in enumerate(features)
if i in self.feature_ranges and
(val == self.feature_ranges[i][0] or
val == self.feature_ranges[i][1]))
if boundary_count > len(features) * 0.8:
return True
return False
API key for service-to-service, JWT for user-bound requests. Both for endpoints that handle either.
from functools import wraps
import jwt
from flask import request, jsonify
def require_api_key(f):
@wraps(f)
def decorated_function(*args, **kwargs):
api_key = request.headers.get('X-API-Key')
if not api_key:
return jsonify({'error': 'API key required'}), 401
if not validate_api_key(api_key):
return jsonify({'error': 'Invalid API key'}), 403
return f(*args, **kwargs)
return decorated_function
def require_jwt(f):
@wraps(f)
def decorated_function(*args, **kwargs):
token = request.headers.get('Authorization', '').replace('Bearer ', '')
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=['HS256'])
request.user = payload
except jwt.ExpiredSignatureError:
return jsonify({'error': 'Token expired'}), 401
except jwt.InvalidTokenError:
return jsonify({'error': 'Invalid token'}), 401
return f(*args, **kwargs)
return decorated_function
@app.route('/predict', methods=['POST'])
@require_api_key
@require_jwt
def predict():
pass
Sliding window per client. The model is expensive; protect it.
import time
from collections import defaultdict, deque
class RateLimiter:
def __init__(self, max_requests: int = 100, time_window: int = 60):
self.max_requests = max_requests
self.time_window = time_window
self.requests = defaultdict(deque)
def is_allowed(self, client_id: str) -> bool:
now = time.time()
client_requests = self.requests[client_id]
while client_requests and client_requests[0] < now - self.time_window:
client_requests.popleft()
if len(client_requests) < self.max_requests:
client_requests.append(now)
return True
return False
rate_limiter = RateLimiter(max_requests=100, time_window=60)
@app.before_request
def check_rate_limit():
client_id = request.remote_addr
if not rate_limiter.is_allowed(client_id):
return jsonify({'error': 'Rate limit exceeded'}), 429
50K+ predictions per day with a sub-100ms SLA. Initial deploy had memory leaks (model loaded per worker) and inconsistent latency (cold database connections).
Three changes:
functools.lru_cache).Result: P95 latency 85ms (target <100ms), memory down 60%, uptime 99.95% (from 97.2%).
import multiprocessing as mp
from functools import lru_cache
@lru_cache(maxsize=1)
def get_model():
"""Lazy-load and cache model across processes"""
return joblib.load('/models/lead_scoring_v2.pkl')
class SharedModelServer:
def __init__(self):
self.model = get_model()
self.connection_pool = create_db_pool(max_connections=20)
def predict(self, features):
return self.model.predict_proba([features])[0][1]
Collaborative filtering for product recs. SLA was 10ms — too tight for live model inference, so the architecture moved the work elsewhere.
import onnxruntime as ort
class ONNXRecommendationModel:
def __init__(self, model_path):
self.session = ort.InferenceSession(
model_path,
providers=['CPUExecutionProvider']
)
self.input_name = self.session.get_inputs()[0].name
def predict(self, user_embedding, item_embeddings):
scores = self.session.run(
None,
{self.input_name: np.hstack([user_embedding, item_embeddings])}
)[0]
return scores
Object detection on 1M+ images per day, with a cost ceiling.
Infrastructure cost dropped 45%. Accuracy dropped from 94.2% to 94.1% — a trade the business was happy to make.
The pattern across all three: nothing exotic. Standard operational practices applied to a model the same way you’d apply them to any production service. ML deployment isn’t a separate discipline; it’s regular deployment with a few extra failure modes — drift, version skew, the temptation to treat the model as magic — that you plan for from day one.