Back to Essays

Conservative Q-Learning for Safe Sepsis Treatment: Learning Optimal ICU Policies from Historical Data

How offline reinforcement learning can save lives by learning safer treatment strategies without experimentation on patients.

January 18, 2026 12 min read

Introduction

Every year, sepsis kills up to 21 million people worldwide — more than all cancers combined. For patients in an Intensive Care Unit (ICU), the difference between survival and death often hinges on subtle, moment-to-moment treatment decisions, such as how much intravenous (IV) fluid to administer, and when to use vasopressors to raise blood pressure. This is why sepsis is the number one cause of death and readmissions in hospitals.

This is a sequential decision-making problem, which Reinforcement Learning (RL) excels at. Many RL methods are online, where they learn by exploring, i.e. trying random actions and observing the consequences. This works great for video games. However, for a patient on the edge of organ failure, randomly trying treatment options is not ideal.

What if we could learn an optimal treatment policy from historical patient data, without ever risking a single new patient? This is the promise of offline reinforcement learning.

In this tutorial, I discuss my implementation of Conservative Q-Learning (CQL), an influential offline RL algorithm, applied to the 2024 ICU-Sepsis benchmark environment. My results show that CQL achieves an 86.0% survival rate, outperforming both traditional Deep Q-Networks (82.0%) and a policy that simply imitates human clinicians (84.0%). This article explains the theory, shows the code, and analyses the results, hence demonstrating how we can build AI systems for healthcare that are not just effective, but safe.

Conservative Q-Learning for Sepsis Treatment
Partly generated by Gemini 3 Pro (trained via RLHF), gemini.google.com

The Sepsis Problem

Sepsis is the body's dysregulated response to infection. It triggers a widespread inflammation that can lead to dangerously low blood pressure, tissue hypoxia, and multi-organ failure. Treatment involves several time-critical interventions: immediate administration of broad-spectrum antibiotics, source control (such as draining abscesses), fluid resuscitation to restore circulating volume, and vasopressors to maintain blood pressure if hypotension persists. Clinicians must continuously change these based on a patient's evolving vital signs.

The ICU Sepsis benchmark (Choudhary et al., 2024) isolates this management problem as a tabular Markov Decision Process derived from real MIMIC-III patient data. The environment has:

  • Ns = 716 discrete states representing clusters of patient physiology
  • Na = 25 actions arranged as a 5×5 grid (5 vasopressor levels × 5 IV fluid levels)
  • Sparse terminal rewards: +1 for survival (state 714), 0 for death (state 713), and 0 for all intermediate transitions. This simulates real clinical practice, where you don't get continuous instant feedback on whether your treatment is effective.
  • Discount factor: γ=1

Each episode simulates the treatment of one sepsis patient in the ICU until they either survive or die.

This benchmark is ideal for studying offline RL: we have approximately 50,000 transitions collected from clinician behaviour, covering 51.7% of the state-action space (i.e., out of all the possibilities, clinicians tried about half of them). The challenge is to learn an improved treatment policy without any further interaction with patients.

From Q-Learning to Deep Q-Networks

To understand CQL, we first need to build up from the foundations.

The Bellman Equation and Q-Learning

Reinforcement Learning formalises sequential decision making as a Markov Decision Process (MDP). At each timestep t, the agent observes a state s, takes an action a, receives a reward r, and transitions to a new state s'. The goal is to learn a policy π(s) that maximises the cumulative discounted reward.

The key object is the Q function (action-value function), which estimates the expected future reward of taking action a in state s and then following policy π:

Qπ(s, a) = 𝔼π[∑t γt rt | s0 = s, a0 = a]

The optimal Q function, Q*, satisfies the Bellman Optimality Equation:

Q*(s, a) = 𝔼[r + γ maxa' Q*(s', a')]

This recursive equation states that the value of an action equals the immediate reward plus the discounted value of the best action in the next state.

Q-learning (Watkins, 1989) is a classic algorithm that iteratively updates Q values towards this target:

Q(s, a) ← Q(s, a) + α [r + γ maxa' Q(s', a') − Q(s, a)]
↑ Old Estimate        ↑ Learning Rate        ↑ TD Target        ↑ TD Error

Where α is the learning rate. For tabular (finite) state spaces, this converges to Q*.

Deep Q-Networks (DQN)

For large or continuous state spaces, we can't store a table (as it would be so large it would never be updated or never learn). Instead, we approximate Q with a neural network Q(s, a; θ). Deep Q-Networks (DQN) (Mnih et al., 2013) trains this network by minimising the Temporal Difference (TD) loss:

TD(θ) = 𝔼(s,a,r,s') ~ D [(Qθ(s, a) − y)2]

Where the TD target is:

y = r + γ maxa' Qθ̄(s', a')

θ̄ refers to a slowly updated target network, which stabilises training.

DQN also uses an Experience Replay Buffer to store transitions and sample mini-batches, breaking temporal correlations in training data.

Distribution Shift

However, standard DQN learns online — it collects new data by interacting with the environment. In fields like healthcare, we can't do that. We have a fixed dataset D collected by clinicians.

When we train DQN on this offline dataset, something goes wrong. The Bellman target requires computing maxa' Q(s', a'). The network might estimate that some action — which was never tried by clinicians — has a very high Q value. Since there's no data to contradict this optimistic estimate, the error compounds, and the Q values for these "out of distribution" (OOD) actions spiral upwards to infinity.

This is called extrapolation error or distribution shift. The policy tries to maximise Q values, so it gravitates towards these hallucinated OOD actions. In my experiments, this manifests as loss explosion during training, as without conservatism the TD loss grows unbounded, meaning the learned policy recommends nonsensical (or dangerous) treatments.

Conservative Q-Learning: Pessimism is a Good Thing!

Conservative Q-Learning (CQL), proposed by Kumar et al. (2020), addresses this by adding a conservative penalty to the loss function. The core intuition is:

"If I haven't seen an action in the data, I should assume it's bad."

Rather than being optimistic about unknown actions (like other methods), CQL is pessimistic. It learns a Q function that lower-bounds the true value function, ensuring safe behaviour.

CQL Objective

The CQL loss consists of two terms:

Term 1: Push down OOD actions (Conservative Penalty)

We minimise the Q values across all possible actions for states in the dataset. For discrete action spaces, this is computed using the LogSumExp operator which is a smooth approximation of the maximum:

𝔼s~D[log ∑a exp(Qθ(s, a))]

This term penalises high Q values for any action, especially OOD ones where the network might hallucinate high values.

Term 2: Push up actual dataset actions

We maximise the Q values for actions that clinicians actually took (the behaviour policy):

𝔼s,a~D[Qθ(s, a)]

This ensures the conservative penalty doesn't push down the Q values of good, in-distribution actions too much.

The Full CQL Loss:

Combining these with the standard TD loss:

CQL(θ) = α(𝔼s~D[log ∑a exp(Qθ(s, a))] − 𝔼s,a~D[Qθ(s, a)]) + 𝔼(s,a,r,s')~D[(r + γ maxa' Qθ̄(s', a') − Qθ(s, a))2]

The complete CQL loss function. It minimises the conservative regulariser (Terms 1 & 2, scaled by α to prevent over-penalisation of OOD actions), plus the standard TD error to ensure the policy achieves the task.

Where:

  • α is the conservatism coefficient (a hyperparameter)
y = r + γ maxa' Qθ̄(s', a')  ← the TD target
  • D is the offline dataset collected by the behaviour policy (clinicians)

By setting α > 0, we learn a Q function that is a lower bound on the true value. The agent becomes appropriately skeptical of actions it hasn't seen sufficient evidence for.

Implementation

Here is the core of my CQL implementation, adapted for the discrete action ICU Sepsis environment. The full code is available in anishmariathasan/CQL-sepsis on GitHub.

My implementation follows the original CQL paper and reference implementation at github.com/aviralkumar2907/CQL.

Q-Network Architecture

The ICU Sepsis environment has 716 discrete states. Rather than using wasteful one-hot encodings (which are high dimensional and sparse, with no notion of similarity between states), I use an embedding layer to learn dense state representations:

class QNetwork(nn.Module):
    """Q-Network for discrete action spaces."""
    def __init__(self, state_dim, action_dim, hidden_dim=256, num_layers=2):
        super(QNetwork, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim

        # Learn dense embeddings for discrete states
        self.state_embedding = nn.Embedding(state_dim, hidden_dim)

        # Build MLP layers: Linear -> ReLU -> Linear -> ReLU -> Linear
        layers = []
        for i in range(num_layers):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        self.hidden_layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        """Compute Q values for all actions given states."""
        x = self.state_embedding(state.long())
        x = self.hidden_layers(x)
        return self.output_layer(x)

Design choices:

  • nn.Embedding: Maps each of 716 states to a learnable 256 dimensional vector
  • 2 layer MLP with ReLU activations
  • Outputs Q values for all 25 actions simultaneously

The CQL Loss

This method computes both the TD loss and the conservative penalty:

def compute_cql_loss(self, states, actions, rewards, next_states, dones):
    """
    Compute CQL loss = TD loss + alpha * CQL penalty.
    """
    # Compute TD loss (standard DQN term, stable for)
    # Get Q values for all actions
    current_q_values = self.q_network(states.unsqueeze(1))  # [batch, 25]

    # Q-values of actions for the actions actually taken
    q_values = current_q_values.gather(1, actions.unsqueeze(1))  # [batch, 1]

    with torch.no_grad():
        # Use target network for online network for SARSA-# actions
        next_q_values = self.target_network(next_states.unsqueeze(1))
        max_next_q_values = next_q_values.max(dim=1, keepdim=True)[0]

        # Use target network for SARSA-F offline actions
        target_q_values = rewards + (1 - dones) * self.gamma * max_next_q_values

    # Compute squared TD error
    td_loss = F.mse_loss(q_values, target_q_values)

    # ======== Compute conservative penalty ========

    # Logsumexp across all 25 actions (pushes down ALL Q values)
    logsumexp = torch.logsumexp(current_q_values, dim=1, keepdim=True)

    # Q value of the action in dataset
    data_q = q_values

    # Conservative penalty
    cql_penalty = (logsumexp - data_q).mean()
    
    total_loss = td_loss + self.alpha * cql_penalty
    return total_loss, td_loss, logsumexp.mean()
  • torch.logsumexp: A numerically stable computation of log(∑exp(Q)), which acts as a smooth maximum over all actions. This penalises all actions proportionally to their Q values.
  • data_q: The Q values of actions clinicians actually took. Subtracting this from logsumexp ensures we only push down OOD actions, not good in-distribution ones.
  • self.alpha: Controls conservatism strength. α=0 recovers standard DQN; higher α means more pessimism.

Training Loop with Soft Target Updates

Each gradient step samples a batch from the offline replay buffer:

def update(self, batch):
    """Perform a single gradient update step."""
    states, actions, rewards, next_states, dones = batch

    # Compute the CQL + TD loss
    loss, td_loss, logsumexp = self.compute_cql_loss(
        states, actions, rewards, next_states, dones
    )

    # Gradient descent step
    self.q_optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), self.grad_clip)
    self.q_optimizer.step()

    # Soft target network update: θ' ← τ * θ + (1 - τ) * θ'
    for param, target_param in zip(
        self.q_network.parameters(), self.target_network.parameters()
    ):
        target_param.data.copy_(
            self.tau * param.data + (1 - self.tau) * target_param.data
        )

    return loss

The soft target update (with τ=0.005) provides stability by slowly moving the target network towards the online network, preventing oscillations in the TD targets.

Experimental Setup

Training Details

Hyperparameter Value Description
Hidden Dimension 256 Size of the hidden layers in the network
Layers 2 Number of hidden layers in the network
Learning Rate 3×10-4 Step size for parameter updates
Batch Size 256 Number of samples processed per update
γ (Discount) 0.99 Future reward discount factor
τ (Soft Update) 0.005 Target network soft update coefficient
Gradient Clip 1 Maximum norm for gradient clipping
Training Steps 100,000 Total number of training iterations
Evaluation Seeds 3 Number of random seeds used for evaluation

Baselines

  1. DQN: Equivalent to CQL with α=0 (no conservatism)
  2. Behaviour Cloning (BC): Supervised learning to directly imitate clinician actions
  3. Random: Uniform random action selection

Results & Analysis

1. Training Stability

Training Stability Comparison
Figure 1: Training loss over 100,000 steps across three algorithms. Panel (a): CQL converges smoothly to ~0.2. Panel (b): BC converges via supervised learning. Panel (c): DQN's loss explodes past 200 after 60K iterations.

CQL exhibits rapid convergence followed by stable, bounded optimisation, whilst behaviour cloning shows smooth and predictable convergence typical of supervised learning. In contrast, DQN initially appears stable but later diverges sharply, with the TD loss exploding after extended training, highlighting the instability of offline Q-learning under distribution shift.

Explanation and clinical relevance

DQN's divergence is driven by overestimation in the max operator of the TD target, which amplifies errors for out-of-distribution actions through a positive feedback loop. CQL mitigates this by explicitly penalising high Q values across actions, preventing this runaway overestimation and ensuring stable learning essential for a clinical setting.

2. Evaluation Performance During Training

Evaluation Performance
Figure 2: Evaluation survival rate during training. Panel (a): CQL across 5 random seeds maintains 80-90% survival. Panel (b): Algorithm comparison shows CQL improving while DQN degrades from ~85% to ~73%.

Across random seeds, CQL consistently maintains high survival rates (80-90%) throughout training, indicating stable and reliable behaviour even before full convergence. When comparing algorithms, CQL shows a slight improvement over time, behaviour cloning remains stable, and DQN degrades steadily with additional training, reflecting the accumulation of overestimation errors under distribution shift. This highlights the failure of the "train longer is better" assumption in online reinforcement learning compared to offline, and underscores the importance of conservative methods for safety-critical medical applications.

3. Alpha Sensitivity: Finding the "Goldilocks Zone"

Alpha Sensitivity Analysis
Figure 3: Panel (a): Mean survival rate by α, note the peak at α=0.5. Panel (b): Box plots showing the variance-performance tradeoff — low α has high variance, high α has low variance but lower mean.

Sweeping α loosely shows an inverted-U relationship between conservatism and performance, with mean survival peaking at α=0.5. Low values of α lead to unstable performance due to insufficient control of Q-value overestimation, whilst large values reduce variance but degrade mean performance by over-penalising even well-supported actions. The box plots illustrate this variance–performance trade-off, with α=0.5 providing the best balance between high survival and robustness across seeds. This shows us the importance of tuning conservatism in offline reinforcement learning, particularly in clinical settings where both risk and under-treatment are costly.

4. Algorithm Comparison: CQL Beats All Baselines

Algorithm Comparison
Figure 4: Final survival rates: CQL (86.0%) > BC (84.0%) > DQN (82.0%) > Random (81.7%). Error bars show standard deviation across seeds. The red dashed line marks the 80% baseline.

CQL's outperformance of Behaviour Cloning (86% vs 84%) shows that rather than imitating any single clinician's complete treatment sequence, CQL identifies which individual decisions were effective and combines them into a superior policy. Notably, CQL approaches the environment's theoretical optimal of 88% (computed via value iteration with known transition dynamics), suggesting it recovers most of the available performance gain.

DQN's near-equivalence to random (82.0% vs 81.7%) illustrates how Q-value overestimation in offline settings can negate learning entirely. The 4% gap between CQL and DQN translates to approximately 40 additional survivors per 1,000 patients. Applied to UK ICU sepsis admissions (roughly 20,000 annually), this represents approximately 1,600 additional lives per year.

Random's ~78–82% survival rate reflects the benchmark design, where many patients survive regardless of treatment. CQL's improvement targets the marginal cases where intervention quality determines outcome.

5. Performance Variance Across Hyperparameters

Performance Heatmap
Figure 5: Survival rate heatmap across all 7 alpha values (rows) and 3 random seeds (columns). Colour intensity shows performance. This visualisation reveals both optimal settings AND variance patterns.

The heatmap reveals that at low α values (0.0–0.1), performance is highly seed-dependent, with identical configurations yielding outcomes ranging from 79% to 89%. This instability diminishes as α increases: at α≥2.0, results cluster within a narrower 82–87% band, sacrificing peak performance for consistency.

The setting α=0.5 occupies a middle ground, achieving the highest mean performance whilst avoiding the high variance between seeds at α=0.1.

Clinical relevance: In healthcare applications, a model fluctuating between 79% and 89% presents greater deployment risk than one reliably achieving 84%. This visualisation identifies configurations where all seeds perform acceptably, which is essential information for clinicians for safe usage.

6. Action Distributions: What Treatments Does Each Policy Prefer?

Action Distribution Analysis
Figure 6: Action frequency by policy. The 25 actions form a 5×5 grid (vasopressor level × IV fluid level). Pink shading marks the "high-dose region" (actions 20–24).

The action space encodes treatment intensity: action 0 represents no intervention, while actions 20–24 (pink region) represent high-dose combinations.

Clinician and BC distributions are similar and spread across actions 0, 5, 10, 15, and 20, reflecting real-world variation in treatment decisions. BC faithfully reproduces this clinical distribution, as expected from imitation learning.

CQL concentrates sharply on actions 0, 10, and 20, exhibiting far more decisive behaviour than clinicians. This is somewhat surprising, as where clinicians distribute probability across five main actions, CQL commits strongly to three. At action 20 specifically, CQL significantly exceeds clinician frequency (~36% vs ~13%), suggesting it has identified aggressive treatment as beneficial for certain states.

DQN's distribution is different from the high-dose concern: its extreme spikes occur at actions 6 (~36%) and 14 (~23%), both mid-range interventions. Contrary to the high-dose region being DQN's problem area, DQN actually underuses actions 20–24 relative to other policies. The issue is not excessive aggression but rather the concentration on specific actions regardless of patient state.

Clinical interpretation: CQL's deviation from clinician behaviour, particularly its increased use of high-dose treatment, warrants attention. If CQL achieves higher survival rates whilst being more aggressive than clinicians, this may suggest systematic under-treatment in current practice. However, such a conclusion requires rigorous validation before informing clinical guidelines.

7. Q-Value Analysis by State and Across Algorithms

Q-Value Analysis
Figure 7: Q-values for CQL vs DQN. Left: Single state (559) showing Q-values across all actions. Right: Mean Q-values across 50 states with error bars. Grey shading marks low-coverage actions (~20–24) where clinicians rarely intervened.

We can see in the left panel DQN assigns near-uniform Q-values (~0–1) across all 25 actions, whilst CQL shows large variation, with low-coverage actions (grey region) receiving strongly negative values (down to -20).

This pattern for a single state holds across 50 states (right panel). DQN maintains flat, confident estimates with tight error bars regardless of data coverage. CQL produces negative Q-values with large uncertainty for actions rarely observed in the training data.

This is because CQL's penalty term pushes down Q-values for all actions, then selectively restores values for actions observed in the dataset. The strongly negative Q-values for low-coverage actions are this penalty functioning as intended.

Clinical interpretation: For treatments clinicians rarely administered, CQL defaults to pessimism where it assumes the worst. DQN's uniform confidence across untested treatments is why naive offline RL poses risks in high stakes domains like healthcare.

Conclusion: Towards Safe AI for Healthcare

By penalising OOD actions, CQL learns only what the data can reliably support. It combines:

  1. Data efficiency: Learns from existing hospital records — no new patient data needed.
  2. Stability: No training collapse, no runaway Q-values.
  3. Safety: Conservative estimates mean no dangerous extrapolation to untested treatments.
  4. Uncertainty awareness: CQL expresses high uncertainty for actions it lacks evidence about, rather than false confidence and still outperforms clinicians.

Future Work

At the scale of global sepsis incidence (166 million cases/year), even incremental improvements could save hundreds of thousands of lives. However, whilst this implementation showed promising results, real-world deployment remains a significant challenge requiring extensive clinical validation.

Limitations

  1. Hyperparameter sensitivity: The optimal α varies by dataset and domain. We currently lack reliable methods for selecting α purely offline.
  2. Simulator limitations: ICU-Sepsis is a simplified tabular MDP, not a full patient simulator with continuous dynamics.
  3. Temporal abstraction: Real treatment involves continuous dosing over time, not discrete 5×5 action grids.
  4. Generalisation: Performance may vary on patient populations different from MIMIC-III.

References

  1. Kumar, A., Zhou, A., Tucker, G., & Levine, S. (2020). Conservative Q-Learning for Offline Reinforcement Learning. Advances in Neural Information Processing Systems (NeurIPS). arXiv:2006.04779 | GitHub
  2. Mnih, V., et al. (2015). Human-level control through deep reinforcement learning. Nature, 518, 529–533.
  3. Choudhary, S., et al. (2024). ICU-Sepsis: A Benchmark MDP Built from Real Medical Data. Reinforcement Learning Journal, 4, 1546–1566. (Presented at the Reinforcement Learning Conference, RLC 2024.)
  4. Komorowski, M., et al. (2018). The Artificial Intelligence Clinician learns optimal treatment strategies for sepsis in intensive care. Nature Medicine, 24, 1716–1720.
  5. Singer, M., et al. (2016). The Third International Consensus Definitions for Sepsis and Septic Shock (Sepsis-3). JAMA, 315, 801–810.

Acknowledgments

Statistics and external sources not included in the references are hyperlinked directly in the text.

I acknowledge the use of Claude Opus 4.5 (Anthropic, https://claude.ai) for code adaption and debugging. All experimental results, analysis, and conclusions are my own work.

See the full implementation in: github.com/anishmariathasan/CQL-sepsis