imaginAIry/imaginairy/modules/midas/midas/base_model.py

25 lines
666 B
Python
Raw Normal View History

"""Classes for loading model weights"""
import torch
from imaginairy.utils.downloads import get_cached_url_path
class BaseModel(torch.nn.Module):
def load(self, path):
"""
Load model from file.
Args:
path (str): file path
"""
2023-12-18 20:13:42 +00:00
ckpt_path = get_cached_url_path(path, category="weights")
parameters = torch.load(ckpt_path, map_location=torch.device("cpu"))
2023-12-18 20:13:42 +00:00
parameters = {
k: v for k, v in parameters.items() if "relative_position_index" not in k
}
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)