Models Reference

OneEHR ships 38 model architectures across tabular ML, deep learning, irregular-time modeling, KG-enhanced EHR modeling, and survival analysis. All models are configured via [[models]] entries with a name and params dict.


Model overview

Model Config name Type Patient (N-1) Time (N-N) Static branch
XGBoost xgboost Tabular Yes Yes N/A
CatBoost catboost Tabular Yes Yes N/A
Random Forest rf Tabular Yes Yes N/A
Decision Tree dt Tabular Yes Yes N/A
Gradient Boosting gbdt Tabular Yes Yes N/A
Logistic Regression lr Tabular Yes Yes N/A
GRU gru DL Yes Yes No
LSTM lstm DL Yes Yes No
RNN rnn DL Yes Yes No
GRU-D grud DL Yes Yes No
CNN cnn DL Yes Yes No
TCN tcn DL Yes Yes No
Transformer transformer DL Yes Yes No
SAnD sand DL Yes Yes No
Dipole dipole DL Yes Yes No
HiTANet hitanet DL Yes Yes No
LSAN lsan DL Yes Yes No
mTAND mtand DL Yes Yes No
Raindrop raindrop DL Yes Yes No
ContiFormer contiformer DL Yes Yes No
TECO teco DL Yes Yes Yes
MLP mlp DL Yes Yes No
AdaCare adacare DL Yes Yes No
StageNet stagenet DL Yes Yes No
RETAIN retain DL Yes Yes No
ConCare concare DL Yes Yes Yes
GRASP grasp DL Yes Yes Yes
MCGRU mcgru DL Yes Yes Yes
DrAgent dragent DL Yes Yes Yes
Deepr deepr DL Yes Yes No
EHR-Mamba mamba DL Yes Yes No
Jamba jamba DL Yes Yes No
PRISM prism DL Yes Yes Yes
M3Care m3care DL Yes Yes No
SAFARI safari DL Yes Yes Yes
PAI (GRU) pai DL Yes Yes No
GraphCare graphcare DL / KG Yes Yes No
KerPrint kerprint DL / KG Yes Yes No
ProtoEHR protoehr DL / KG Yes Yes No
DeepSurv deepsurv DL / Survival Yes No No
DeepHit deephit DL / Survival Yes No No

Models with a static branch automatically receive patient-level static features as a separate input tensor when static.csv is provided. The static_dim parameter is auto-detected from the static feature count.

Recent Additions

The latest model family additions concentrate on three gaps in longitudinal EHR benchmarking: missing-aware recurrent baselines, irregular-time encoders, and lightweight KG-enhanced architectures.

Model Config Summary Key params
GRU-D grud Missing-aware GRU with trainable decay and observed-feature means hidden_dim, dropout
CNN cnn Lightweight temporal convolution baseline hidden_dim, num_layers, kernel_size
SAnD sand Self-attention with causal convolution and dense interpolation pooling d_model, nhead, interp_points
Dipole dipole Bidirectional GRU with location/general/concat attention hidden_dim, attention_type
HiTANet hitanet Grouped visit encoder with hierarchical time-aware attention hidden_dim
LSAN lsan Grouped visit encoder with long/short-term fusion hidden_dim, nhead, kernel_size
mTAND mtand Relative-time attention for irregular sequences hidden_dim, num_heads, num_layers
Raindrop raindrop Graph-guided sensor message passing over irregular observations hidden_dim
ContiFormer contiformer Continuous-time state updates followed by time-biased attention hidden_dim, num_heads, num_layers
TECO teco Encounter-level transformer with optional static token hidden_dim, nhead, static_dim
GraphCare graphcare Lightweight patient-specific KG summarization with temporal fusion hidden_dim, kg_source, kg_top_k, kg_ontology
KerPrint kerprint Local/global KG summaries with time-aware knowledge gating hidden_dim, kg_source, kg_top_k, kg_ontology
ProtoEHR protoehr KG-enhanced patient modeling with concept/visit/patient prototypes hidden_dim, num_prototypes, kg_source, kg_top_k, kg_ontology

For KG-enhanced models, kg_source = "lightweight" builds an internal concept graph from train-split co-occurrence plus available ontology hints. kg_source = "external" reads a user-supplied graph from external_kg_path.


