main
cassanof 9 months ago
parent afcfd427a2
commit 0365be2c6e

@ -1,7 +1,7 @@
from .py_generate import PyGenerator
from .rs_generate import RsGenerator
from .generator_types import Generator
from .model import ModelBase, GPT4, GPT35, StarChat, GPTDavinci
from .model import CodeLlama, ModelBase, GPT4, GPT35, StarChat, GPTDavinci
def generator_factory(lang: str) -> Generator:
@ -20,6 +20,8 @@ def model_factory(model_name: str) -> ModelBase:
return GPT35()
elif model_name == "starchat":
return StarChat()
elif model_name == "codellama":
return CodeLlama()
elif model_name.startswith("text-davinci"):
return GPTDavinci(model_name)
else:

@ -170,7 +170,7 @@ class StarChat(HFModelBase):
tokenizer = AutoTokenizer.from_pretrained(
"HuggingFaceH4/starchat-beta",
)
super().__init__("star-chat", model, tokenizer, eos_token_id=49155)
super().__init__("starchat", model, tokenizer, eos_token_id=49155)
def prepare_prompt(self, messages: List[Message]) -> List[int]:
prompt = ""
@ -212,7 +212,7 @@ If a question does not make any sense, or is not factually coherent, explain why
torch_dtype=torch.bfloat16,
device_map="auto",
)
super().__init__("code-llama", model, tokenizer)
super().__init__("codellama", model, tokenizer)
def prepare_prompt(self, messages: List[Message]) -> str:
if messages[0].role != "system":

Loading…
Cancel
Save