allow to get smol boi or big boi

main
cassanof 9 months ago
parent 0a9ebab46b
commit f748264b35

@ -20,8 +20,12 @@ def model_factory(model_name: str) -> ModelBase:
return GPT35()
elif model_name == "starchat":
return StarChat()
elif model_name == "codellama":
return CodeLlama()
elif model_name.startswith("codellama"):
# if it has `-` in the name, version was specified
kwargs = {}
if "-" in model_name:
kwargs["version"] = model_name.split("-")[1]
return CodeLlama(**kwargs)
elif model_name.startswith("text-davinci"):
return GPTDavinci(model_name)
else:

@ -198,17 +198,17 @@ You are a helpful, respectful and honest assistant. Always answer as helpfully a
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
def __init__(self):
def __init__(self, version: Literal["34b", "13b", "7b"] = "34b"):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"codellama/CodeLlama-34b-Instruct-hf",
f"codellama/CodeLlama-{version}-Instruct-hf",
add_eos_token=True,
add_bos_token=True,
padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
"codellama/CodeLlama-34b-Instruct-hf",
f"codellama/CodeLlama-{version}-Instruct-hf",
torch_dtype=torch.bfloat16,
device_map="auto",
)

Loading…
Cancel
Save