From 2b9f1cea4e1fd1052f8216a92236404d25789568 Mon Sep 17 00:00:00 2001 From: Beau Horenberger <36315656+horenbergerb@users.noreply.github.com> Date: Mon, 24 Apr 2023 21:31:14 -0400 Subject: [PATCH] add LoRA loading for the LlamaCpp LLM (#3363) First PR, let me know if this needs anything like unit tests, reformatting, etc. Seemed pretty straightforward to implement. Only hitch was that mmap needs to be disabled when loading LoRAs or else you segfault. --- langchain/llms/llamacpp.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/langchain/llms/llamacpp.py b/langchain/llms/llamacpp.py index 8078b48d77..b74160841e 100644 --- a/langchain/llms/llamacpp.py +++ b/langchain/llms/llamacpp.py @@ -27,6 +27,12 @@ class LlamaCpp(LLM): model_path: str """The path to the Llama model file.""" + lora_base: Optional[str] = None + """The path to the Llama LoRA base model.""" + + lora_path: Optional[str] = None + """The path to the Llama LoRA. If None, no LoRa is loaded.""" + n_ctx: int = Field(512, alias="n_ctx") """Token context window.""" @@ -87,6 +93,9 @@ class LlamaCpp(LLM): last_n_tokens_size: Optional[int] = 64 """The number of tokens to look back when applying the repeat_penalty.""" + use_mmap: Optional[bool] = True + """Whether to keep the model loaded in RAM""" + streaming: bool = True """Whether to stream the results, token by token.""" @@ -94,6 +103,8 @@ class LlamaCpp(LLM): def validate_environment(cls, values: Dict) -> Dict: """Validate that llama-cpp-python library is installed.""" model_path = values["model_path"] + lora_path = values["lora_path"] + lora_base = values["lora_base"] n_ctx = values["n_ctx"] n_parts = values["n_parts"] seed = values["seed"] @@ -103,6 +114,7 @@ class LlamaCpp(LLM): use_mlock = values["use_mlock"] n_threads = values["n_threads"] n_batch = values["n_batch"] + use_mmap = values["use_mmap"] last_n_tokens_size = values["last_n_tokens_size"] try: @@ -110,6 +122,8 @@ class LlamaCpp(LLM): values["client"] = Llama( model_path=model_path, + lora_base=lora_base, + lora_path=lora_path, n_ctx=n_ctx, n_parts=n_parts, seed=seed, @@ -119,6 +133,7 @@ class LlamaCpp(LLM): use_mlock=use_mlock, n_threads=n_threads, n_batch=n_batch, + use_mmap=use_mmap, last_n_tokens_size=last_n_tokens_size, ) except ImportError: