Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import cv2
|
| 3 |
import torch
|
| 4 |
import torchvision
|
| 5 |
import numpy as np
|
|
@@ -7,6 +6,7 @@ from PIL import Image
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import matplotlib.patches as patches
|
| 9 |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
|
|
| 10 |
import os
|
| 11 |
import io
|
| 12 |
|
|
@@ -16,12 +16,16 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
|
| 16 |
CLASS_NAMES = {1: 'Nipple', 2: 'Lump'}
|
| 17 |
CLASS_COLORS = {1: 'white', 2: 'white'}
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def preprocess_image(image):
|
| 20 |
"""Load and preprocess image for Faster R-CNN."""
|
| 21 |
# Convert PIL Image to numpy array
|
| 22 |
image = np.array(image)
|
| 23 |
|
| 24 |
-
#
|
| 25 |
image = image.astype(np.float32) / 255.0 # Normalize to [0,1]
|
| 26 |
|
| 27 |
# Normalize using ImageNet mean and std
|
|
@@ -40,13 +44,41 @@ def load_model(checkpoint_path, device):
|
|
| 40 |
model.to(device).eval()
|
| 41 |
return model
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def predict(image, score_thresh=0.5):
|
| 44 |
"""Run inference and return image with bounding boxes."""
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
# Load model
|
| 48 |
-
checkpoint_path = "lumps.pth" # This will be downloaded from the model repo
|
| 49 |
-
model = load_model(checkpoint_path, device)
|
| 50 |
|
| 51 |
# Preprocess image
|
| 52 |
image_tensor = preprocess_image(image)
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import torch
|
| 3 |
import torchvision
|
| 4 |
import numpy as np
|
|
|
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
import matplotlib.patches as patches
|
| 8 |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
import os
|
| 11 |
import io
|
| 12 |
|
|
|
|
| 16 |
CLASS_NAMES = {1: 'Nipple', 2: 'Lump'}
|
| 17 |
CLASS_COLORS = {1: 'white', 2: 'white'}
|
| 18 |
|
| 19 |
+
# Global variables for model (load once at startup)
|
| 20 |
+
MODEL = None
|
| 21 |
+
DEVICE = None
|
| 22 |
+
|
| 23 |
def preprocess_image(image):
|
| 24 |
"""Load and preprocess image for Faster R-CNN."""
|
| 25 |
# Convert PIL Image to numpy array
|
| 26 |
image = np.array(image)
|
| 27 |
|
| 28 |
+
# Already in RGB format from Gradio
|
| 29 |
image = image.astype(np.float32) / 255.0 # Normalize to [0,1]
|
| 30 |
|
| 31 |
# Normalize using ImageNet mean and std
|
|
|
|
| 44 |
model.to(device).eval()
|
| 45 |
return model
|
| 46 |
|
| 47 |
+
def initialize_model():
|
| 48 |
+
"""Initialize model at startup by downloading from HF Hub."""
|
| 49 |
+
global MODEL, DEVICE
|
| 50 |
+
|
| 51 |
+
if MODEL is None:
|
| 52 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 53 |
+
|
| 54 |
+
print(f"Downloading model from IFMedTech/Lumps repository...")
|
| 55 |
+
# Download model from private HuggingFace repository
|
| 56 |
+
# Token is automatically read from HF_TOKEN environment variable (Spaces secrets)
|
| 57 |
+
try:
|
| 58 |
+
checkpoint_path = hf_hub_download(
|
| 59 |
+
repo_id="IFMedTech/Lumps",
|
| 60 |
+
filename="lumps.pth",
|
| 61 |
+
repo_type="model",
|
| 62 |
+
token=os.environ.get("HF_TOKEN") # Use token from Spaces secrets
|
| 63 |
+
)
|
| 64 |
+
print(f"Model downloaded to: {checkpoint_path}")
|
| 65 |
+
print(f"Loading model on {DEVICE}...")
|
| 66 |
+
MODEL = load_model(checkpoint_path, DEVICE)
|
| 67 |
+
print(f"Model loaded successfully!")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Error loading model: {e}")
|
| 70 |
+
raise RuntimeError(
|
| 71 |
+
f"Failed to load model from IFMedTech/Lumps. "
|
| 72 |
+
f"Please ensure HF_TOKEN is set in Spaces secrets with read access to the private repository. "
|
| 73 |
+
f"Error: {e}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return MODEL, DEVICE
|
| 77 |
+
|
| 78 |
def predict(image, score_thresh=0.5):
|
| 79 |
"""Run inference and return image with bounding boxes."""
|
| 80 |
+
# Ensure model is loaded
|
| 81 |
+
model, device = initialize_model()
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# Preprocess image
|
| 84 |
image_tensor = preprocess_image(image)
|