fix: add trust remote code HF models (#102)

pull/104/head
Laurel Orr 12 months ago committed by GitHub
parent 7285fee140
commit b745617045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -386,6 +386,7 @@ class CrossModalEncoderModel(HuggingFaceModel):
).from_pretrained(
self.model_path,
cache_dir=cache_dir,
trust_remote_code=True,
)
model.eval()
@ -496,6 +497,7 @@ class TextGenerationModel(HuggingFaceModel):
load_in_8bit=True,
device_map="auto",
max_memory=max_memory,
trust_remote_code=True,
)
else:
try:
@ -507,12 +509,16 @@ class TextGenerationModel(HuggingFaceModel):
cache_dir=cache_dir,
revision="float16",
torch_dtype=torch.float16,
trust_remote_code=True,
)
except Exception:
model = MODEL_REGISTRY.get(
self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
).from_pretrained( # type: ignore
self.model_path, cache_dir=cache_dir, torch_dtype=dtype
self.model_path,
cache_dir=cache_dir,
torch_dtype=dtype,
trust_remote_code=True,
)
model.eval()
print(f"Loaded Model DType {model.dtype}")

Loading…
Cancel
Save