← Research

// article

ML deployment: a working reference for getting models into production

September 10, 2024 Article

Training a model and serving a model are different problems. The first one is mostly captured in libraries; the second one is mostly captured in operational practice. This piece is the reference I wish I’d had when I first put models in front of real traffic — serving patterns, containerization, orchestration, monitoring, drift, optimization, security, and three case studies from production systems I built.

It is not a guide to writing a model. It is a guide to keeping one alive after you deploy it.

What makes ML deployment different

A trained model is just a file. The hard problems are everything around it:

  • Input distributions drift. A model that was 92% accurate in July can be 78% accurate in November because the world changed, not because the model did.
  • Multiple versions live in parallel. Production traffic, canary traffic, the rollback path, the training baseline. Each needs its own provenance.
  • Inference has tighter latency budgets than training. Sub-100ms for real-time; tens of milliseconds for recommendation engines.
  • A/B testing requires statistical machinery. Two models with similar accuracy may differ on tails that matter.
  • Compute is expensive. Especially at peak load. Autoscaling that works for stateless web servers needs adjustment for stateful, memory-heavy model containers.

The rest of this piece walks each of those, in the order you’ll hit them.

Serving patterns

REST: the default, and the right default most of the time

from flask import Flask, request, jsonify
import joblib
import numpy as np

app = Flask(__name__)
model = joblib.load('model.pkl')

@app.route('/predict', methods=['POST'])
def predict():
    data = request.json
    features = np.array(data['features']).reshape(1, -1)
    prediction = model.predict(features)

    return jsonify({
        'prediction': prediction.tolist(),
        'model_version': '1.2.3',
        'timestamp': datetime.utcnow().isoformat()
    })

Trade-off: HTTP overhead adds 5–20ms of latency. For most internal services and batch consumers, that’s fine. Reach for gRPC or streaming only when the latency budget says you must.

gRPC: when the latency budget says you must

syntax = "proto3";

service ModelService {
  rpc Predict(PredictRequest) returns (PredictResponse);
  rpc BatchPredict(BatchPredictRequest) returns (BatchPredictResponse);
}

message PredictRequest {
  repeated float features = 1;
  string model_version = 2;
}

message PredictResponse {
  repeated float predictions = 1;
  float confidence = 2;
  int64 latency_ms = 3;
}

50–90% lower wire latency than REST plus strong typing. The cost is the proto schema, language-specific clients, and tooling that’s less universal than curl.

Streaming: continuous inference on event flows

from kafka import KafkaConsumer, KafkaProducer
import json

consumer = KafkaConsumer('feature_stream')
producer = KafkaProducer('prediction_stream')

for message in consumer:
    features = json.loads(message.value)
    prediction = model.predict([features['data']])

    result = {
        'prediction': prediction.tolist(),
        'input_id': features['id'],
        'timestamp': time.time()
    }

    producer.send('prediction_stream', json.dumps(result))

Use when predictions are an output of an event pipeline rather than a request/response — fraud scoring, sensor telemetry, real-time personalization.

Containerization

A production Dockerfile for an ML service does five things: pins the runtime, installs build tools only for the build, drops privileges, declares a healthcheck, and runs a real WSGI server (not the Flask dev server).

FROM python:3.9-slim