Tabular models

Tabular models flatten the time dimension and operate on a 2D feature matrix. They work with both patient and time prediction modes.

XGBoost

[[models]]
name = "xgboost"
[models.params]
max_depth = 6
n_estimators = 500
learning_rate = 0.05
subsample = 0.8
colsample_bytree = 0.8
reg_lambda = 1.0
min_child_weight = 1.0
Parameter Type Default Description
max_depth int 6 Maximum tree depth
n_estimators int 500 Number of boosting rounds
learning_rate float 0.05 Step size shrinkage
subsample float 0.8 Row subsampling ratio
colsample_bytree float 0.8 Column subsampling ratio per tree
reg_lambda float 1.0 L2 regularization
min_child_weight float 1.0 Minimum sum of instance weight in a child

CatBoost

[[models]]
name = "catboost"
[models.params]
depth = 6
n_estimators = 500
learning_rate = 0.05
Parameter Type Default Description
depth int 6 Maximum tree depth
n_estimators int 500 Number of boosting iterations
learning_rate float 0.05 Step size shrinkage

Random Forest

[[models]]
name = "rf"
[models.params]
n_estimators = 100
max_depth = 6
Parameter Type Default Description
n_estimators int 100 Number of trees
max_depth int None Maximum tree depth (None for unlimited)

Decision Tree

[[models]]
name = "dt"
[models.params]
max_depth = 6
Parameter Type Default Description
max_depth int None Maximum tree depth (None for unlimited)

Gradient Boosting (GBDT)

Scikit-learn's GradientBoostingClassifier / GradientBoostingRegressor.

[[models]]
name = "gbdt"
[models.params]
n_estimators = 100
max_depth = 3
learning_rate = 0.1
Parameter Type Default Description
n_estimators int 100 Number of boosting stages
max_depth int 3 Maximum tree depth
learning_rate float 0.1 Step size shrinkage

Logistic Regression

Scikit-learn's LogisticRegression (binary) or Ridge (regression).

[[models]]
name = "lr"
[models.params]
max_iter = 1000
Parameter Type Default Description
max_iter int 1000 Maximum iterations for solver convergence

Recurrent models

GRU

[[models]]
name = "gru"
[models.params]
hidden_dim = 128
num_layers = 1
dropout = 0.0
Parameter Type Default Description
hidden_dim int 128 Hidden state size
num_layers int 1 Number of stacked GRU layers
dropout float 0.0 Dropout between layers

LSTM

[[models]]
name = "lstm"
[models.params]
hidden_dim = 128
num_layers = 1
dropout = 0.0
Parameter Type Default Description
hidden_dim int 128 Hidden state size
num_layers int 1 Number of stacked LSTM layers
dropout float 0.0 Dropout between layers

RNN

Vanilla (Elman) recurrent network.

[[models]]
name = "rnn"
[models.params]
hidden_dim = 128
num_layers = 1
dropout = 0.0
Parameter Type Default Description
hidden_dim int 128 Hidden state size
num_layers int 1 Number of stacked RNN layers
dropout float 0.0 Dropout between layers

Non-recurrent models

TCN

Temporal Convolutional Network.

[[models]]
name = "tcn"
[models.params]
hidden_dim = 128
num_layers = 2
kernel_size = 3
dropout = 0.1
Parameter Type Default Description
hidden_dim int 128 Channel dimension
num_layers int 2 Number of TCN blocks
kernel_size int 3 Convolutional kernel size
dropout float 0.1 Dropout rate

Transformer

[[models]]
name = "transformer"
[models.params]
d_model = 128
nhead = 4
num_layers = 2
dim_feedforward = 256
dropout = 0.1
pooling = "last"
Parameter Type Default Description
d_model int 128 Model dimension
nhead int 4 Number of attention heads
num_layers int 2 Number of encoder layers
dim_feedforward int 256 FFN inner dimension
dropout float 0.1 Dropout rate
pooling str "last" Pooling for patient mode: last or mean

MLP

Feedforward network operating on the last time step.

