|
|
|
@ -386,6 +386,7 @@ class CrossModalEncoderModel(HuggingFaceModel):
|
|
|
|
|
).from_pretrained(
|
|
|
|
|
self.model_path,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
)
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
@ -496,6 +497,7 @@ class TextGenerationModel(HuggingFaceModel):
|
|
|
|
|
load_in_8bit=True,
|
|
|
|
|
device_map="auto",
|
|
|
|
|
max_memory=max_memory,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
@ -507,12 +509,16 @@ class TextGenerationModel(HuggingFaceModel):
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
revision="float16",
|
|
|
|
|
torch_dtype=torch.float16,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
model = MODEL_REGISTRY.get(
|
|
|
|
|
self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
|
|
|
|
|
).from_pretrained( # type: ignore
|
|
|
|
|
self.model_path, cache_dir=cache_dir, torch_dtype=dtype
|
|
|
|
|
self.model_path,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
torch_dtype=dtype,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
)
|
|
|
|
|
model.eval()
|
|
|
|
|
print(f"Loaded Model DType {model.dtype}")
|
|
|
|
|