# Install system dependencies
RUN apt-get update && apt-get install -y \
    gcc \
    g++ \
    && rm -rf /var/lib/apt/lists/*

# Create non-root user
RUN useradd --create-home --shell /bin/bash mluser

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY --chown=mluser:mluser . .

USER mluser

HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
    CMD curl -f http://localhost:8000/health || exit 1

EXPOSE 8000

CMD ["gunicorn", "--bind", "0.0.0.0:8000", "--workers", "4", "app:app"]

Multi-stage for big models

Training images carry build tools and training data. Serving images shouldn’t. Split them:

# Training stage
FROM python:3.9 as trainer
WORKDIR /app
COPY training/ .
RUN python train_model.py --output /models/

# Serving stage
FROM python:3.9-slim as serving
WORKDIR /app
COPY --from=trainer /models/ /app/models/
COPY requirements-serving.txt .
RUN pip install -r requirements-serving.txt
COPY serve.py .
CMD ["python", "serve.py"]

The serving image ships only what’s needed to load the model and answer requests.

Kubernetes

The deployment manifest below is the minimum I’d ship: resource requests and limits, readiness and liveness probes (different intervals — readiness fast, liveness slow), and version labels for traffic routing.

apiVersion: apps/v1
kind: Deployment
metadata:
  name: ml-model-serving
  labels:
    app: ml-model
    version: v1.2.3
spec:
  replicas: 3
  selector:
    matchLabels:
      app: ml-model
  template:
    metadata:
      labels:
        app: ml-model
        version: v1.2.3
    spec:
      containers:
      - name: model-server
        image: your-registry/ml-model:v1.2.3
        ports:
        - containerPort: 8000
        resources:
          requests:
            cpu: 500m
            memory: 1Gi
          limits:
            cpu: 2000m
            memory: 4Gi
        readinessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 10
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 60
          periodSeconds: 30
        env:
        - name: MODEL_VERSION
          value: "v1.2.3"
        - name: LOG_LEVEL
          value: "INFO"

Autoscaling

For most model servers, CPU and memory are the right signals. GPU servers want custom metrics — queue length or in-flight requests.

apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: ml-model-hpa
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: ml-model-serving
  minReplicas: 2
  maxReplicas: 20
  metrics:
  - type: Resource
    resource:
      name: cpu
      target:
        type: Utilization
        averageUtilization: 70
  - type: Resource
    resource:
      name: memory
      target:
        type: Utilization
        averageUtilization: 80

Versioning and canary rollouts

Semantic versioning, for models

  • Major: input/output schema changes. Callers need to update.
  • Minor: new features or improved accuracy. Backward compatible.
  • Patch: bug fixes, performance, no behavior change.
class ModelVersion:
    def __init__(self, major, minor, patch):
        self.major = major
        self.minor = minor
        self.patch = patch

    def __str__(self):
        return f"{self.major}.{self.minor}.{self.patch}"

    def is_compatible(self, other):
        return self.major == other.major

current_version = ModelVersion(2, 1, 3)
new_version = ModelVersion(2, 2, 0)

if current_version.is_compatible(new_version):
    deploy_model(new_version)

Canary deployments with Istio

10% of traffic to v2, rest to v1. Headers let you force a single client onto the canary for testing.

apiVersion: networking.istio.io/v1beta1
kind: VirtualService
metadata:
  name: ml-model-route
spec:
  http:
  - match:
    - headers:
        canary:
          exact: "true"
    route:
    - destination:
        host: ml-model-serving
        subset: v2
  - route:
    - destination:
        host: ml-model-serving
        subset: v1
      weight: 90
    - destination:
        host: ml-model-serving
        subset: v2
      weight: 10

A/B testing with statistical significance

Don’t compare two models by eyeballing accuracy. Test it.

import scipy.stats as stats
from dataclasses import dataclass
from typing import List, Tuple

@dataclass
class ModelMetrics:
    predictions: List[float]
    ground_truth: List[float]
    response_times: List[float]

def compare_models(model_a: ModelMetrics, model_b: ModelMetrics) -> dict:
    """Compare two models using statistical tests"""

    accuracy_a = calculate_accuracy(model_a.predictions, model_a.ground_truth)
    accuracy_b = calculate_accuracy(model_b.predictions, model_b.ground_truth)

    t_stat, p_value = stats.ttest_ind(
        model_a.response_times,
        model_b.response_times
    )

    return {
        'accuracy_improvement': accuracy_b - accuracy_a,
        'performance_p_value': p_value,
        'significant_performance_diff': p_value < 0.05,
        'recommendation': 'deploy_b' if accuracy_b > accuracy_a and p_value < 0.05 else 'keep_a'
    }

Monitoring

System metrics and ML metrics are different things. Track both.

from prometheus_client import Counter, Histogram, Gauge, start_http_server
import time

REQUEST_COUNT = Counter('ml_requests_total', 'Total ML prediction requests')
REQUEST_LATENCY = Histogram('ml_request_duration_seconds', 'ML prediction latency')
MODEL_ACCURACY = Gauge('ml_model_accuracy', 'Current model accuracy')
PREDICTION_DRIFT = Gauge('ml_prediction_drift', 'Distribution drift score')

def track_prediction(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        REQUEST_COUNT.inc()

        try:
            result = func(*args, **kwargs)
            return result
        finally:
            REQUEST_LATENCY.observe(time.time() - start_time)

    return wrapper

@track_prediction
def predict(features):
    return model.predict(features)

Drift detection

Kolmogorov-Smirnov on each feature against a reference distribution. Per-feature p-values plus an any aggregator. Set the threshold conservatively — a 0.05 p-value on every request is going to false-positive constantly.

import numpy as np
from scipy.stats import ks_2samp
from sklearn.preprocessing import StandardScaler

class DriftDetector:
    def __init__(self, reference_data: np.ndarray, threshold: float = 0.05):
        self.reference_data = reference_data
        self.threshold = threshold
        self.scaler = StandardScaler()
        self.scaler.fit(reference_data)

    def detect_drift(self, new_data: np.ndarray) -> dict:
        """Detect drift using Kolmogorov-Smirnov test"""

        ref_normalized = self.scaler.transform(self.reference_data)
        new_normalized = self.scaler.transform(new_data)

        drift_scores = []

        for feature_idx in range(ref_normalized.shape[1]):
            statistic, p_value = ks_2samp(
                ref_normalized[:, feature_idx],
                new_normalized[:, feature_idx]
            )
            drift_scores.append({
                'feature': feature_idx,
                'ks_statistic': statistic,
                'p_value': p_value,
                'drift_detected': p_value < self.threshold
            })

        overall_drift = any(score['drift_detected'] for score in drift_scores)

        return {
            'overall_drift': overall_drift,
            'feature_scores': drift_scores,
            'recommendation': 'retrain_model' if overall_drift else 'continue_monitoring'
        }

Alerts

groups:
- name: ml-model-alerts
  rules:
  - alert: ModelAccuracyDrop
    expr: ml_model_accuracy < 0.85
    for: 5m
    labels:
      severity: warning
    annotations:
      summary: "Model accuracy below threshold"
      description: "Model accuracy has dropped to {{ $value }}"

  - alert: HighPredictionLatency
    expr: histogram_quantile(0.95, ml_request_duration_seconds) > 0.5
    for: 2m
    labels:
      severity: critical
    annotations:
      summary: "High prediction latency detected"
      description: "95th percentile latency: {{ $value }}s"

  - alert: DataDriftDetected
    expr: ml_prediction_drift > 0.1
    for: 10m
    labels:
      severity: warning
    annotations:
      summary: "Data drift detected"
      description: "Input distribution has changed significantly"

Optimization

Quantization

Cuts model size and inference time, costs you some accuracy. Worth measuring the trade.

import torch
import torch.quantization as quantization

def quantize_model(model, representative_data):
    """Apply dynamic quantization to PyTorch model"""

    model.eval()
    model.qconfig = quantization.get_default_qconfig('fbgemm')
    model_prepared = quantization.prepare(model, inplace=False)

    with torch.no_grad():
        for data in representative_data:
            model_prepared(data)

    model_quantized = quantization.convert(model_prepared, inplace=False)

    return model_quantized

def quantize_keras_model(model):
    import tensorflow as tf

    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    quantized_model = converter.convert()

    return quantized_model

Pruning

Drop the smallest-magnitude weights. 30% pruning is a safe default for most fully-connected layers.

import torch.nn.utils.prune as prune

def prune_model(model, pruning_amount=0.3):
    """Apply magnitude-based pruning"""

    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=pruning_amount)
            prune.remove(module, 'weight')

    return model

Caching predictions

If the same input shows up repeatedly — and in many real services it does — cache the answer.

import redis
import pickle
import hashlib
from functools import wraps

class PredictionCache:
    def __init__(self, redis_url: str, ttl: int = 3600):
        self.redis_client = redis.from_url(redis_url)
        self.ttl = ttl

    def cache_key(self, features: list) -> str:
        """Generate consistent cache key from features"""
        feature_str = str(sorted(features))
        return hashlib.md5(feature_str.encode()).hexdigest()

    def get(self, features: list):
        key = self.cache_key(features)
        cached = self.redis_client.get(key)

        if cached:
            return pickle.loads(cached)
        return None

    def set(self, features: list, prediction):
        key = self.cache_key(features)
        self.redis_client.setex(
            key,
            self.ttl,
            pickle.dumps(prediction)
        )

def cached_prediction(cache: PredictionCache):
    def decorator(predict_func):
        @wraps(predict_func)
        def wrapper(features):
            cached_result = cache.get(features)
            if cached_result is not None:
                return cached_result

            result = predict_func(features)
            cache.set(features, result)

            return result
        return wrapper
    return decorator

Batch processing

For throughput-bound workloads, accumulating a few requests and running them as a batch beats serving them one-at-a-time — especially on GPU. The trade is added tail latency.

import asyncio
from collections import deque
import time

class BatchProcessor:
    def __init__(self, model, max_batch_size=32, max_wait_time=0.1):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_wait_time = max_wait_time
        self.batch_queue = deque()
        self.processing = False

    async def predict(self, features):
        """Add prediction request to batch queue"""
        future = asyncio.Future()
        request = {
            'features': features,
            'future': future,
            'timestamp': time.time()
        }

        self.batch_queue.append(request)

        if not self.processing:
            asyncio.create_task(self._process_batch())

        return await future

    async def _process_batch(self):
        if self.processing:
            return

        self.processing = True

        try:
            await asyncio.sleep(self.max_wait_time)

            if not self.batch_queue:
                return

            batch_requests = []
            batch_features = []

            while (self.batch_queue and
                   len(batch_requests) < self.max_batch_size):
                request = self.batch_queue.popleft()
                batch_requests.append(request)
                batch_features.append(request['features'])

            predictions = self.model.predict(batch_features)

            for request, prediction in zip(batch_requests, predictions):
                request['future'].set_result(prediction)

        finally:
            self.processing = False

            if self.batch_queue:
                asyncio.create_task(self._process_batch())

Security

Input validation

The model trusts whatever the API hands it. The API shouldn’t trust whatever the client hands it.

import numpy as np
from typing import List, Tuple

class InputValidator:
    def __init__(self, feature_ranges: dict, max_request_size: int = 1000):
        self.feature_ranges = feature_ranges
        self.max_request_size = max_request_size

    def validate_features(self, features: List[float]) -> Tuple[bool, str]:
        if len(features) > self.max_request_size:
            return False, f"Request too large: {len(features)} features"

        for i, value in enumerate(features):
            if not isinstance(value, (int, float)):
                return False, f"Invalid type for feature {i}: {type(value)}"

            if np.isnan(value) or np.isinf(value):
                return False, f"Invalid value for feature {i}: {value}"

            if i in self.feature_ranges:
                min_val, max_val = self.feature_ranges[i]
                if not (min_val <= value <= max_val):
                    return False, f"Feature {i} out of range: {value}"

        return True, "Valid"

    def detect_adversarial_pattern(self, features: List[float]) -> bool:
        """Simple adversarial input detection"""

        feature_array = np.array(features)

        if np.var(feature_array) > 1000:
            return True

        boundary_count = sum(1 for i, val in enumerate(features)
                           if i in self.feature_ranges and
                           (val == self.feature_ranges[i][0] or
                            val == self.feature_ranges[i][1]))

        if boundary_count > len(features) * 0.8:
            return True

        return False

Auth

API key for service-to-service, JWT for user-bound requests. Both for endpoints that handle either.

from functools import wraps
import jwt
from flask import request, jsonify

def require_api_key(f):
    @wraps(f)
    def decorated_function(*args, **kwargs):
        api_key = request.headers.get('X-API-Key')

        if not api_key:
            return jsonify({'error': 'API key required'}), 401

        if not validate_api_key(api_key):
            return jsonify({'error': 'Invalid API key'}), 403

        return f(*args, **kwargs)
    return decorated_function

def require_jwt(f):
    @wraps(f)
    def decorated_function(*args, **kwargs):
        token = request.headers.get('Authorization', '').replace('Bearer ', '')

        try:
            payload = jwt.decode(token, JWT_SECRET, algorithms=['HS256'])
            request.user = payload
        except jwt.ExpiredSignatureError:
            return jsonify({'error': 'Token expired'}), 401
        except jwt.InvalidTokenError:
            return jsonify({'error': 'Invalid token'}), 401

        return f(*args, **kwargs)
    return decorated_function

@app.route('/predict', methods=['POST'])
@require_api_key
@require_jwt
def predict():
    pass

Rate limiting

Sliding window per client. The model is expensive; protect it.

import time
from collections import defaultdict, deque

class RateLimiter:
    def __init__(self, max_requests: int = 100, time_window: int = 60):
        self.max_requests = max_requests
        self.time_window = time_window
        self.requests = defaultdict(deque)

    def is_allowed(self, client_id: str) -> bool:
        now = time.time()
        client_requests = self.requests[client_id]

        while client_requests and client_requests[0] < now - self.time_window:
            client_requests.popleft()

        if len(client_requests) < self.max_requests:
            client_requests.append(now)
            return True

        return False

rate_limiter = RateLimiter(max_requests=100, time_window=60)

@app.before_request
def check_rate_limit():
    client_id = request.remote_addr
    if not rate_limiter.is_allowed(client_id):
        return jsonify({'error': 'Rate limit exceeded'}), 429

Three production examples

Lead scoring at BTL Industries

50K+ predictions per day with a sub-100ms SLA. Initial deploy had memory leaks (model loaded per worker) and inconsistent latency (cold database connections).

Three changes:

  • Load the model once, share across workers (functools.lru_cache).
  • Connection pool the database calls.
  • Batch similar requests on GPU when scoring leads in bulk.

Result: P95 latency 85ms (target <100ms), memory down 60%, uptime 99.95% (from 97.2%).

import multiprocessing as mp
from functools import lru_cache

@lru_cache(maxsize=1)
def get_model():
    """Lazy-load and cache model across processes"""
    return joblib.load('/models/lead_scoring_v2.pkl')

class SharedModelServer:
    def __init__(self):
        self.model = get_model()
        self.connection_pool = create_db_pool(max_connections=20)

    def predict(self, features):
        return self.model.predict_proba([features])[0][1]

Real-time recommendations

Collaborative filtering for product recs. SLA was 10ms — too tight for live model inference, so the architecture moved the work elsewhere.

  • Serving: gRPC, persistent connections.
  • Feature store: Redis cluster holding precomputed user/item embeddings.
  • Format: ONNX, for portability and faster inference than the original framework.
  • Search: Faiss for approximate-nearest-neighbor lookups.
  • Cold start: cache top-N popular items as a fallback.
import onnxruntime as ort

class ONNXRecommendationModel:
    def __init__(self, model_path):
        self.session = ort.InferenceSession(
            model_path,
            providers=['CPUExecutionProvider']
        )
        self.input_name = self.session.get_inputs()[0].name

    def predict(self, user_embedding, item_embeddings):
        scores = self.session.run(
            None,
            {self.input_name: np.hstack([user_embedding, item_embeddings])}
        )[0]
        return scores

Computer vision at scale

Object detection on 1M+ images per day, with a cost ceiling.

  • Kubernetes HPA on queue length, not CPU.
  • TensorRT quantization: inference time 200ms → 60ms.
  • Perceptual hashing on inputs to cache results for near-duplicate images.

Infrastructure cost dropped 45%. Accuracy dropped from 94.2% to 94.1% — a trade the business was happy to make.


The pattern across all three: nothing exotic. Standard operational practices applied to a model the same way you’d apply them to any production service. ML deployment isn’t a separate discipline; it’s regular deployment with a few extra failure modes — drift, version skew, the temptation to treat the model as magic — that you plan for from day one.