fix: map eos,bos hopefully stops

eval
Zach Nussbaum 2 years ago
parent f51c5c8109
commit cb43f53f7a

@ -11,6 +11,7 @@ def generate(tokenizer, prompt, model, config):
outputs = model.generate(input_ids=input_ids, max_new_tokens=config["max_new_tokens"], temperature=config["temperature"])
print(outputs)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return decoded[len(prompt):]
@ -19,6 +20,7 @@ def generate(tokenizer, prompt, model, config):
def setup_model(config):
model = AutoModelForCausalLM.from_pretrained(config["model_name"], device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>"})
if config["lora"]:
model = PeftModelForCausalLM.from_pretrained(model, config["lora_path"], device_map="auto", torch_dtype=torch.float16)
@ -33,17 +35,22 @@ def setup_model(config):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--prompt", type=str)
args = parser.parse_args()
config = read_config(args.config)
print("setting up model")
if config["prompt"] is None and args.prompt is None:
raise ValueError("Prompt is required either in config or as argument")
prompt = config["prompt"] if args.prompt is None else args.prompt
print("Setting up model")
model, tokenizer = setup_model(config)
print("generating")
print("Generating")
start = time.time()
generation = generate(tokenizer, args.prompt, model, config)
print(f"done in {time.time() - start:.2f}s")
generation = generate(tokenizer, prompt, model, config)
print(f"Done in {time.time() - start:.2f}s")
print(generation)
Loading…
Cancel
Save