[[models]]
name = "mlp"
[models.params]
hidden_dim = 128
dropout = 0.0
Parameter Type Default Description
hidden_dim int 128 Hidden layer size
dropout float 0.0 Dropout rate

Deepr

Embedding + CNN-based sequence model over discrete time windows.

[[models]]
name = "deepr"
[models.params]
hidden_dim = 128
window = 1
dropout = 0.0
Parameter Type Default Description
hidden_dim int 128 Embedding and hidden dimension
window int 1 Convolution window size
dropout float 0.0 Dropout rate

EHR-Mamba

Selective state-space model (Mamba) adapted for EHR sequences.

[[models]]
name = "mamba"
[models.params]
hidden_dim = 128
num_layers = 2
state_size = 16
conv_kernel = 4
dropout = 0.1
Parameter Type Default Description
hidden_dim int 128 Model dimension
num_layers int 2 Number of Mamba layers
state_size int 16 SSM state dimension
conv_kernel int 4 1D convolution kernel size
dropout float 0.1 Dropout rate

Jamba

Hybrid architecture combining Transformer attention and Mamba SSM layers.

[[models]]
name = "jamba"
[models.params]
hidden_dim = 128
num_transformer_layers = 2
num_mamba_layers = 6
heads = 4
state_size = 16
conv_kernel = 4
dropout = 0.3
Parameter Type Default Description
hidden_dim int 128 Model dimension
num_transformer_layers int 2 Number of Transformer layers
num_mamba_layers int 6 Number of Mamba layers
heads int 4 Attention heads in Transformer layers
state_size int 16 Mamba SSM state dimension
conv_kernel int 4 Mamba 1D convolution kernel size
dropout float 0.3 Dropout rate

EHR-specialised models

AdaCare

Adaptive clinical feature calibration with dilated convolutions.

Liantao Ma et al. AdaCare: Explainable Clinical Health Status Representation Learning via Scale-Adaptive Feature Extraction and Recalibration. AAAI 2020.

[[models]]
name = "adacare"
[models.params]
hidden_dim = 128
kernel_size = 2
kernel_num = 64
dropout = 0.5
Parameter Type Default Description
hidden_dim int 128 Hidden state size
kernel_size int 2 Dilated convolution kernel size
kernel_num int 64 Number of convolution channels
dropout float 0.5 Dropout rate

StageNet

Stage-aware LSTM with stage-adaptive convolution.

Junyi Gao et al. StageNet: Stage-Aware Neural Network for Health Risk Prediction. WWW 2020.

[[models]]
name = "stagenet"
[models.params]
chunk_size = 128
levels = 3
conv_size = 10
dropout = 0.3
Parameter Type Default Description
chunk_size int 128 Stage-aware hidden chunk size
levels int 3 Number of hierarchical levels
conv_size int 10 Convolution window size
dropout float 0.3 Dropout rate

RETAIN

Reverse Time Attention Network with interpretable alpha and beta attention.

Edward Choi et al. RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism. NIPS 2016.

[[models]]
name = "retain"
[models.params]
hidden_dim = 128
dropout = 0.5
Parameter Type Default Description
hidden_dim int 128 Hidden state size for both alpha and beta GRUs
dropout float 0.5 Dropout rate

ConCare

Context-aware temporal attention with self-attention over feature embeddings. Supports a dedicated static branch.

Liantao Ma et al. ConCare: Personalized Clinical Feature Embedding via Capturing the Healthcare Context. AAAI 2020.

[[models]]
name = "concare"
[models.params]
hidden_dim = 128
num_heads = 4
dropout = 0.5
Parameter Type Default Description
hidden_dim int 128 Hidden state size
num_heads int 4 Multi-head attention heads
dropout float 0.5 Dropout rate
static_dim int auto Auto-detected from static features

GRASP

GRU with K-means clustering and graph convolutional layers. Supports a dedicated static branch.

Chaohe Zhang et al. GRASP: Generic Framework for Health Status Representation Learning Based on Incorporating Knowledge from Similar Patients. AAAI 2021.

[[models]]
name = "grasp"
[models.params]
hidden_dim = 128
cluster_num = 12
dropout = 0.5
Parameter Type Default Description
hidden_dim int 128 Hidden state size
cluster_num int 12 Number of patient clusters
dropout float 0.5 Dropout rate
static_dim int auto Auto-detected from static features

