add LoRA loading for the LlamaCpp LLM (#3363)

First PR, let me know if this needs anything like unit tests,
reformatting, etc. Seemed pretty straightforward to implement. Only
hitch was that mmap needs to be disabled when loading LoRAs or else you
segfault.
pull/3477/head
Beau Horenberger 1 year ago committed by GitHub
parent 5d0674fb46
commit 2b9f1cea4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,6 +27,12 @@ class LlamaCpp(LLM):
model_path: str
"""The path to the Llama model file."""
lora_base: Optional[str] = None
"""The path to the Llama LoRA base model."""
lora_path: Optional[str] = None
"""The path to the Llama LoRA. If None, no LoRa is loaded."""
n_ctx: int = Field(512, alias="n_ctx")
"""Token context window."""
@ -87,6 +93,9 @@ class LlamaCpp(LLM):
last_n_tokens_size: Optional[int] = 64
"""The number of tokens to look back when applying the repeat_penalty."""
use_mmap: Optional[bool] = True
"""Whether to keep the model loaded in RAM"""
streaming: bool = True
"""Whether to stream the results, token by token."""
@ -94,6 +103,8 @@ class LlamaCpp(LLM):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
model_path = values["model_path"]
lora_path = values["lora_path"]
lora_base = values["lora_base"]
n_ctx = values["n_ctx"]
n_parts = values["n_parts"]
seed = values["seed"]
@ -103,6 +114,7 @@ class LlamaCpp(LLM):
use_mlock = values["use_mlock"]
n_threads = values["n_threads"]
n_batch = values["n_batch"]
use_mmap = values["use_mmap"]
last_n_tokens_size = values["last_n_tokens_size"]
try:
@ -110,6 +122,8 @@ class LlamaCpp(LLM):
values["client"] = Llama(
model_path=model_path,
lora_base=lora_base,
lora_path=lora_path,
n_ctx=n_ctx,
n_parts=n_parts,
seed=seed,
@ -119,6 +133,7 @@ class LlamaCpp(LLM):
use_mlock=use_mlock,
n_threads=n_threads,
n_batch=n_batch,
use_mmap=use_mmap,
last_n_tokens_size=last_n_tokens_size,
)
except ImportError:

Loading…
Cancel
Save