ConvNeXt-CheXpert: CBAM-Augmented Thoracic Classifier
Model Description
This repository contains the weights for a ConvNeXt-Base architecture fine-tuned for multi-label classification of chest radiographs. The model is augmented with Convolutional Block Attention Modules (CBAM) to enhance feature localization and interpretability.
Try the live demo here: ๐ GRADCAM-Convnext-Chexpert-Attention Space
- Architecture: ConvNeXt-Base + CBAM
- Validation AUC: 0.81 (Iteration 6)
- Input Resolution: 384x384
Detectable Classes
The model outputs probabilities for the following 14 classes:
- No Finding
- Enlarged Cardiomediastinum
- Cardiomegaly
- Lung Opacity
- Lung Lesion
- Edema
- Consolidation
- Pneumonia
- Atelectasis
- Pneumothorax
- Pleural Effusion
- Pleural Other
- Fracture
- Support Devices
Interpretability (Grad-CAM)
The model's attention mechanism allows for precise localization of pathologies.
Figure 1: Multi-label detection visualized via Grad-CAM.
Usage
Prerequisites
pip install torch torchvision timm pillow
Inference Code
This script loads the model and performs inference using the specific normalization statistics from training. Python
import torch
from PIL import Image
from torchvision import transforms
import timm
# 1. Configuration
# Ensure you have downloaded model.pth from this repository
MODEL_PATH = "model.pth"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pathologies = [
"No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly",
"Lung Opacity", "Lung Lesion", "Edema", "Consolidation",
"Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion",
"Pleural Other", "Fracture", "Support Devices"
]
# 2. Load Model
model = timm.create_model('convnext_base', pretrained=False, num_classes=14)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
# 3. Preprocessing
# Note: These specific mean/std values are critical for accuracy
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3),
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5029414296150208] * 3,
std=[0.2892409563064575] * 3
)
])
# 4. Run Prediction
image = Image.open('chest_xray.jpg')
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(input_tensor)
probs = torch.sigmoid(logits)[0]
print("Predictions:")
for name, score in zip(pathologies, probs):
print(f"{name}: {score.item():.3f}")
Citation
@misc{convnext_chexpert_2025,
author = {Calendar, S.},
title = {ConvNeXt-CheXpert: CBAM-Augmented Thoracic Classifier},
year = {2025},
publisher = {Hugging Face},
url = {[https://huggingface.co/calender/GRADCAM-Convnext-Chexpert-Attention](https://huggingface.co/calender/GRADCAM-Convnext-Chexpert-Attention)}
}```
- Downloads last month
- -