MCGRU

Multi-Channel GRU with per-feature GRU cells. Supports a dedicated static branch.

[[models]]
name = "mcgru"
[models.params]
hidden_dim = 32
feat_dim = 8
dropout = 0.0
Parameter Type Default Description
hidden_dim int 32 Per-channel GRU hidden size
feat_dim int 8 Feature embedding dimension
dropout float 0.0 Dropout rate
static_dim int auto Auto-detected from static features

DrAgent

Dual-agent reinforcement learning action selection for clinical prediction. Supports a dedicated static branch.

Junyi Gao et al. Dr. Agent: Clinical predictive model via mimicked second opinions. JAMIA.

[[models]]
name = "dragent"
[models.params]
hidden_dim = 128
n_actions = 10
n_units = 64
dropout = 0.5
lamda = 0.5
Parameter Type Default Description
hidden_dim int 128 GRU hidden state size
n_actions int 10 Number of agent actions
n_units int 64 Agent MLP hidden size
dropout float 0.5 Dropout rate
lamda float 0.5 Mixing weight for agent-selected vs current hidden state
static_dim int auto Auto-detected from static features

M3Care

Transformer-style temporal encoder with sinusoidal positional encodings and in-batch neighbour graph refinement.

KDD 2022 reference-inspired implementation adapted to OneEHR's sequence contract.

[[models]]
name = "m3care"
[models.params]
hidden_dim = 128
num_heads = 4
dim_feedforward = 256
dropout = 0.1
num_layers = 1
Parameter Type Default Description
hidden_dim int 128 Sequence embedding size
num_heads int 4 Attention head count
dim_feedforward int 256 Feed-forward inner dimension
dropout float 0.1 Dropout rate
num_layers int 1 Number of encoder blocks

SAFARI

MCGRU-style grouped feature encoder with feature clustering, graph refinement, and attention pooling. Supports a dedicated static branch.

TKDE 2022 reference-inspired implementation adapted to OneEHR's grouped feature schema.

[[models]]
name = "safari"
[models.params]
hidden_dim = 32
n_clu = 8
dropout = 0.5
Parameter Type Default Description
hidden_dim int 32 Group encoder and attention hidden size
n_clu int 8 Number of feature clusters used for the graph update
dropout float 0.5 Dropout rate
dim_list list[int] auto Auto-derived group widths from feature_schema.json
static_dim int auto Auto-detected from static features

PAI

Learnable Prompt as Pseudo-Imputation on top of the GRU backbone. Missing entries are replaced by a learned feature-wise prompt using obs_mask.parquet.

KDD 2025 plugin-style implementation restricted to the GRU base model in OneEHR.

[[models]]
name = "pai"
[models.params]
hidden_dim = 128
num_layers = 1
dropout = 0.0
prompt_init = "median"
Parameter Type Default Description
hidden_dim int 128 GRU hidden state size
num_layers int 1 Number of stacked GRU layers
dropout float 0.0 Dropout between GRU layers
prompt_init str "median" Prompt initialisation: median, zero, or random

Survival models

Survival models predict time-to-event outcomes with censoring support. Use with task.kind = "survival".

DeepSurv

Cox proportional hazards deep neural network (Katzman et al., 2018). Outputs a single log-risk score per patient. Trained with the Cox partial likelihood loss.

[[models]]
name = "deepsurv"
[models.params]
hidden_dim = 128
num_layers = 2
dropout = 0.1
Parameter Type Default Description
hidden_dim int 128 Hidden layer dimension
num_layers int 2 Number of hidden layers
dropout float 0.1 Dropout rate

DeepHit

Discrete-time competing risks survival model (Lee et al., 2018). Outputs a probability mass function over time bins.

[[models]]
name = "deephit"
[models.params]
hidden_dim = 128
num_time_bins = 20
num_layers = 2
dropout = 0.1
Parameter Type Default Description
hidden_dim int 128 Hidden layer dimension
num_time_bins int 10 Number of discrete time bins
num_layers int 2 Number of hidden layers
dropout float 0.1 Dropout rate