mirror of
https://github.com/HazyResearch/manifest
synced 2024-10-31 15:20:26 +00:00
fix: add trust remote code HF models
This commit is contained in:
parent
7285fee140
commit
0dfcab728a
@ -386,6 +386,7 @@ class CrossModalEncoderModel(HuggingFaceModel):
|
||||
).from_pretrained(
|
||||
self.model_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
@ -496,6 +497,7 @@ class TextGenerationModel(HuggingFaceModel):
|
||||
load_in_8bit=True,
|
||||
device_map="auto",
|
||||
max_memory=max_memory,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@ -507,12 +509,16 @@ class TextGenerationModel(HuggingFaceModel):
|
||||
cache_dir=cache_dir,
|
||||
revision="float16",
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
except Exception:
|
||||
model = MODEL_REGISTRY.get(
|
||||
self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
|
||||
).from_pretrained( # type: ignore
|
||||
self.model_path, cache_dir=cache_dir, torch_dtype=dtype
|
||||
self.model_path,
|
||||
cache_dir=cache_dir,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model.eval()
|
||||
print(f"Loaded Model DType {model.dtype}")
|
||||
|
Loading…
Reference in New Issue
Block a user