mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
add LoRA loading for the LlamaCpp LLM (#3363)
First PR, let me know if this needs anything like unit tests, reformatting, etc. Seemed pretty straightforward to implement. Only hitch was that mmap needs to be disabled when loading LoRAs or else you segfault.
This commit is contained in:
parent
5d0674fb46
commit
2b9f1cea4e
@ -27,6 +27,12 @@ class LlamaCpp(LLM):
|
||||
model_path: str
|
||||
"""The path to the Llama model file."""
|
||||
|
||||
lora_base: Optional[str] = None
|
||||
"""The path to the Llama LoRA base model."""
|
||||
|
||||
lora_path: Optional[str] = None
|
||||
"""The path to the Llama LoRA. If None, no LoRa is loaded."""
|
||||
|
||||
n_ctx: int = Field(512, alias="n_ctx")
|
||||
"""Token context window."""
|
||||
|
||||
@ -87,6 +93,9 @@ class LlamaCpp(LLM):
|
||||
last_n_tokens_size: Optional[int] = 64
|
||||
"""The number of tokens to look back when applying the repeat_penalty."""
|
||||
|
||||
use_mmap: Optional[bool] = True
|
||||
"""Whether to keep the model loaded in RAM"""
|
||||
|
||||
streaming: bool = True
|
||||
"""Whether to stream the results, token by token."""
|
||||
|
||||
@ -94,6 +103,8 @@ class LlamaCpp(LLM):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that llama-cpp-python library is installed."""
|
||||
model_path = values["model_path"]
|
||||
lora_path = values["lora_path"]
|
||||
lora_base = values["lora_base"]
|
||||
n_ctx = values["n_ctx"]
|
||||
n_parts = values["n_parts"]
|
||||
seed = values["seed"]
|
||||
@ -103,6 +114,7 @@ class LlamaCpp(LLM):
|
||||
use_mlock = values["use_mlock"]
|
||||
n_threads = values["n_threads"]
|
||||
n_batch = values["n_batch"]
|
||||
use_mmap = values["use_mmap"]
|
||||
last_n_tokens_size = values["last_n_tokens_size"]
|
||||
|
||||
try:
|
||||
@ -110,6 +122,8 @@ class LlamaCpp(LLM):
|
||||
|
||||
values["client"] = Llama(
|
||||
model_path=model_path,
|
||||
lora_base=lora_base,
|
||||
lora_path=lora_path,
|
||||
n_ctx=n_ctx,
|
||||
n_parts=n_parts,
|
||||
seed=seed,
|
||||
@ -119,6 +133,7 @@ class LlamaCpp(LLM):
|
||||
use_mlock=use_mlock,
|
||||
n_threads=n_threads,
|
||||
n_batch=n_batch,
|
||||
use_mmap=use_mmap,
|
||||
last_n_tokens_size=last_n_tokens_size,
|
||||
)
|
||||
except ImportError:
|
||||
|
Loading…
Reference in New Issue
Block a user