@ -60,7 +60,7 @@ def main():
revision=args.revision,
torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
load_in_8bit=args.torch_dtype == "int8",
device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"},
device_map="auto" if args.torch_dtype == "int8" else None,
)
if args.torch_dtype == "int8":
# trigger weight quantization