kerzel commited on
Commit
3b63b33
·
1 Parent(s): 50dff88

update model and code for handling TSM layers

Browse files
.gitignore CHANGED
@@ -1 +1,2 @@
1
- *.pyc*
 
 
1
+ *.pyc*
2
+ classified_damage_sites.*
app.py CHANGED
@@ -68,14 +68,44 @@ def damage_classification(SEM_image,image_threshold, model1_threshold, model2_th
68
  ##
69
  logging.debug('---------------: prepare model 1 :=====================')
70
  images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
 
 
71
 
72
  logging.debug('---------------: run model 1 :=====================')
73
  #y1_pred = model1.predict(np.asarray(images_model1, float))
74
- y1_pred = model1(np.asarray(images_model1, float))
75
-
76
- logging.debug('---------------: model1 threshold :=====================')
77
- inclusions = y1_pred[:,0].reshape(len(y1_pred),1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  inclusions = np.where(inclusions > model1_threshold)
 
 
 
79
 
80
  logging.debug('---------------: model 1 update dict :=====================')
81
  for i in range(len(inclusions[0])):
@@ -105,7 +135,19 @@ def damage_classification(SEM_image,image_threshold, model1_threshold, model2_th
105
 
106
  logging.debug('---------------: run model 2 :=====================')
107
  #y2_pred = model2.predict(np.asarray(images_model2, float))
108
- y2_pred = model2(np.asarray(images_model2, float))
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
111
 
 
68
  ##
69
  logging.debug('---------------: prepare model 1 :=====================')
70
  images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
71
+ from utils import debug_classification_input
72
+ debug_classification_input(images_model1)
73
 
74
  logging.debug('---------------: run model 1 :=====================')
75
  #y1_pred = model1.predict(np.asarray(images_model1, float))
76
+ #y1_pred = model1(np.asarray(images_model1, float))
77
+ # logging.debug('---------------: model1 threshold :=====================')
78
+ # inclusions = y1_pred[:,0].reshape(len(y1_pred),1)
79
+ # inclusions = np.where(inclusions > model1_threshold)
80
+
81
+ batch_model1 = np.array(images_model1, dtype=np.float32)
82
+ logging.info(f"Model 1 input shape: {batch_model1.shape}")
83
+ # Get predictions from model 1
84
+ y1_pred_raw = model1(batch_model1)
85
+ logging.info(f"Model 1 raw output type: {type(y1_pred_raw)}")
86
+
87
+ # Extract actual predictions from the model output
88
+ y1_pred = utils.extract_predictions_from_tfsm(y1_pred_raw)
89
+ logging.info(f"Model 1 predictions shape: {y1_pred.shape}")
90
+ logging.info(f"Model 1 predictions sample: {y1_pred[:3] if len(y1_pred) > 0 else 'Empty'}")
91
+
92
+ logging.info('---------------: model1 threshold :=====================')
93
+ # Handle predictions based on their shape
94
+ if len(y1_pred.shape) == 2:
95
+ # Predictions are 2D: (batch_size, num_classes)
96
+ inclusions = y1_pred[:, 0] # Get first column (inclusion probability)
97
+ elif len(y1_pred.shape) == 1:
98
+ # Predictions are 1D: (batch_size,)
99
+ inclusions = y1_pred
100
+ else:
101
+ raise ValueError(f"Unexpected prediction shape: {y1_pred.shape}")
102
+
103
+
104
+ logging.info('---------------: model1 threshold :=====================')
105
  inclusions = np.where(inclusions > model1_threshold)
106
+ logging.info('Inclusions found at indices:')
107
+ logging.info(inclusions)
108
+
109
 
110
  logging.debug('---------------: model 1 update dict :=====================')
111
  for i in range(len(inclusions[0])):
 
135
 
136
  logging.debug('---------------: run model 2 :=====================')
137
  #y2_pred = model2.predict(np.asarray(images_model2, float))
138
+ #y2_pred = model2(np.asarray(images_model2, float))
139
+ batch_model2 = np.array(images_model2, dtype=np.float32)
140
+ logging.info(f"Model 2 input shape: {batch_model2.shape}")
141
+ # Get predictions from model 2
142
+ y2_pred_raw = model2(batch_model2)
143
+ logging.info(f"Model 2 raw output type: {type(y2_pred_raw)}")
144
+ # Extract actual predictions from the model output
145
+ y2_pred = utils.extract_predictions_from_tfsm(y2_pred_raw)
146
+ logging.info(f"Model 2 predictions shape: {y2_pred.shape}")
147
+ logging.info(f"Model 2 predictions sample: {y2_pred[:3] if len(y2_pred) > 0 else 'Empty'}")
148
+ logging.info(y2_pred)
149
+
150
+ logging.debug('---------------: model2 threshold :=====================')
151
 
152
  damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
153
 
rwthmaterials_dp800_network2_damage.tgz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c0b85a55ba6f970661ae81eb14063dd157538d310ca4aa3288e9f263f58b6749
3
- size 80796486
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9db591833c36177cb865e13a6332eaa15db98a3104860d9c3749214da6e06e0e
3
+ size 3943261
rwthmaterials_dp800_network2_damage/fingerprint.pb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:42b75267132948545aba844b272093f741daea371271c5adca0122be1bfb91cf
3
- size 59
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca1d16acc9246b42769f826cb6b63586f8d922ed0c58ce42f9ce64d8b2680725
3
+ size 56
rwthmaterials_dp800_network2_damage/saved_model.pb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8644d5da2015e6a2ce6fd4181ede406e98171e7fa170ab256411a639b1bbb014
3
- size 4482600
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c243d6b76e1fe71cbbe1bf7312f7148409b04d44dd796c0b1ea1219f1aaec06
3
+ size 620656
rwthmaterials_dp800_network2_damage/variables/variables.data-00000-of-00001 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f65b218b3718aac76d545ebef2d6a6ef9b4c2d7e22e7885bc3dd91836563f5ca
3
- size 87475223
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6faa8b229462f1cb0e7d57ccf3df66ab3935bff97d685a7154cd4ec381b5a34
3
+ size 3977928
rwthmaterials_dp800_network2_damage/variables/variables.index CHANGED
Binary files a/rwthmaterials_dp800_network2_damage/variables/variables.index and b/rwthmaterials_dp800_network2_damage/variables/variables.index differ
 
utils.py CHANGED
@@ -124,11 +124,110 @@ def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
124
 
125
  return data
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  ###
129
- ### cut out small images from panorama, append colour information
130
  ###
131
- def prepare_classifier_input(panorama, centroids: list, window_size=[250, 250]) -> list: # Removed np.ndarray type hint for panorama
 
 
 
 
 
 
 
 
 
132
  """
133
  Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
134
 
@@ -154,61 +253,337 @@ def prepare_classifier_input(panorama, centroids: list, window_size=[250, 250])
154
  List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
155
  centroids that allow full window extraction within image bounds are used.
156
  """
157
- logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}") # Added logging
158
 
159
- # --- MINIMAL FIX START ---
160
- # Convert PIL Image to NumPy array if necessary
161
- if isinstance(panorama, Image.Image):
162
- # Convert to grayscale NumPy array as your original code expects this structure for processing
163
- if panorama.mode == 'RGB':
164
- panorama_array = np.array(panorama.convert('L'))
165
- logging.info("prepare_classifier_input: Converted RGB PIL Image to grayscale NumPy array.")
166
- else:
167
- panorama_array = np.array(panorama)
168
- logging.info("prepare_classifier_input: Converted PIL Image to grayscale NumPy array.")
169
- elif isinstance(panorama, np.ndarray):
170
- # Ensure it's treated as a grayscale array for consistency with original logic
171
- if panorama.ndim == 3 and panorama.shape[2] in [3, 4]: # RGB or RGBA NumPy array
172
- panorama_array = np.mean(panorama, axis=2).astype(panorama.dtype) # Convert to grayscale
173
- logging.info("prepare_classifier_input: Converted multi-channel NumPy array to grayscale.")
174
  else:
175
- panorama_array = panorama # Assume it's already grayscale 2D or (H,W,1)
176
- logging.info("prepare_classifier_input: Panorama is already a suitable NumPy array.")
177
  else:
178
- logging.error("prepare_classifier_input: Unsupported panorama format received. Expected PIL Image or NumPy array.")
179
- raise ValueError("Unsupported panorama format for classifier input.")
180
 
181
- # Now, ensure panorama_array has a channel dimension if it's 2D for consistency
182
- if panorama_array.ndim == 2:
183
- panorama_array = np.expand_dims(panorama_array, axis=-1) # (H, W, 1)
184
- logging.info("prepare_classifier_input: Expanded 2D panorama to 3D (H,W,1).")
185
- # --- MINIMAL FIX END ---
186
-
187
- H, W, _ = panorama_array.shape # Use panorama_array here
188
  win_h, win_w = window_size
189
  images = []
 
 
 
190
 
191
- for (cy, cx) in centroids:
192
  # Ensure coordinates are integers
193
  cy, cx = int(round(cy)), int(round(cx))
194
-
195
- x1 = int(cx - win_w / 2)
196
- y1 = int(cy - win_h / 2)
197
- x2 = x1 + win_w
198
  y2 = y1 + win_h
 
 
199
 
200
- # Skip if patch would go out of bounds
201
- if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
202
- logging.warning(f"prepare_classifier_input: Skipping centroid ({cy},{cx}) as patch is out of bounds.") # Added warning
 
 
 
203
  continue
204
 
205
- # Extract and normalize patch
206
- patch = panorama_array[y1:y2, x1:x2, 0].astype(np.float32) # Use panorama_array
207
- patch = patch * 2. / 255. - 1. # Keep your original normalization
208
-
209
- # Replicate grayscale channel to simulate RGB
210
- patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
211
- images.append(patch_color)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  return images
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  return data
126
 
127
+ ##
128
+ ## orig
129
+ ##
130
+
131
+ # ###
132
+ # ### cut out small images from panorama, append colour information
133
+ # ###
134
+ # def prepare_classifier_input(panorama, centroids: list, window_size=[250, 250]) -> list: # Removed np.ndarray type hint for panorama
135
+ # """
136
+ # Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
137
+
138
+ # Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
139
+ # normalized image suitable for use with classification neural networks that expect color input.
140
+
141
+ # Parameters
142
+ # ----------
143
+ # panorama : PIL.Image.Image or np.ndarray
144
+ # Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data,
145
+ # or a PIL Image object.
146
+
147
+ # centroids : list of [int, int]
148
+ # List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
149
+ # identified in preprocessing (e.g., clustering).
150
+
151
+ # window_size : list of int, optional
152
+ # Size [height, width] of each extracted image patch. Defaults to [250, 250].
153
+
154
+ # Returns
155
+ # -------
156
+ # list of np.ndarray
157
+ # List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
158
+ # centroids that allow full window extraction within image bounds are used.
159
+ # """
160
+ # logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}") # Added logging
161
+
162
+ # # --- MINIMAL FIX START ---
163
+ # # Convert PIL Image to NumPy array if necessary
164
+ # if isinstance(panorama, Image.Image):
165
+ # # Convert to grayscale NumPy array as your original code expects this structure for processing
166
+ # if panorama.mode == 'RGB':
167
+ # panorama_array = np.array(panorama.convert('L'))
168
+ # logging.info("prepare_classifier_input: Converted RGB PIL Image to grayscale NumPy array.")
169
+ # else:
170
+ # panorama_array = np.array(panorama)
171
+ # logging.info("prepare_classifier_input: Converted PIL Image to grayscale NumPy array.")
172
+ # elif isinstance(panorama, np.ndarray):
173
+ # # Ensure it's treated as a grayscale array for consistency with original logic
174
+ # if panorama.ndim == 3 and panorama.shape[2] in [3, 4]: # RGB or RGBA NumPy array
175
+ # panorama_array = np.mean(panorama, axis=2).astype(panorama.dtype) # Convert to grayscale
176
+ # logging.info("prepare_classifier_input: Converted multi-channel NumPy array to grayscale.")
177
+ # else:
178
+ # panorama_array = panorama # Assume it's already grayscale 2D or (H,W,1)
179
+ # logging.info("prepare_classifier_input: Panorama is already a suitable NumPy array.")
180
+ # else:
181
+ # logging.error("prepare_classifier_input: Unsupported panorama format received. Expected PIL Image or NumPy array.")
182
+ # raise ValueError("Unsupported panorama format for classifier input.")
183
+
184
+ # # Now, ensure panorama_array has a channel dimension if it's 2D for consistency
185
+ # if panorama_array.ndim == 2:
186
+ # panorama_array = np.expand_dims(panorama_array, axis=-1) # (H, W, 1)
187
+ # logging.info("prepare_classifier_input: Expanded 2D panorama to 3D (H,W,1).")
188
+ # # --- MINIMAL FIX END ---
189
+
190
+ # H, W, _ = panorama_array.shape # Use panorama_array here
191
+ # win_h, win_w = window_size
192
+ # images = []
193
+
194
+ # for (cy, cx) in centroids:
195
+ # # Ensure coordinates are integers
196
+ # cy, cx = int(round(cy)), int(round(cx))
197
+
198
+ # x1 = int(cx - win_w / 2)
199
+ # y1 = int(cy - win_h / 2)
200
+ # x2 = x1 + win_w
201
+ # y2 = y1 + win_h
202
+
203
+ # # Skip if patch would go out of bounds
204
+ # if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
205
+ # logging.warning(f"prepare_classifier_input: Skipping centroid ({cy},{cx}) as patch is out of bounds.") # Added warning
206
+ # continue
207
+
208
+ # # Extract and normalize patch
209
+ # patch = panorama_array[y1:y2, x1:x2, 0].astype(np.float32) # Use panorama_array
210
+ # patch = patch * 2. / 255. - 1. # Keep your original normalization
211
+
212
+ # # Replicate grayscale channel to simulate RGB
213
+ # patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
214
+ # images.append(patch_color)
215
+
216
+ # return images
217
 
218
  ###
219
+ ### refactored
220
  ###
221
+ import numpy as np
222
+ from PIL import Image
223
+ import logging
224
+ from typing import List, Union, Tuple
225
+
226
+ def prepare_classifier_input(
227
+ panorama: Union[Image.Image, np.ndarray],
228
+ centroids: List[Tuple[int, int]],
229
+ window_size: List[int] = [250, 250]
230
+ ) -> List[np.ndarray]:
231
  """
232
  Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
233
 
 
253
  List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
254
  centroids that allow full window extraction within image bounds are used.
255
  """
256
+ logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}")
257
 
258
+ # Convert input to standardized NumPy array format
259
+ panorama_array = _convert_to_grayscale_array(panorama)
260
+
261
+ # Ensure we have the correct dimensions
262
+ if panorama_array.ndim == 2:
263
+ H, W = panorama_array.shape
264
+ logging.info("prepare_classifier_input: Working with 2D grayscale array.")
265
+ elif panorama_array.ndim == 3:
266
+ H, W, C = panorama_array.shape
267
+ if C == 1:
268
+ # Squeeze the single channel dimension for easier processing
269
+ panorama_array = panorama_array.squeeze(axis=2)
270
+ H, W = panorama_array.shape
271
+ logging.info("prepare_classifier_input: Squeezed single channel dimension.")
 
272
  else:
273
+ logging.error(f"prepare_classifier_input: Unexpected number of channels: {C}")
274
+ raise ValueError(f"Expected 1 channel, got {C}")
275
  else:
276
+ logging.error(f"prepare_classifier_input: Unexpected array dimensions: {panorama_array.ndim}")
277
+ raise ValueError(f"Expected 2D or 3D array, got {panorama_array.ndim}D")
278
 
 
 
 
 
 
 
 
279
  win_h, win_w = window_size
280
  images = []
281
+
282
+ logging.info(f"prepare_classifier_input: Image dimensions: {H}x{W}, Window size: {win_h}x{win_w}")
283
+ logging.info(f"prepare_classifier_input: Processing {len(centroids)} centroids")
284
 
285
+ for i, (cy, cx) in enumerate(centroids):
286
  # Ensure coordinates are integers
287
  cy, cx = int(round(cy)), int(round(cx))
288
+
289
+ # Calculate patch boundaries
290
+ half_h, half_w = win_h // 2, win_w // 2
291
+ y1 = cy - half_h
292
  y2 = y1 + win_h
293
+ x1 = cx - half_w
294
+ x2 = x1 + win_w
295
 
296
+ # Check bounds more explicitly
297
+ if y1 < 0 or x1 < 0 or y2 > H or x2 > W:
298
+ logging.warning(
299
+ f"prepare_classifier_input: Skipping centroid {i+1}/{len(centroids)} "
300
+ f"at ({cy},{cx}) - patch bounds ({y1}:{y2}, {x1}:{x2}) exceed image bounds (0:{H}, 0:{W})"
301
+ )
302
  continue
303
 
304
+ try:
305
+ # Extract patch with explicit bounds checking
306
+ patch = panorama_array[y1:y2, x1:x2].astype(np.float32)
307
+
308
+ # Verify patch dimensions
309
+ if patch.shape != (win_h, win_w):
310
+ logging.warning(
311
+ f"prepare_classifier_input: Patch {i+1} has unexpected shape {patch.shape}, "
312
+ f"expected ({win_h}, {win_w}). Skipping."
313
+ )
314
+ continue
315
+
316
+ # Normalize patch: [0, 255] -> [-1, 1]
317
+ patch_normalized = (patch * 2.0 / 255.0) - 1.0
318
+
319
+ # Convert to 3-channel RGB-like format
320
+ patch_rgb = np.stack([patch_normalized] * 3, axis=2)
321
+
322
+ images.append(patch_rgb)
323
+ logging.debug(f"prepare_classifier_input: Successfully processed centroid {i+1} at ({cy},{cx})")
324
+
325
+ except Exception as e:
326
+ logging.error(
327
+ f"prepare_classifier_input: Error processing centroid {i+1} at ({cy},{cx}): {e}"
328
+ )
329
+ continue
330
 
331
+ logging.info(f"prepare_classifier_input: Successfully extracted {len(images)} patches from {len(centroids)} centroids")
332
+
333
+ # Add diagnostic information about the output
334
+ if images:
335
+ sample_shape = images[0].shape
336
+ sample_dtype = images[0].dtype
337
+ sample_min = images[0].min()
338
+ sample_max = images[0].max()
339
+ logging.info(f"prepare_classifier_input: Output patches - Shape: {sample_shape}, Dtype: {sample_dtype}, Range: [{sample_min:.3f}, {sample_max:.3f}]")
340
+
341
+ # Verify all patches have consistent shapes
342
+ shapes = [img.shape for img in images]
343
+ if not all(shape == sample_shape for shape in shapes):
344
+ logging.warning("prepare_classifier_input: Inconsistent patch shapes detected!")
345
+ for i, shape in enumerate(shapes):
346
+ if shape != sample_shape:
347
+ logging.warning(f" Patch {i}: {shape} (expected {sample_shape})")
348
+ else:
349
+ logging.warning("prepare_classifier_input: No valid patches were extracted!")
350
+
351
  return images
352
 
353
+
354
+ def _convert_to_grayscale_array(panorama: Union[Image.Image, np.ndarray]) -> np.ndarray:
355
+ """
356
+ Helper function to convert various input formats to a standardized grayscale NumPy array.
357
+
358
+ Parameters
359
+ ----------
360
+ panorama : PIL.Image.Image or np.ndarray
361
+ Input image in various formats
362
+
363
+ Returns
364
+ -------
365
+ np.ndarray
366
+ Standardized grayscale array
367
+ """
368
+ if isinstance(panorama, Image.Image):
369
+ if panorama.mode in ['RGB', 'RGBA']:
370
+ # Convert to grayscale
371
+ panorama_array = np.array(panorama.convert('L'))
372
+ logging.info("_convert_to_grayscale_array: Converted RGB/RGBA PIL Image to grayscale.")
373
+ elif panorama.mode == 'L':
374
+ panorama_array = np.array(panorama)
375
+ logging.info("_convert_to_grayscale_array: Converted grayscale PIL Image to NumPy array.")
376
+ else:
377
+ # Handle other modes by converting to grayscale
378
+ panorama_array = np.array(panorama.convert('L'))
379
+ logging.info(f"_convert_to_grayscale_array: Converted PIL Image mode '{panorama.mode}' to grayscale.")
380
+
381
+ elif isinstance(panorama, np.ndarray):
382
+ if panorama.ndim == 2:
383
+ # Already grayscale
384
+ panorama_array = panorama.copy()
385
+ logging.info("_convert_to_grayscale_array: Using existing 2D grayscale array.")
386
+ elif panorama.ndim == 3:
387
+ if panorama.shape[2] in [3, 4]: # RGB or RGBA
388
+ # Convert to grayscale using luminance weights
389
+ if panorama.shape[2] == 3: # RGB
390
+ panorama_array = np.dot(panorama, [0.299, 0.587, 0.114]).astype(panorama.dtype)
391
+ else: # RGBA
392
+ panorama_array = np.dot(panorama[:, :, :3], [0.299, 0.587, 0.114]).astype(panorama.dtype)
393
+ logging.info("_convert_to_grayscale_array: Converted multi-channel NumPy array to grayscale using luminance weights.")
394
+ elif panorama.shape[2] == 1:
395
+ # Already single channel
396
+ panorama_array = panorama.copy()
397
+ logging.info("_convert_to_grayscale_array: Using existing single-channel array.")
398
+ else:
399
+ raise ValueError(f"Unsupported number of channels: {panorama.shape[2]}")
400
+ else:
401
+ raise ValueError(f"Unsupported array dimensions: {panorama.ndim}")
402
+ else:
403
+ raise ValueError(f"Unsupported panorama type: {type(panorama)}")
404
+
405
+ return panorama_array
406
+
407
+
408
+ ##
409
+ ## debug
410
+ ##
411
+ import numpy as np
412
+ import logging
413
+ from typing import List, Any
414
+
415
+ def debug_classification_input(patches: List[np.ndarray], model: Any = None) -> None:
416
+ """
417
+ Debug function to help identify issues in the classification pipeline.
418
+ Call this right before your classification step.
419
+
420
+ Parameters
421
+ ----------
422
+ patches : List[np.ndarray]
423
+ List of image patches from prepare_classifier_input
424
+ model : Any, optional
425
+ Your classification model (for additional debugging)
426
+ """
427
+ logging.info("=== CLASSIFICATION DEBUG INFO ===")
428
+ logging.info(f"Number of patches: {len(patches)}")
429
+
430
+ if not patches:
431
+ logging.error("No patches provided for classification!")
432
+ return
433
+
434
+ for i, patch in enumerate(patches):
435
+ logging.info(f"Patch {i}:")
436
+ logging.info(f" Shape: {patch.shape}")
437
+ logging.info(f" Dtype: {patch.dtype}")
438
+ logging.info(f" Range: [{patch.min():.3f}, {patch.max():.3f}]")
439
+ logging.info(f" Memory layout: {patch.flags}")
440
+
441
+ # Check for common issues
442
+ if np.isnan(patch).any():
443
+ logging.warning(f" Contains NaN values: {np.isnan(patch).sum()}")
444
+ if np.isinf(patch).any():
445
+ logging.warning(f" Contains infinite values: {np.isinf(patch).sum()}")
446
+
447
+ # Check if patch is contiguous (some models require this)
448
+ if not patch.flags.c_contiguous:
449
+ logging.warning(f" Patch {i} is not C-contiguous")
450
+
451
+ # Test conversion to common formats
452
+ try:
453
+ patches_array = np.array(patches)
454
+ logging.info(f"Stacked array shape: {patches_array.shape}")
455
+ logging.info(f"Stacked array dtype: {patches_array.dtype}")
456
+ except Exception as e:
457
+ logging.error(f"Failed to stack patches into array: {e}")
458
+
459
+ # Test batch preparation (common source of slice errors)
460
+ try:
461
+ if len(patches) > 0:
462
+ # Common preprocessing steps that might cause issues
463
+ test_batch = np.stack(patches, axis=0) # Shape: (batch_size, height, width, channels)
464
+ logging.info(f"Test batch shape: {test_batch.shape}")
465
+
466
+ # Test various indexing operations that might cause slice errors
467
+ test_slice = test_batch[0] # Should work
468
+ logging.info(f"Single item slice shape: {test_slice.shape}")
469
+
470
+ test_batch_slice = test_batch[:] # Should work
471
+ logging.info(f"Full batch slice shape: {test_batch_slice.shape}")
472
+
473
+ except Exception as e:
474
+ logging.error(f"Error during batch preparation testing: {e}")
475
+ logging.error(f"Error type: {type(e)}")
476
+ import traceback
477
+ logging.error(f"Traceback: {traceback.format_exc()}")
478
+
479
+ logging.info("=== END CLASSIFICATION DEBUG ===")
480
+
481
+
482
+ def safe_classify_patches(patches: List[np.ndarray], classify_func, **kwargs) -> Any:
483
+ """
484
+ Wrapper function to safely run classification with better error handling.
485
+
486
+ Parameters
487
+ ----------
488
+ patches : List[np.ndarray]
489
+ List of image patches
490
+ classify_func : callable
491
+ Your classification function
492
+ **kwargs
493
+ Additional arguments for classify_func
494
+
495
+ Returns
496
+ -------
497
+ Any
498
+ Classification results or None if error occurred
499
+ """
500
+ try:
501
+ logging.info("Starting safe classification...")
502
+
503
+ # Debug the input
504
+ debug_classification_input(patches)
505
+
506
+ # Ensure patches are properly formatted
507
+ if not patches:
508
+ logging.error("No patches to classify")
509
+ return None
510
+
511
+ # Make sure all patches are contiguous arrays
512
+ patches_clean = []
513
+ for i, patch in enumerate(patches):
514
+ if not patch.flags.c_contiguous:
515
+ patch_clean = np.ascontiguousarray(patch)
516
+ logging.info(f"Made patch {i} contiguous")
517
+ else:
518
+ patch_clean = patch
519
+ patches_clean.append(patch_clean)
520
+
521
+ # Call the actual classification function
522
+ logging.info("Calling classification function...")
523
+ result = classify_func(patches_clean, **kwargs)
524
+ logging.info("Classification completed successfully")
525
+
526
+ return result
527
+
528
+ except Exception as e:
529
+ logging.error(f"Error in safe_classify_patches: {e}")
530
+ logging.error(f"Error type: {type(e)}")
531
+ import traceback
532
+ logging.error(f"Full traceback: {traceback.format_exc()}")
533
+ return None
534
+
535
+
536
+ # Example usage function
537
+ def example_usage():
538
+ """
539
+ Example of how to use the debug functions in your pipeline
540
+ """
541
+ # Your existing code that calls prepare_classifier_input
542
+ # patches = prepare_classifier_input(panorama, centroids, window_size)
543
+
544
+ # Add debugging before classification
545
+ # debug_classification_input(patches)
546
+
547
+ # Use safe wrapper for classification
548
+ # results = safe_classify_patches(patches, your_classify_function, model=your_model)
549
+
550
+ pass
551
+
552
+
553
+ ########################################
554
+ ##
555
+ ##
556
+ ########################################
557
+ def extract_predictions_from_tfsm(model_output):
558
+ """
559
+ Helper function to extract predictions from TFSMLayer output.
560
+ TFSMLayer often returns a dictionary with multiple outputs.
561
+ """
562
+ logging.info(f"Model output type: {type(model_output)}")
563
+ logging.info(f"Model output keys: {model_output.keys() if isinstance(model_output, dict) else 'Not a dict'}")
564
+
565
+ if isinstance(model_output, dict):
566
+ # Try common output key names
567
+ possible_keys = ['output', 'predictions', 'dense', 'logits', 'probabilities']
568
+
569
+ # First, log all available keys
570
+ available_keys = list(model_output.keys())
571
+ logging.info(f"Available output keys: {available_keys}")
572
+
573
+ # Try to find the right output
574
+ for key in possible_keys:
575
+ if key in model_output:
576
+ logging.info(f"Using output key: {key}")
577
+ return model_output[key].numpy()
578
+
579
+ # If no standard key found, use the first available key
580
+ if available_keys:
581
+ first_key = available_keys[0]
582
+ logging.info(f"Using first available key: {first_key}")
583
+ return model_output[first_key].numpy()
584
+ else:
585
+ raise ValueError("No output keys found in model response")
586
+ else:
587
+ # If it's not a dictionary, assume it's already the tensor we need
588
+ logging.info("Model output is not a dictionary, using directly")
589
+ return model_output.numpy() if hasattr(model_output, 'numpy') else np.array(model_output)