add LoRA loading for the LlamaCpp LLM (#3363)

First PR, let me know if this needs anything like unit tests,
reformatting, etc. Seemed pretty straightforward to implement. Only
hitch was that mmap needs to be disabled when loading LoRAs or else you
segfault.
This commit is contained in:
Beau Horenberger 2023-04-24 21:31:14 -04:00 committed by GitHub
parent 5d0674fb46
commit 2b9f1cea4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,6 +27,12 @@ class LlamaCpp(LLM):
model_path: str model_path: str
"""The path to the Llama model file.""" """The path to the Llama model file."""
lora_base: Optional[str] = None
"""The path to the Llama LoRA base model."""
lora_path: Optional[str] = None
"""The path to the Llama LoRA. If None, no LoRa is loaded."""
n_ctx: int = Field(512, alias="n_ctx") n_ctx: int = Field(512, alias="n_ctx")
"""Token context window.""" """Token context window."""
@ -87,6 +93,9 @@ class LlamaCpp(LLM):
last_n_tokens_size: Optional[int] = 64 last_n_tokens_size: Optional[int] = 64
"""The number of tokens to look back when applying the repeat_penalty.""" """The number of tokens to look back when applying the repeat_penalty."""
use_mmap: Optional[bool] = True
"""Whether to keep the model loaded in RAM"""
streaming: bool = True streaming: bool = True
"""Whether to stream the results, token by token.""" """Whether to stream the results, token by token."""
@ -94,6 +103,8 @@ class LlamaCpp(LLM):
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed.""" """Validate that llama-cpp-python library is installed."""
model_path = values["model_path"] model_path = values["model_path"]
lora_path = values["lora_path"]
lora_base = values["lora_base"]
n_ctx = values["n_ctx"] n_ctx = values["n_ctx"]
n_parts = values["n_parts"] n_parts = values["n_parts"]
seed = values["seed"] seed = values["seed"]
@ -103,6 +114,7 @@ class LlamaCpp(LLM):
use_mlock = values["use_mlock"] use_mlock = values["use_mlock"]
n_threads = values["n_threads"] n_threads = values["n_threads"]
n_batch = values["n_batch"] n_batch = values["n_batch"]
use_mmap = values["use_mmap"]
last_n_tokens_size = values["last_n_tokens_size"] last_n_tokens_size = values["last_n_tokens_size"]
try: try:
@ -110,6 +122,8 @@ class LlamaCpp(LLM):
values["client"] = Llama( values["client"] = Llama(
model_path=model_path, model_path=model_path,
lora_base=lora_base,
lora_path=lora_path,
n_ctx=n_ctx, n_ctx=n_ctx,
n_parts=n_parts, n_parts=n_parts,
seed=seed, seed=seed,
@ -119,6 +133,7 @@ class LlamaCpp(LLM):
use_mlock=use_mlock, use_mlock=use_mlock,
n_threads=n_threads, n_threads=n_threads,
n_batch=n_batch, n_batch=n_batch,
use_mmap=use_mmap,
last_n_tokens_size=last_n_tokens_size, last_n_tokens_size=last_n_tokens_size,
) )
except ImportError: except ImportError: