IFMedTechdemo commited on
Commit
20ca3ee
·
verified ·
1 Parent(s): 3ce3ccf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -7
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import cv2
3
  import torch
4
  import torchvision
5
  import numpy as np
@@ -7,6 +6,7 @@ from PIL import Image
7
  import matplotlib.pyplot as plt
8
  import matplotlib.patches as patches
9
  from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
 
10
  import os
11
  import io
12
 
@@ -16,12 +16,16 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
16
  CLASS_NAMES = {1: 'Nipple', 2: 'Lump'}
17
  CLASS_COLORS = {1: 'white', 2: 'white'}
18
 
 
 
 
 
19
  def preprocess_image(image):
20
  """Load and preprocess image for Faster R-CNN."""
21
  # Convert PIL Image to numpy array
22
  image = np.array(image)
23
 
24
- # Convert RGB to RGB (already in correct format from Gradio)
25
  image = image.astype(np.float32) / 255.0 # Normalize to [0,1]
26
 
27
  # Normalize using ImageNet mean and std
@@ -40,13 +44,41 @@ def load_model(checkpoint_path, device):
40
  model.to(device).eval()
41
  return model
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def predict(image, score_thresh=0.5):
44
  """Run inference and return image with bounding boxes."""
45
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
-
47
- # Load model
48
- checkpoint_path = "lumps.pth" # This will be downloaded from the model repo
49
- model = load_model(checkpoint_path, device)
50
 
51
  # Preprocess image
52
  image_tensor = preprocess_image(image)
 
1
  import gradio as gr
 
2
  import torch
3
  import torchvision
4
  import numpy as np
 
6
  import matplotlib.pyplot as plt
7
  import matplotlib.patches as patches
8
  from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
9
+ from huggingface_hub import hf_hub_download
10
  import os
11
  import io
12
 
 
16
  CLASS_NAMES = {1: 'Nipple', 2: 'Lump'}
17
  CLASS_COLORS = {1: 'white', 2: 'white'}
18
 
19
+ # Global variables for model (load once at startup)
20
+ MODEL = None
21
+ DEVICE = None
22
+
23
  def preprocess_image(image):
24
  """Load and preprocess image for Faster R-CNN."""
25
  # Convert PIL Image to numpy array
26
  image = np.array(image)
27
 
28
+ # Already in RGB format from Gradio
29
  image = image.astype(np.float32) / 255.0 # Normalize to [0,1]
30
 
31
  # Normalize using ImageNet mean and std
 
44
  model.to(device).eval()
45
  return model
46
 
47
+ def initialize_model():
48
+ """Initialize model at startup by downloading from HF Hub."""
49
+ global MODEL, DEVICE
50
+
51
+ if MODEL is None:
52
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+
54
+ print(f"Downloading model from IFMedTech/Lumps repository...")
55
+ # Download model from private HuggingFace repository
56
+ # Token is automatically read from HF_TOKEN environment variable (Spaces secrets)
57
+ try:
58
+ checkpoint_path = hf_hub_download(
59
+ repo_id="IFMedTech/Lumps",
60
+ filename="lumps.pth",
61
+ repo_type="model",
62
+ token=os.environ.get("HF_TOKEN") # Use token from Spaces secrets
63
+ )
64
+ print(f"Model downloaded to: {checkpoint_path}")
65
+ print(f"Loading model on {DEVICE}...")
66
+ MODEL = load_model(checkpoint_path, DEVICE)
67
+ print(f"Model loaded successfully!")
68
+ except Exception as e:
69
+ print(f"Error loading model: {e}")
70
+ raise RuntimeError(
71
+ f"Failed to load model from IFMedTech/Lumps. "
72
+ f"Please ensure HF_TOKEN is set in Spaces secrets with read access to the private repository. "
73
+ f"Error: {e}"
74
+ )
75
+
76
+ return MODEL, DEVICE
77
+
78
  def predict(image, score_thresh=0.5):
79
  """Run inference and return image with bounding boxes."""
80
+ # Ensure model is loaded
81
+ model, device = initialize_model()
 
 
 
82
 
83
  # Preprocess image
84
  image_tensor = preprocess_image(image)