Spaces:
Runtime error
Runtime error
Updating app file
Browse files
app.py
CHANGED
@@ -61,13 +61,13 @@ def infer(model,data, notes):
|
|
61 |
data= torch.tensor(data)
|
62 |
if model == "CNN":
|
63 |
model = MMCNN_CAT()
|
64 |
-
checkpoint = torch.load(MMCNN_CAT_ckpt_path)
|
65 |
model.load_state_dict(checkpoint['model_state_dict'])
|
66 |
data = data.transpose(1,2).float()
|
67 |
|
68 |
elif model == "RNN":
|
69 |
model = MMRNN(device='cpu')
|
70 |
-
model.load_state_dict(torch.load(MMRNN_ckpt_path)['model_state_dict'])
|
71 |
data = data.float()
|
72 |
model.eval()
|
73 |
outputs, predicted = predict(model, data, embed_notes, device='cpu')
|
|
|
61 |
data= torch.tensor(data)
|
62 |
if model == "CNN":
|
63 |
model = MMCNN_CAT()
|
64 |
+
checkpoint = torch.load(MMCNN_CAT_ckpt_path, map_location="cpu")
|
65 |
model.load_state_dict(checkpoint['model_state_dict'])
|
66 |
data = data.transpose(1,2).float()
|
67 |
|
68 |
elif model == "RNN":
|
69 |
model = MMRNN(device='cpu')
|
70 |
+
model.load_state_dict(torch.load(MMRNN_ckpt_path, map_location="cpu")['model_state_dict'])
|
71 |
data = data.float()
|
72 |
model.eval()
|
73 |
outputs, predicted = predict(model, data, embed_notes, device='cpu')
|