aknapitsch user commited on
Commit
3d74194
·
1 Parent(s): 6e553ac

fallback for from_pretrained model loading

Browse files
Files changed (1) hide show
  1. app.py +62 -51
app.py CHANGED
@@ -30,6 +30,7 @@ from hf_utils.css_and_html import (
30
  get_header_html,
31
  )
32
  from hf_utils.visual_util import predictions_to_glb
 
33
  from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
34
  from mapanything.utils.image import load_images, rgb
35
 
@@ -149,68 +150,78 @@ def run_model(target_dir, model_placeholder, apply_mask=True, mask_edges=True):
149
  high_level_config["path"], overrides=high_level_config["config_overrides"]
150
  )
151
 
152
- print("Loading MapAnything model...")
153
- # Create model from local configuration instead of using from_pretrained
154
- from mapanything.models import init_model
155
-
156
- model = init_model(
157
- model_str=cfg.model.model_str,
158
- model_config=cfg.model.model_config,
159
- torch_hub_force_reload=high_level_config.get(
160
- "torch_hub_force_reload", False
161
- ),
162
- )
163
-
164
- # Load the pretrained weights from HuggingFace Hub
165
  try:
166
- from huggingface_hub import hf_hub_download, list_repo_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- # First, let's see what files are available in the repository
169
  try:
170
- repo_files = list_repo_files(
171
- repo_id=high_level_config["hf_model_name"], token=load_hf_token()
172
- )
173
- print(f"Available files in repository: {repo_files}")
174
-
175
- checkpoint_filename = "model.safetensors"
176
 
177
- # Download the model weights
178
- checkpoint_path = hf_hub_download(
179
- repo_id=high_level_config["hf_model_name"],
180
- filename=checkpoint_filename,
181
- token=load_hf_token(),
182
- )
183
 
184
- # Load the weights
185
- print("start loading checkpoint")
186
- if checkpoint_filename.endswith(".safetensors"):
187
- from safetensors.torch import load_file
188
 
189
- checkpoint = load_file(checkpoint_path)
190
- else:
191
- checkpoint = torch.load(
192
- checkpoint_path, map_location="cpu", weights_only=True
 
193
  )
194
 
195
- print("start loading state_dict")
196
- if "model" in checkpoint:
197
- model.load_state_dict(checkpoint["model"])
198
- elif "state_dict" in checkpoint:
199
- model.load_state_dict(checkpoint["state_dict"])
200
- else:
201
- model.load_state_dict(checkpoint)
202
 
203
- print(
204
- f"Successfully loaded pretrained weights from HuggingFace Hub ({checkpoint_filename})"
205
- )
 
 
 
 
 
 
 
 
 
 
206
 
207
- except Exception as inner_e:
208
- print(f"Error listing repository files or loading weights: {inner_e}")
209
- raise inner_e
210
 
211
- except Exception as e:
212
- print(f"Warning: Could not load pretrained weights: {e}")
213
- print("Proceeding with randomly initialized model...")
 
 
 
 
214
 
215
  model = model.to(device)
216
 
 
30
  get_header_html,
31
  )
32
  from hf_utils.visual_util import predictions_to_glb
33
+ from mapanything.models import MapAnything
34
  from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
35
  from mapanything.utils.image import load_images, rgb
36
 
 
150
  high_level_config["path"], overrides=high_level_config["config_overrides"]
151
  )
152
 
153
+ # Try using from_pretrained first
 
 
 
 
 
 
 
 
 
 
 
 
154
  try:
155
+ print("Loading MapAnything model from_pretrained...")
156
+ model = MapAnything.from_pretrained(high_level_config["hf_model_name"]).to(
157
+ device
158
+ )
159
+ print("Loading MapAnything model from_pretrained succeeded...")
160
+ except Exception as e:
161
+ print(f"from_pretrained failed: {e}")
162
+ print("Falling back to local configuration approach...")
163
+
164
+ # Create model from local configuration instead of using from_pretrained
165
+ from mapanything.models import init_model
166
+
167
+ model = init_model(
168
+ model_str=cfg.model.model_str,
169
+ model_config=cfg.model.model_config,
170
+ torch_hub_force_reload=high_level_config.get(
171
+ "torch_hub_force_reload", False
172
+ ),
173
+ )
174
 
175
+ # Load the pretrained weights from HuggingFace Hub
176
  try:
177
+ from huggingface_hub import hf_hub_download, list_repo_files
 
 
 
 
 
178
 
179
+ # First, let's see what files are available in the repository
180
+ try:
181
+ repo_files = list_repo_files(
182
+ repo_id=high_level_config["hf_model_name"], token=load_hf_token()
183
+ )
184
+ print(f"Available files in repository: {repo_files}")
185
 
186
+ checkpoint_filename = "model.safetensors"
 
 
 
187
 
188
+ # Download the model weights
189
+ checkpoint_path = hf_hub_download(
190
+ repo_id=high_level_config["hf_model_name"],
191
+ filename=checkpoint_filename,
192
+ token=load_hf_token(),
193
  )
194
 
195
+ # Load the weights
196
+ print("start loading checkpoint")
197
+ if checkpoint_filename.endswith(".safetensors"):
198
+ from safetensors.torch import load_file
 
 
 
199
 
200
+ checkpoint = load_file(checkpoint_path)
201
+ else:
202
+ checkpoint = torch.load(
203
+ checkpoint_path, map_location="cpu", weights_only=True
204
+ )
205
+
206
+ print("start loading state_dict")
207
+ if "model" in checkpoint:
208
+ model.load_state_dict(checkpoint["model"])
209
+ elif "state_dict" in checkpoint:
210
+ model.load_state_dict(checkpoint["state_dict"])
211
+ else:
212
+ model.load_state_dict(checkpoint)
213
 
214
+ print(
215
+ f"Successfully loaded pretrained weights from HuggingFace Hub ({checkpoint_filename})"
216
+ )
217
 
218
+ except Exception as inner_e:
219
+ print(f"Error listing repository files or loading weights: {inner_e}")
220
+ raise inner_e
221
+
222
+ except Exception as e:
223
+ print(f"Warning: Could not load pretrained weights: {e}")
224
+ print("Proceeding with randomly initialized model...")
225
 
226
  model = model.to(device)
227