Update LlamaCpp parameters (#2411)

Add `n_batch` and `last_n_tokens_size` parameters to the LlamaCpp class.
These parameters (epecially `n_batch`) significantly effect performance.
There's also a `verbose` flag that prints system timings on the `Llama`
class but I wasn't sure where to add this as it conflicts with (should
be pulled from?) the LLM base class.
This commit is contained in:
Andrei 2023-04-04 22:52:33 -04:00 committed by GitHub
parent b026a62bc4
commit e519a81a05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 0 deletions

View File

@ -49,6 +49,10 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
"""Number of threads to use. If None, the number
of threads is automatically determined."""
n_batch: Optional[int] = Field(8, alias="n_batch")
"""Number of tokens to process in parallel.
Should be a number between 1 and n_ctx."""
class Config:
"""Configuration for this pydantic object."""
@ -66,6 +70,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
vocab_only = values["vocab_only"]
use_mlock = values["use_mlock"]
n_threads = values["n_threads"]
n_batch = values["n_batch"]
try:
from llama_cpp import Llama
@ -80,6 +85,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
vocab_only=vocab_only,
use_mlock=use_mlock,
n_threads=n_threads,
n_batch=n_batch,
embedding=True,
)
except ImportError:

View File

@ -53,6 +53,10 @@ class LlamaCpp(LLM, BaseModel):
"""Number of threads to use.
If None, the number of threads is automatically determined."""
n_batch: Optional[int] = Field(8, alias="n_batch")
"""Number of tokens to process in parallel.
Should be a number between 1 and n_ctx."""
suffix: Optional[str] = Field(None)
"""A suffix to append to the generated text. If None, no suffix is appended."""
@ -80,6 +84,9 @@ class LlamaCpp(LLM, BaseModel):
top_k: Optional[int] = 40
"""The top-k value to use for sampling."""
last_n_tokens_size: Optional[int] = 64
"""The number of tokens to look back when applying the repeat_penalty."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
@ -92,6 +99,8 @@ class LlamaCpp(LLM, BaseModel):
vocab_only = values["vocab_only"]
use_mlock = values["use_mlock"]
n_threads = values["n_threads"]
n_batch = values["n_batch"]
last_n_tokens_size = values["last_n_tokens_size"]
try:
from llama_cpp import Llama
@ -106,6 +115,8 @@ class LlamaCpp(LLM, BaseModel):
vocab_only=vocab_only,
use_mlock=use_mlock,
n_threads=n_threads,
n_batch=n_batch,
last_n_tokens_size=last_n_tokens_size,
)
except ImportError:
raise ModuleNotFoundError(