Set device_map only for int8

pull/273/head
Max Ryabinin 1 year ago
parent d70019f2b6
commit 556f0fabe0

@ -60,7 +60,7 @@ def main():
revision=args.revision,
torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
load_in_8bit=args.torch_dtype == "int8",
device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"},
device_map="auto" if args.torch_dtype == "int8" else None,
)
if args.torch_dtype == "int8":
# trigger weight quantization

Loading…
Cancel
Save