ECG Arrhythmia Classification — CNN+MLP vs CNN+KAN
This repository provides two trained PyTorch models for ECG beat classification on the MIT-BIH Arrhythmia Database:
- CNN+MLP (baseline)
- CNN+KAN (proposed Kolmogorov–Arnold Network head)
Both models classify ECG beats into 5 AAMI-style classes and are evaluated on a held-out test set.
Models and Files
- CNN+MLP weights: checkpoints/cnn_mlp.pth
- CNN+KAN weights: checkpoints/cnn_kan.pth
Task
Time-series classification of single ECG beats.
Class Labels
| Label | Meaning |
|---|---|
| N | Normal / Non-ectopic beat |
| S | Supraventricular ectopic beat |
| V | Ventricular ectopic beat |
| F | Fusion beat |
| Q | Unknown / Unclassifiable beat |
Data
- Dataset: MIT-BIH Arrhythmia Database (PhysioNet)
- Lead: MLII (lead 0)
- Sampling rate: 360 Hz
- Beat window: 256 samples (128 before + 128 after annotation)
- Normalization: per-beat z-score
- Split: stratified train/val/test = 70% / 15% / 15%
Training
- Optimizer: Adam
- Learning rate: 1e-3
- Weight decay: 1e-4
- Batch size: 128
- Epochs: up to 50, early stopping (patience 10)
- Loss: class-weighted cross-entropy
- Gradient clipping: max norm 1.0 (for KAN stability)
Results (Test Set)
Overall Metrics
| Model | Accuracy | Macro F1 | Weighted F1 | Macro AUC | Params | Inference (ms/sample) |
|---|---|---|---|---|---|---|
| CNN+MLP | 0.9800 | 0.9019 | 0.9809 | 0.9965 | 175,973 | 0.3664 |
| CNN+KAN | 0.9476 | 0.8167 | 0.9540 | 0.9924 | 285,280 | 0.6308 |
Per-Class F1
| Class | CNN+MLP | CNN+KAN |
|---|---|---|
| N | 0.9889 | 0.9690 |
| S | 0.8111 | 0.5968 |
| V | 0.9577 | 0.9070 |
| F | 0.7589 | 0.6270 |
| Q | 0.9930 | 0.9836 |
How to Use (PyTorch)
import torch
from src.models.cnn import ECGCNN
from src.models.cnn_kan import ECGCNNWithKAN
# Choose model
model = ECGCNN(num_classes=5) # or ECGCNNWithKAN(num_classes=5)
# Load weights
ckpt = torch.load(CHECKPOINT_PATH, map_location="cpu")
model.load_state_dict(ckpt["model_state"])
model.eval()
# Example input: [batch, 1, 256]
x = torch.randn(1, 1, 256)
proba = torch.softmax(model(x), dim=1)
pred = proba.argmax(dim=1).item()
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support