from tensorflow import keras import numpy as np # import tensorflow as tf # loaded = tf.saved_model.load('rwthmaterials_dp800_network1_inclusion') # print("Available endpoints:", list(loaded.signatures.keys())) # Load the model model = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.keras') # Inspect model inputs and outputs print("Model Summary:") model.summary() print("Inputs:") for i, input_tensor in enumerate(model.inputs): print(f"Input {i+1}: name={input_tensor.name}, shape={input_tensor.shape}") print("Outputs:") for i, output_tensor in enumerate(model.outputs): print(f"Output {i+1}: name={output_tensor.name}, shape={output_tensor.shape}") # Generate a wrapper function based on input count def generate_wrapper(model): def wrapper(*args): # Convert inputs to numpy arrays and reshape if needed processed_inputs = [] for i, input_tensor in enumerate(model.inputs): shape = input_tensor.shape # Replace None with 1 for batch dimension input_shape = [dim if dim is not None else 1 for dim in shape] arr = np.array(args[i]).reshape(input_shape) processed_inputs.append(arr) # Predict prediction = model.predict(processed_inputs) return prediction.tolist() return wrapper # Create the wrapper predict_fn = generate_wrapper(model) # Example usage with dummy data # Replace with actual input data when integrating with Gradio # dummy_input1 = np.random.rand(1, 6, 6, 2048) # dummy_input2 = np.random.rand(1, 6, 6, 2048) # print(predict_fn(dummy_input1, dummy_input2))