2023-12-15 20:31:28 +00:00
|
|
|
"""Classes for loading model weights"""
|
|
|
|
|
2022-12-20 09:43:04 +00:00
|
|
|
import torch
|
|
|
|
|
2024-01-15 00:50:17 +00:00
|
|
|
from imaginairy.utils.downloads import get_cached_url_path
|
2022-12-20 09:43:04 +00:00
|
|
|
|
|
|
|
|
2023-09-29 08:13:50 +00:00
|
|
|
class BaseModel(torch.nn.Module):
|
2022-12-20 09:43:04 +00:00
|
|
|
def load(self, path):
|
|
|
|
"""
|
|
|
|
Load model from file.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path (str): file path
|
|
|
|
"""
|
2023-12-18 20:13:42 +00:00
|
|
|
ckpt_path = get_cached_url_path(path, category="weights")
|
2022-12-20 09:43:04 +00:00
|
|
|
parameters = torch.load(ckpt_path, map_location=torch.device("cpu"))
|
2023-12-18 20:13:42 +00:00
|
|
|
parameters = {
|
|
|
|
k: v for k, v in parameters.items() if "relative_position_index" not in k
|
|
|
|
}
|
2022-12-20 09:43:04 +00:00
|
|
|
if "optimizer" in parameters:
|
|
|
|
parameters = parameters["model"]
|
|
|
|
|
|
|
|
self.load_state_dict(parameters)
|