mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Update LlamaCpp parameters (#2411)
Add `n_batch` and `last_n_tokens_size` parameters to the LlamaCpp class. These parameters (epecially `n_batch`) significantly effect performance. There's also a `verbose` flag that prints system timings on the `Llama` class but I wasn't sure where to add this as it conflicts with (should be pulled from?) the LLM base class.
This commit is contained in:
parent
b026a62bc4
commit
e519a81a05
@ -49,6 +49,10 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
"""Number of threads to use. If None, the number
|
||||
of threads is automatically determined."""
|
||||
|
||||
n_batch: Optional[int] = Field(8, alias="n_batch")
|
||||
"""Number of tokens to process in parallel.
|
||||
Should be a number between 1 and n_ctx."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -66,6 +70,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
vocab_only = values["vocab_only"]
|
||||
use_mlock = values["use_mlock"]
|
||||
n_threads = values["n_threads"]
|
||||
n_batch = values["n_batch"]
|
||||
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
@ -80,6 +85,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
vocab_only=vocab_only,
|
||||
use_mlock=use_mlock,
|
||||
n_threads=n_threads,
|
||||
n_batch=n_batch,
|
||||
embedding=True,
|
||||
)
|
||||
except ImportError:
|
||||
|
@ -53,6 +53,10 @@ class LlamaCpp(LLM, BaseModel):
|
||||
"""Number of threads to use.
|
||||
If None, the number of threads is automatically determined."""
|
||||
|
||||
n_batch: Optional[int] = Field(8, alias="n_batch")
|
||||
"""Number of tokens to process in parallel.
|
||||
Should be a number between 1 and n_ctx."""
|
||||
|
||||
suffix: Optional[str] = Field(None)
|
||||
"""A suffix to append to the generated text. If None, no suffix is appended."""
|
||||
|
||||
@ -80,6 +84,9 @@ class LlamaCpp(LLM, BaseModel):
|
||||
top_k: Optional[int] = 40
|
||||
"""The top-k value to use for sampling."""
|
||||
|
||||
last_n_tokens_size: Optional[int] = 64
|
||||
"""The number of tokens to look back when applying the repeat_penalty."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that llama-cpp-python library is installed."""
|
||||
@ -92,6 +99,8 @@ class LlamaCpp(LLM, BaseModel):
|
||||
vocab_only = values["vocab_only"]
|
||||
use_mlock = values["use_mlock"]
|
||||
n_threads = values["n_threads"]
|
||||
n_batch = values["n_batch"]
|
||||
last_n_tokens_size = values["last_n_tokens_size"]
|
||||
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
@ -106,6 +115,8 @@ class LlamaCpp(LLM, BaseModel):
|
||||
vocab_only=vocab_only,
|
||||
use_mlock=use_mlock,
|
||||
n_threads=n_threads,
|
||||
n_batch=n_batch,
|
||||
last_n_tokens_size=last_n_tokens_size,
|
||||
)
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
|
Loading…
Reference in New Issue
Block a user