From f748264b35924af12eb224e7d58cfc6605893336 Mon Sep 17 00:00:00 2001 From: cassanof Date: Mon, 28 Aug 2023 15:56:24 -0700 Subject: [PATCH] allow to get smol boi or big boi --- programming_runs/generators/factory.py | 8 ++++++-- programming_runs/generators/model.py | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/programming_runs/generators/factory.py b/programming_runs/generators/factory.py index 5d41326..207aad2 100644 --- a/programming_runs/generators/factory.py +++ b/programming_runs/generators/factory.py @@ -20,8 +20,12 @@ def model_factory(model_name: str) -> ModelBase: return GPT35() elif model_name == "starchat": return StarChat() - elif model_name == "codellama": - return CodeLlama() + elif model_name.startswith("codellama"): + # if it has `-` in the name, version was specified + kwargs = {} + if "-" in model_name: + kwargs["version"] = model_name.split("-")[1] + return CodeLlama(**kwargs) elif model_name.startswith("text-davinci"): return GPTDavinci(model_name) else: diff --git a/programming_runs/generators/model.py b/programming_runs/generators/model.py index 48c7360..878db2d 100644 --- a/programming_runs/generators/model.py +++ b/programming_runs/generators/model.py @@ -198,17 +198,17 @@ You are a helpful, respectful and honest assistant. Always answer as helpfully a If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" - def __init__(self): + def __init__(self, version: Literal["34b", "13b", "7b"] = "34b"): import torch from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( - "codellama/CodeLlama-34b-Instruct-hf", + f"codellama/CodeLlama-{version}-Instruct-hf", add_eos_token=True, add_bos_token=True, padding_side='left' ) model = AutoModelForCausalLM.from_pretrained( - "codellama/CodeLlama-34b-Instruct-hf", + f"codellama/CodeLlama-{version}-Instruct-hf", torch_dtype=torch.bfloat16, device_map="auto", )