update model and code for handling TSM layers
Browse files- .gitignore +2 -1
- app.py +47 -5
- rwthmaterials_dp800_network2_damage.tgz +2 -2
- rwthmaterials_dp800_network2_damage/fingerprint.pb +2 -2
- rwthmaterials_dp800_network2_damage/saved_model.pb +2 -2
- rwthmaterials_dp800_network2_damage/variables/variables.data-00000-of-00001 +2 -2
- rwthmaterials_dp800_network2_damage/variables/variables.index +0 -0
- utils.py +419 -44
.gitignore
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
-
*.pyc*
|
|
|
|
|
|
| 1 |
+
*.pyc*
|
| 2 |
+
classified_damage_sites.*
|
app.py
CHANGED
|
@@ -68,14 +68,44 @@ def damage_classification(SEM_image,image_threshold, model1_threshold, model2_th
|
|
| 68 |
##
|
| 69 |
logging.debug('---------------: prepare model 1 :=====================')
|
| 70 |
images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
|
|
|
|
|
|
|
| 71 |
|
| 72 |
logging.debug('---------------: run model 1 :=====================')
|
| 73 |
#y1_pred = model1.predict(np.asarray(images_model1, float))
|
| 74 |
-
y1_pred = model1(np.asarray(images_model1, float))
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
inclusions =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
inclusions = np.where(inclusions > model1_threshold)
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
logging.debug('---------------: model 1 update dict :=====================')
|
| 81 |
for i in range(len(inclusions[0])):
|
|
@@ -105,7 +135,19 @@ def damage_classification(SEM_image,image_threshold, model1_threshold, model2_th
|
|
| 105 |
|
| 106 |
logging.debug('---------------: run model 2 :=====================')
|
| 107 |
#y2_pred = model2.predict(np.asarray(images_model2, float))
|
| 108 |
-
y2_pred = model2(np.asarray(images_model2, float))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
|
| 111 |
|
|
|
|
| 68 |
##
|
| 69 |
logging.debug('---------------: prepare model 1 :=====================')
|
| 70 |
images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
|
| 71 |
+
from utils import debug_classification_input
|
| 72 |
+
debug_classification_input(images_model1)
|
| 73 |
|
| 74 |
logging.debug('---------------: run model 1 :=====================')
|
| 75 |
#y1_pred = model1.predict(np.asarray(images_model1, float))
|
| 76 |
+
#y1_pred = model1(np.asarray(images_model1, float))
|
| 77 |
+
# logging.debug('---------------: model1 threshold :=====================')
|
| 78 |
+
# inclusions = y1_pred[:,0].reshape(len(y1_pred),1)
|
| 79 |
+
# inclusions = np.where(inclusions > model1_threshold)
|
| 80 |
+
|
| 81 |
+
batch_model1 = np.array(images_model1, dtype=np.float32)
|
| 82 |
+
logging.info(f"Model 1 input shape: {batch_model1.shape}")
|
| 83 |
+
# Get predictions from model 1
|
| 84 |
+
y1_pred_raw = model1(batch_model1)
|
| 85 |
+
logging.info(f"Model 1 raw output type: {type(y1_pred_raw)}")
|
| 86 |
+
|
| 87 |
+
# Extract actual predictions from the model output
|
| 88 |
+
y1_pred = utils.extract_predictions_from_tfsm(y1_pred_raw)
|
| 89 |
+
logging.info(f"Model 1 predictions shape: {y1_pred.shape}")
|
| 90 |
+
logging.info(f"Model 1 predictions sample: {y1_pred[:3] if len(y1_pred) > 0 else 'Empty'}")
|
| 91 |
+
|
| 92 |
+
logging.info('---------------: model1 threshold :=====================')
|
| 93 |
+
# Handle predictions based on their shape
|
| 94 |
+
if len(y1_pred.shape) == 2:
|
| 95 |
+
# Predictions are 2D: (batch_size, num_classes)
|
| 96 |
+
inclusions = y1_pred[:, 0] # Get first column (inclusion probability)
|
| 97 |
+
elif len(y1_pred.shape) == 1:
|
| 98 |
+
# Predictions are 1D: (batch_size,)
|
| 99 |
+
inclusions = y1_pred
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unexpected prediction shape: {y1_pred.shape}")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
logging.info('---------------: model1 threshold :=====================')
|
| 105 |
inclusions = np.where(inclusions > model1_threshold)
|
| 106 |
+
logging.info('Inclusions found at indices:')
|
| 107 |
+
logging.info(inclusions)
|
| 108 |
+
|
| 109 |
|
| 110 |
logging.debug('---------------: model 1 update dict :=====================')
|
| 111 |
for i in range(len(inclusions[0])):
|
|
|
|
| 135 |
|
| 136 |
logging.debug('---------------: run model 2 :=====================')
|
| 137 |
#y2_pred = model2.predict(np.asarray(images_model2, float))
|
| 138 |
+
#y2_pred = model2(np.asarray(images_model2, float))
|
| 139 |
+
batch_model2 = np.array(images_model2, dtype=np.float32)
|
| 140 |
+
logging.info(f"Model 2 input shape: {batch_model2.shape}")
|
| 141 |
+
# Get predictions from model 2
|
| 142 |
+
y2_pred_raw = model2(batch_model2)
|
| 143 |
+
logging.info(f"Model 2 raw output type: {type(y2_pred_raw)}")
|
| 144 |
+
# Extract actual predictions from the model output
|
| 145 |
+
y2_pred = utils.extract_predictions_from_tfsm(y2_pred_raw)
|
| 146 |
+
logging.info(f"Model 2 predictions shape: {y2_pred.shape}")
|
| 147 |
+
logging.info(f"Model 2 predictions sample: {y2_pred[:3] if len(y2_pred) > 0 else 'Empty'}")
|
| 148 |
+
logging.info(y2_pred)
|
| 149 |
+
|
| 150 |
+
logging.debug('---------------: model2 threshold :=====================')
|
| 151 |
|
| 152 |
damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
|
| 153 |
|
rwthmaterials_dp800_network2_damage.tgz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9db591833c36177cb865e13a6332eaa15db98a3104860d9c3749214da6e06e0e
|
| 3 |
+
size 3943261
|
rwthmaterials_dp800_network2_damage/fingerprint.pb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca1d16acc9246b42769f826cb6b63586f8d922ed0c58ce42f9ce64d8b2680725
|
| 3 |
+
size 56
|
rwthmaterials_dp800_network2_damage/saved_model.pb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6c243d6b76e1fe71cbbe1bf7312f7148409b04d44dd796c0b1ea1219f1aaec06
|
| 3 |
+
size 620656
|
rwthmaterials_dp800_network2_damage/variables/variables.data-00000-of-00001
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a6faa8b229462f1cb0e7d57ccf3df66ab3935bff97d685a7154cd4ec381b5a34
|
| 3 |
+
size 3977928
|
rwthmaterials_dp800_network2_damage/variables/variables.index
CHANGED
|
Binary files a/rwthmaterials_dp800_network2_damage/variables/variables.index and b/rwthmaterials_dp800_network2_damage/variables/variables.index differ
|
|
|
utils.py
CHANGED
|
@@ -124,11 +124,110 @@ def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
|
|
| 124 |
|
| 125 |
return data
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
###
|
| 129 |
-
###
|
| 130 |
###
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
"""
|
| 133 |
Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
|
| 134 |
|
|
@@ -154,61 +253,337 @@ def prepare_classifier_input(panorama, centroids: list, window_size=[250, 250])
|
|
| 154 |
List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
|
| 155 |
centroids that allow full window extraction within image bounds are used.
|
| 156 |
"""
|
| 157 |
-
logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}")
|
| 158 |
|
| 159 |
-
#
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
logging.info("prepare_classifier_input: Converted multi-channel NumPy array to grayscale.")
|
| 174 |
else:
|
| 175 |
-
|
| 176 |
-
|
| 177 |
else:
|
| 178 |
-
logging.error("prepare_classifier_input:
|
| 179 |
-
raise ValueError("
|
| 180 |
|
| 181 |
-
# Now, ensure panorama_array has a channel dimension if it's 2D for consistency
|
| 182 |
-
if panorama_array.ndim == 2:
|
| 183 |
-
panorama_array = np.expand_dims(panorama_array, axis=-1) # (H, W, 1)
|
| 184 |
-
logging.info("prepare_classifier_input: Expanded 2D panorama to 3D (H,W,1).")
|
| 185 |
-
# --- MINIMAL FIX END ---
|
| 186 |
-
|
| 187 |
-
H, W, _ = panorama_array.shape # Use panorama_array here
|
| 188 |
win_h, win_w = window_size
|
| 189 |
images = []
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
-
for (cy, cx) in centroids:
|
| 192 |
# Ensure coordinates are integers
|
| 193 |
cy, cx = int(round(cy)), int(round(cx))
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
y2 = y1 + win_h
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
#
|
| 201 |
-
if
|
| 202 |
-
logging.warning(
|
|
|
|
|
|
|
|
|
|
| 203 |
continue
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
return images
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
return data
|
| 126 |
|
| 127 |
+
##
|
| 128 |
+
## orig
|
| 129 |
+
##
|
| 130 |
+
|
| 131 |
+
# ###
|
| 132 |
+
# ### cut out small images from panorama, append colour information
|
| 133 |
+
# ###
|
| 134 |
+
# def prepare_classifier_input(panorama, centroids: list, window_size=[250, 250]) -> list: # Removed np.ndarray type hint for panorama
|
| 135 |
+
# """
|
| 136 |
+
# Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
|
| 137 |
+
|
| 138 |
+
# Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
|
| 139 |
+
# normalized image suitable for use with classification neural networks that expect color input.
|
| 140 |
+
|
| 141 |
+
# Parameters
|
| 142 |
+
# ----------
|
| 143 |
+
# panorama : PIL.Image.Image or np.ndarray
|
| 144 |
+
# Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data,
|
| 145 |
+
# or a PIL Image object.
|
| 146 |
+
|
| 147 |
+
# centroids : list of [int, int]
|
| 148 |
+
# List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
|
| 149 |
+
# identified in preprocessing (e.g., clustering).
|
| 150 |
+
|
| 151 |
+
# window_size : list of int, optional
|
| 152 |
+
# Size [height, width] of each extracted image patch. Defaults to [250, 250].
|
| 153 |
+
|
| 154 |
+
# Returns
|
| 155 |
+
# -------
|
| 156 |
+
# list of np.ndarray
|
| 157 |
+
# List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
|
| 158 |
+
# centroids that allow full window extraction within image bounds are used.
|
| 159 |
+
# """
|
| 160 |
+
# logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}") # Added logging
|
| 161 |
+
|
| 162 |
+
# # --- MINIMAL FIX START ---
|
| 163 |
+
# # Convert PIL Image to NumPy array if necessary
|
| 164 |
+
# if isinstance(panorama, Image.Image):
|
| 165 |
+
# # Convert to grayscale NumPy array as your original code expects this structure for processing
|
| 166 |
+
# if panorama.mode == 'RGB':
|
| 167 |
+
# panorama_array = np.array(panorama.convert('L'))
|
| 168 |
+
# logging.info("prepare_classifier_input: Converted RGB PIL Image to grayscale NumPy array.")
|
| 169 |
+
# else:
|
| 170 |
+
# panorama_array = np.array(panorama)
|
| 171 |
+
# logging.info("prepare_classifier_input: Converted PIL Image to grayscale NumPy array.")
|
| 172 |
+
# elif isinstance(panorama, np.ndarray):
|
| 173 |
+
# # Ensure it's treated as a grayscale array for consistency with original logic
|
| 174 |
+
# if panorama.ndim == 3 and panorama.shape[2] in [3, 4]: # RGB or RGBA NumPy array
|
| 175 |
+
# panorama_array = np.mean(panorama, axis=2).astype(panorama.dtype) # Convert to grayscale
|
| 176 |
+
# logging.info("prepare_classifier_input: Converted multi-channel NumPy array to grayscale.")
|
| 177 |
+
# else:
|
| 178 |
+
# panorama_array = panorama # Assume it's already grayscale 2D or (H,W,1)
|
| 179 |
+
# logging.info("prepare_classifier_input: Panorama is already a suitable NumPy array.")
|
| 180 |
+
# else:
|
| 181 |
+
# logging.error("prepare_classifier_input: Unsupported panorama format received. Expected PIL Image or NumPy array.")
|
| 182 |
+
# raise ValueError("Unsupported panorama format for classifier input.")
|
| 183 |
+
|
| 184 |
+
# # Now, ensure panorama_array has a channel dimension if it's 2D for consistency
|
| 185 |
+
# if panorama_array.ndim == 2:
|
| 186 |
+
# panorama_array = np.expand_dims(panorama_array, axis=-1) # (H, W, 1)
|
| 187 |
+
# logging.info("prepare_classifier_input: Expanded 2D panorama to 3D (H,W,1).")
|
| 188 |
+
# # --- MINIMAL FIX END ---
|
| 189 |
+
|
| 190 |
+
# H, W, _ = panorama_array.shape # Use panorama_array here
|
| 191 |
+
# win_h, win_w = window_size
|
| 192 |
+
# images = []
|
| 193 |
+
|
| 194 |
+
# for (cy, cx) in centroids:
|
| 195 |
+
# # Ensure coordinates are integers
|
| 196 |
+
# cy, cx = int(round(cy)), int(round(cx))
|
| 197 |
+
|
| 198 |
+
# x1 = int(cx - win_w / 2)
|
| 199 |
+
# y1 = int(cy - win_h / 2)
|
| 200 |
+
# x2 = x1 + win_w
|
| 201 |
+
# y2 = y1 + win_h
|
| 202 |
+
|
| 203 |
+
# # Skip if patch would go out of bounds
|
| 204 |
+
# if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
|
| 205 |
+
# logging.warning(f"prepare_classifier_input: Skipping centroid ({cy},{cx}) as patch is out of bounds.") # Added warning
|
| 206 |
+
# continue
|
| 207 |
+
|
| 208 |
+
# # Extract and normalize patch
|
| 209 |
+
# patch = panorama_array[y1:y2, x1:x2, 0].astype(np.float32) # Use panorama_array
|
| 210 |
+
# patch = patch * 2. / 255. - 1. # Keep your original normalization
|
| 211 |
+
|
| 212 |
+
# # Replicate grayscale channel to simulate RGB
|
| 213 |
+
# patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
|
| 214 |
+
# images.append(patch_color)
|
| 215 |
+
|
| 216 |
+
# return images
|
| 217 |
|
| 218 |
###
|
| 219 |
+
### refactored
|
| 220 |
###
|
| 221 |
+
import numpy as np
|
| 222 |
+
from PIL import Image
|
| 223 |
+
import logging
|
| 224 |
+
from typing import List, Union, Tuple
|
| 225 |
+
|
| 226 |
+
def prepare_classifier_input(
|
| 227 |
+
panorama: Union[Image.Image, np.ndarray],
|
| 228 |
+
centroids: List[Tuple[int, int]],
|
| 229 |
+
window_size: List[int] = [250, 250]
|
| 230 |
+
) -> List[np.ndarray]:
|
| 231 |
"""
|
| 232 |
Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
|
| 233 |
|
|
|
|
| 253 |
List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
|
| 254 |
centroids that allow full window extraction within image bounds are used.
|
| 255 |
"""
|
| 256 |
+
logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}")
|
| 257 |
|
| 258 |
+
# Convert input to standardized NumPy array format
|
| 259 |
+
panorama_array = _convert_to_grayscale_array(panorama)
|
| 260 |
+
|
| 261 |
+
# Ensure we have the correct dimensions
|
| 262 |
+
if panorama_array.ndim == 2:
|
| 263 |
+
H, W = panorama_array.shape
|
| 264 |
+
logging.info("prepare_classifier_input: Working with 2D grayscale array.")
|
| 265 |
+
elif panorama_array.ndim == 3:
|
| 266 |
+
H, W, C = panorama_array.shape
|
| 267 |
+
if C == 1:
|
| 268 |
+
# Squeeze the single channel dimension for easier processing
|
| 269 |
+
panorama_array = panorama_array.squeeze(axis=2)
|
| 270 |
+
H, W = panorama_array.shape
|
| 271 |
+
logging.info("prepare_classifier_input: Squeezed single channel dimension.")
|
|
|
|
| 272 |
else:
|
| 273 |
+
logging.error(f"prepare_classifier_input: Unexpected number of channels: {C}")
|
| 274 |
+
raise ValueError(f"Expected 1 channel, got {C}")
|
| 275 |
else:
|
| 276 |
+
logging.error(f"prepare_classifier_input: Unexpected array dimensions: {panorama_array.ndim}")
|
| 277 |
+
raise ValueError(f"Expected 2D or 3D array, got {panorama_array.ndim}D")
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
win_h, win_w = window_size
|
| 280 |
images = []
|
| 281 |
+
|
| 282 |
+
logging.info(f"prepare_classifier_input: Image dimensions: {H}x{W}, Window size: {win_h}x{win_w}")
|
| 283 |
+
logging.info(f"prepare_classifier_input: Processing {len(centroids)} centroids")
|
| 284 |
|
| 285 |
+
for i, (cy, cx) in enumerate(centroids):
|
| 286 |
# Ensure coordinates are integers
|
| 287 |
cy, cx = int(round(cy)), int(round(cx))
|
| 288 |
+
|
| 289 |
+
# Calculate patch boundaries
|
| 290 |
+
half_h, half_w = win_h // 2, win_w // 2
|
| 291 |
+
y1 = cy - half_h
|
| 292 |
y2 = y1 + win_h
|
| 293 |
+
x1 = cx - half_w
|
| 294 |
+
x2 = x1 + win_w
|
| 295 |
|
| 296 |
+
# Check bounds more explicitly
|
| 297 |
+
if y1 < 0 or x1 < 0 or y2 > H or x2 > W:
|
| 298 |
+
logging.warning(
|
| 299 |
+
f"prepare_classifier_input: Skipping centroid {i+1}/{len(centroids)} "
|
| 300 |
+
f"at ({cy},{cx}) - patch bounds ({y1}:{y2}, {x1}:{x2}) exceed image bounds (0:{H}, 0:{W})"
|
| 301 |
+
)
|
| 302 |
continue
|
| 303 |
|
| 304 |
+
try:
|
| 305 |
+
# Extract patch with explicit bounds checking
|
| 306 |
+
patch = panorama_array[y1:y2, x1:x2].astype(np.float32)
|
| 307 |
+
|
| 308 |
+
# Verify patch dimensions
|
| 309 |
+
if patch.shape != (win_h, win_w):
|
| 310 |
+
logging.warning(
|
| 311 |
+
f"prepare_classifier_input: Patch {i+1} has unexpected shape {patch.shape}, "
|
| 312 |
+
f"expected ({win_h}, {win_w}). Skipping."
|
| 313 |
+
)
|
| 314 |
+
continue
|
| 315 |
+
|
| 316 |
+
# Normalize patch: [0, 255] -> [-1, 1]
|
| 317 |
+
patch_normalized = (patch * 2.0 / 255.0) - 1.0
|
| 318 |
+
|
| 319 |
+
# Convert to 3-channel RGB-like format
|
| 320 |
+
patch_rgb = np.stack([patch_normalized] * 3, axis=2)
|
| 321 |
+
|
| 322 |
+
images.append(patch_rgb)
|
| 323 |
+
logging.debug(f"prepare_classifier_input: Successfully processed centroid {i+1} at ({cy},{cx})")
|
| 324 |
+
|
| 325 |
+
except Exception as e:
|
| 326 |
+
logging.error(
|
| 327 |
+
f"prepare_classifier_input: Error processing centroid {i+1} at ({cy},{cx}): {e}"
|
| 328 |
+
)
|
| 329 |
+
continue
|
| 330 |
|
| 331 |
+
logging.info(f"prepare_classifier_input: Successfully extracted {len(images)} patches from {len(centroids)} centroids")
|
| 332 |
+
|
| 333 |
+
# Add diagnostic information about the output
|
| 334 |
+
if images:
|
| 335 |
+
sample_shape = images[0].shape
|
| 336 |
+
sample_dtype = images[0].dtype
|
| 337 |
+
sample_min = images[0].min()
|
| 338 |
+
sample_max = images[0].max()
|
| 339 |
+
logging.info(f"prepare_classifier_input: Output patches - Shape: {sample_shape}, Dtype: {sample_dtype}, Range: [{sample_min:.3f}, {sample_max:.3f}]")
|
| 340 |
+
|
| 341 |
+
# Verify all patches have consistent shapes
|
| 342 |
+
shapes = [img.shape for img in images]
|
| 343 |
+
if not all(shape == sample_shape for shape in shapes):
|
| 344 |
+
logging.warning("prepare_classifier_input: Inconsistent patch shapes detected!")
|
| 345 |
+
for i, shape in enumerate(shapes):
|
| 346 |
+
if shape != sample_shape:
|
| 347 |
+
logging.warning(f" Patch {i}: {shape} (expected {sample_shape})")
|
| 348 |
+
else:
|
| 349 |
+
logging.warning("prepare_classifier_input: No valid patches were extracted!")
|
| 350 |
+
|
| 351 |
return images
|
| 352 |
|
| 353 |
+
|
| 354 |
+
def _convert_to_grayscale_array(panorama: Union[Image.Image, np.ndarray]) -> np.ndarray:
|
| 355 |
+
"""
|
| 356 |
+
Helper function to convert various input formats to a standardized grayscale NumPy array.
|
| 357 |
+
|
| 358 |
+
Parameters
|
| 359 |
+
----------
|
| 360 |
+
panorama : PIL.Image.Image or np.ndarray
|
| 361 |
+
Input image in various formats
|
| 362 |
+
|
| 363 |
+
Returns
|
| 364 |
+
-------
|
| 365 |
+
np.ndarray
|
| 366 |
+
Standardized grayscale array
|
| 367 |
+
"""
|
| 368 |
+
if isinstance(panorama, Image.Image):
|
| 369 |
+
if panorama.mode in ['RGB', 'RGBA']:
|
| 370 |
+
# Convert to grayscale
|
| 371 |
+
panorama_array = np.array(panorama.convert('L'))
|
| 372 |
+
logging.info("_convert_to_grayscale_array: Converted RGB/RGBA PIL Image to grayscale.")
|
| 373 |
+
elif panorama.mode == 'L':
|
| 374 |
+
panorama_array = np.array(panorama)
|
| 375 |
+
logging.info("_convert_to_grayscale_array: Converted grayscale PIL Image to NumPy array.")
|
| 376 |
+
else:
|
| 377 |
+
# Handle other modes by converting to grayscale
|
| 378 |
+
panorama_array = np.array(panorama.convert('L'))
|
| 379 |
+
logging.info(f"_convert_to_grayscale_array: Converted PIL Image mode '{panorama.mode}' to grayscale.")
|
| 380 |
+
|
| 381 |
+
elif isinstance(panorama, np.ndarray):
|
| 382 |
+
if panorama.ndim == 2:
|
| 383 |
+
# Already grayscale
|
| 384 |
+
panorama_array = panorama.copy()
|
| 385 |
+
logging.info("_convert_to_grayscale_array: Using existing 2D grayscale array.")
|
| 386 |
+
elif panorama.ndim == 3:
|
| 387 |
+
if panorama.shape[2] in [3, 4]: # RGB or RGBA
|
| 388 |
+
# Convert to grayscale using luminance weights
|
| 389 |
+
if panorama.shape[2] == 3: # RGB
|
| 390 |
+
panorama_array = np.dot(panorama, [0.299, 0.587, 0.114]).astype(panorama.dtype)
|
| 391 |
+
else: # RGBA
|
| 392 |
+
panorama_array = np.dot(panorama[:, :, :3], [0.299, 0.587, 0.114]).astype(panorama.dtype)
|
| 393 |
+
logging.info("_convert_to_grayscale_array: Converted multi-channel NumPy array to grayscale using luminance weights.")
|
| 394 |
+
elif panorama.shape[2] == 1:
|
| 395 |
+
# Already single channel
|
| 396 |
+
panorama_array = panorama.copy()
|
| 397 |
+
logging.info("_convert_to_grayscale_array: Using existing single-channel array.")
|
| 398 |
+
else:
|
| 399 |
+
raise ValueError(f"Unsupported number of channels: {panorama.shape[2]}")
|
| 400 |
+
else:
|
| 401 |
+
raise ValueError(f"Unsupported array dimensions: {panorama.ndim}")
|
| 402 |
+
else:
|
| 403 |
+
raise ValueError(f"Unsupported panorama type: {type(panorama)}")
|
| 404 |
+
|
| 405 |
+
return panorama_array
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
##
|
| 409 |
+
## debug
|
| 410 |
+
##
|
| 411 |
+
import numpy as np
|
| 412 |
+
import logging
|
| 413 |
+
from typing import List, Any
|
| 414 |
+
|
| 415 |
+
def debug_classification_input(patches: List[np.ndarray], model: Any = None) -> None:
|
| 416 |
+
"""
|
| 417 |
+
Debug function to help identify issues in the classification pipeline.
|
| 418 |
+
Call this right before your classification step.
|
| 419 |
+
|
| 420 |
+
Parameters
|
| 421 |
+
----------
|
| 422 |
+
patches : List[np.ndarray]
|
| 423 |
+
List of image patches from prepare_classifier_input
|
| 424 |
+
model : Any, optional
|
| 425 |
+
Your classification model (for additional debugging)
|
| 426 |
+
"""
|
| 427 |
+
logging.info("=== CLASSIFICATION DEBUG INFO ===")
|
| 428 |
+
logging.info(f"Number of patches: {len(patches)}")
|
| 429 |
+
|
| 430 |
+
if not patches:
|
| 431 |
+
logging.error("No patches provided for classification!")
|
| 432 |
+
return
|
| 433 |
+
|
| 434 |
+
for i, patch in enumerate(patches):
|
| 435 |
+
logging.info(f"Patch {i}:")
|
| 436 |
+
logging.info(f" Shape: {patch.shape}")
|
| 437 |
+
logging.info(f" Dtype: {patch.dtype}")
|
| 438 |
+
logging.info(f" Range: [{patch.min():.3f}, {patch.max():.3f}]")
|
| 439 |
+
logging.info(f" Memory layout: {patch.flags}")
|
| 440 |
+
|
| 441 |
+
# Check for common issues
|
| 442 |
+
if np.isnan(patch).any():
|
| 443 |
+
logging.warning(f" Contains NaN values: {np.isnan(patch).sum()}")
|
| 444 |
+
if np.isinf(patch).any():
|
| 445 |
+
logging.warning(f" Contains infinite values: {np.isinf(patch).sum()}")
|
| 446 |
+
|
| 447 |
+
# Check if patch is contiguous (some models require this)
|
| 448 |
+
if not patch.flags.c_contiguous:
|
| 449 |
+
logging.warning(f" Patch {i} is not C-contiguous")
|
| 450 |
+
|
| 451 |
+
# Test conversion to common formats
|
| 452 |
+
try:
|
| 453 |
+
patches_array = np.array(patches)
|
| 454 |
+
logging.info(f"Stacked array shape: {patches_array.shape}")
|
| 455 |
+
logging.info(f"Stacked array dtype: {patches_array.dtype}")
|
| 456 |
+
except Exception as e:
|
| 457 |
+
logging.error(f"Failed to stack patches into array: {e}")
|
| 458 |
+
|
| 459 |
+
# Test batch preparation (common source of slice errors)
|
| 460 |
+
try:
|
| 461 |
+
if len(patches) > 0:
|
| 462 |
+
# Common preprocessing steps that might cause issues
|
| 463 |
+
test_batch = np.stack(patches, axis=0) # Shape: (batch_size, height, width, channels)
|
| 464 |
+
logging.info(f"Test batch shape: {test_batch.shape}")
|
| 465 |
+
|
| 466 |
+
# Test various indexing operations that might cause slice errors
|
| 467 |
+
test_slice = test_batch[0] # Should work
|
| 468 |
+
logging.info(f"Single item slice shape: {test_slice.shape}")
|
| 469 |
+
|
| 470 |
+
test_batch_slice = test_batch[:] # Should work
|
| 471 |
+
logging.info(f"Full batch slice shape: {test_batch_slice.shape}")
|
| 472 |
+
|
| 473 |
+
except Exception as e:
|
| 474 |
+
logging.error(f"Error during batch preparation testing: {e}")
|
| 475 |
+
logging.error(f"Error type: {type(e)}")
|
| 476 |
+
import traceback
|
| 477 |
+
logging.error(f"Traceback: {traceback.format_exc()}")
|
| 478 |
+
|
| 479 |
+
logging.info("=== END CLASSIFICATION DEBUG ===")
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def safe_classify_patches(patches: List[np.ndarray], classify_func, **kwargs) -> Any:
|
| 483 |
+
"""
|
| 484 |
+
Wrapper function to safely run classification with better error handling.
|
| 485 |
+
|
| 486 |
+
Parameters
|
| 487 |
+
----------
|
| 488 |
+
patches : List[np.ndarray]
|
| 489 |
+
List of image patches
|
| 490 |
+
classify_func : callable
|
| 491 |
+
Your classification function
|
| 492 |
+
**kwargs
|
| 493 |
+
Additional arguments for classify_func
|
| 494 |
+
|
| 495 |
+
Returns
|
| 496 |
+
-------
|
| 497 |
+
Any
|
| 498 |
+
Classification results or None if error occurred
|
| 499 |
+
"""
|
| 500 |
+
try:
|
| 501 |
+
logging.info("Starting safe classification...")
|
| 502 |
+
|
| 503 |
+
# Debug the input
|
| 504 |
+
debug_classification_input(patches)
|
| 505 |
+
|
| 506 |
+
# Ensure patches are properly formatted
|
| 507 |
+
if not patches:
|
| 508 |
+
logging.error("No patches to classify")
|
| 509 |
+
return None
|
| 510 |
+
|
| 511 |
+
# Make sure all patches are contiguous arrays
|
| 512 |
+
patches_clean = []
|
| 513 |
+
for i, patch in enumerate(patches):
|
| 514 |
+
if not patch.flags.c_contiguous:
|
| 515 |
+
patch_clean = np.ascontiguousarray(patch)
|
| 516 |
+
logging.info(f"Made patch {i} contiguous")
|
| 517 |
+
else:
|
| 518 |
+
patch_clean = patch
|
| 519 |
+
patches_clean.append(patch_clean)
|
| 520 |
+
|
| 521 |
+
# Call the actual classification function
|
| 522 |
+
logging.info("Calling classification function...")
|
| 523 |
+
result = classify_func(patches_clean, **kwargs)
|
| 524 |
+
logging.info("Classification completed successfully")
|
| 525 |
+
|
| 526 |
+
return result
|
| 527 |
+
|
| 528 |
+
except Exception as e:
|
| 529 |
+
logging.error(f"Error in safe_classify_patches: {e}")
|
| 530 |
+
logging.error(f"Error type: {type(e)}")
|
| 531 |
+
import traceback
|
| 532 |
+
logging.error(f"Full traceback: {traceback.format_exc()}")
|
| 533 |
+
return None
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
# Example usage function
|
| 537 |
+
def example_usage():
|
| 538 |
+
"""
|
| 539 |
+
Example of how to use the debug functions in your pipeline
|
| 540 |
+
"""
|
| 541 |
+
# Your existing code that calls prepare_classifier_input
|
| 542 |
+
# patches = prepare_classifier_input(panorama, centroids, window_size)
|
| 543 |
+
|
| 544 |
+
# Add debugging before classification
|
| 545 |
+
# debug_classification_input(patches)
|
| 546 |
+
|
| 547 |
+
# Use safe wrapper for classification
|
| 548 |
+
# results = safe_classify_patches(patches, your_classify_function, model=your_model)
|
| 549 |
+
|
| 550 |
+
pass
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
########################################
|
| 554 |
+
##
|
| 555 |
+
##
|
| 556 |
+
########################################
|
| 557 |
+
def extract_predictions_from_tfsm(model_output):
|
| 558 |
+
"""
|
| 559 |
+
Helper function to extract predictions from TFSMLayer output.
|
| 560 |
+
TFSMLayer often returns a dictionary with multiple outputs.
|
| 561 |
+
"""
|
| 562 |
+
logging.info(f"Model output type: {type(model_output)}")
|
| 563 |
+
logging.info(f"Model output keys: {model_output.keys() if isinstance(model_output, dict) else 'Not a dict'}")
|
| 564 |
+
|
| 565 |
+
if isinstance(model_output, dict):
|
| 566 |
+
# Try common output key names
|
| 567 |
+
possible_keys = ['output', 'predictions', 'dense', 'logits', 'probabilities']
|
| 568 |
+
|
| 569 |
+
# First, log all available keys
|
| 570 |
+
available_keys = list(model_output.keys())
|
| 571 |
+
logging.info(f"Available output keys: {available_keys}")
|
| 572 |
+
|
| 573 |
+
# Try to find the right output
|
| 574 |
+
for key in possible_keys:
|
| 575 |
+
if key in model_output:
|
| 576 |
+
logging.info(f"Using output key: {key}")
|
| 577 |
+
return model_output[key].numpy()
|
| 578 |
+
|
| 579 |
+
# If no standard key found, use the first available key
|
| 580 |
+
if available_keys:
|
| 581 |
+
first_key = available_keys[0]
|
| 582 |
+
logging.info(f"Using first available key: {first_key}")
|
| 583 |
+
return model_output[first_key].numpy()
|
| 584 |
+
else:
|
| 585 |
+
raise ValueError("No output keys found in model response")
|
| 586 |
+
else:
|
| 587 |
+
# If it's not a dictionary, assume it's already the tensor we need
|
| 588 |
+
logging.info("Model output is not a dictionary, using directly")
|
| 589 |
+
return model_output.numpy() if hasattr(model_output, 'numpy') else np.array(model_output)
|