allow to get smol boi or big boi

main
cassanof 10 months ago
parent 0a9ebab46b
commit f748264b35

@ -20,8 +20,12 @@ def model_factory(model_name: str) -> ModelBase:
return GPT35() return GPT35()
elif model_name == "starchat": elif model_name == "starchat":
return StarChat() return StarChat()
elif model_name == "codellama": elif model_name.startswith("codellama"):
return CodeLlama() # if it has `-` in the name, version was specified
kwargs = {}
if "-" in model_name:
kwargs["version"] = model_name.split("-")[1]
return CodeLlama(**kwargs)
elif model_name.startswith("text-davinci"): elif model_name.startswith("text-davinci"):
return GPTDavinci(model_name) return GPTDavinci(model_name)
else: else:

@ -198,17 +198,17 @@ You are a helpful, respectful and honest assistant. Always answer as helpfully a
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
def __init__(self): def __init__(self, version: Literal["34b", "13b", "7b"] = "34b"):
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
"codellama/CodeLlama-34b-Instruct-hf", f"codellama/CodeLlama-{version}-Instruct-hf",
add_eos_token=True, add_eos_token=True,
add_bos_token=True, add_bos_token=True,
padding_side='left' padding_side='left'
) )
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
"codellama/CodeLlama-34b-Instruct-hf", f"codellama/CodeLlama-{version}-Instruct-hf",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map="auto", device_map="auto",
) )

Loading…
Cancel
Save