diff --git a/programming_runs/generators/factory.py b/programming_runs/generators/factory.py index 3772bed..5d41326 100644 --- a/programming_runs/generators/factory.py +++ b/programming_runs/generators/factory.py @@ -1,7 +1,7 @@ from .py_generate import PyGenerator from .rs_generate import RsGenerator from .generator_types import Generator -from .model import ModelBase, GPT4, GPT35, StarChat, GPTDavinci +from .model import CodeLlama, ModelBase, GPT4, GPT35, StarChat, GPTDavinci def generator_factory(lang: str) -> Generator: @@ -20,6 +20,8 @@ def model_factory(model_name: str) -> ModelBase: return GPT35() elif model_name == "starchat": return StarChat() + elif model_name == "codellama": + return CodeLlama() elif model_name.startswith("text-davinci"): return GPTDavinci(model_name) else: diff --git a/programming_runs/generators/model.py b/programming_runs/generators/model.py index e5c5c46..f7434d7 100644 --- a/programming_runs/generators/model.py +++ b/programming_runs/generators/model.py @@ -170,7 +170,7 @@ class StarChat(HFModelBase): tokenizer = AutoTokenizer.from_pretrained( "HuggingFaceH4/starchat-beta", ) - super().__init__("star-chat", model, tokenizer, eos_token_id=49155) + super().__init__("starchat", model, tokenizer, eos_token_id=49155) def prepare_prompt(self, messages: List[Message]) -> List[int]: prompt = "" @@ -212,7 +212,7 @@ If a question does not make any sense, or is not factually coherent, explain why torch_dtype=torch.bfloat16, device_map="auto", ) - super().__init__("code-llama", model, tokenizer) + super().__init__("codellama", model, tokenizer) def prepare_prompt(self, messages: List[Message]) -> str: if messages[0].role != "system":