diff --git a/langchain/embeddings/llamacpp.py b/langchain/embeddings/llamacpp.py index 8b8c6c54..44c887a8 100644 --- a/langchain/embeddings/llamacpp.py +++ b/langchain/embeddings/llamacpp.py @@ -49,6 +49,10 @@ class LlamaCppEmbeddings(BaseModel, Embeddings): """Number of threads to use. If None, the number of threads is automatically determined.""" + n_batch: Optional[int] = Field(8, alias="n_batch") + """Number of tokens to process in parallel. + Should be a number between 1 and n_ctx.""" + class Config: """Configuration for this pydantic object.""" @@ -66,6 +70,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings): vocab_only = values["vocab_only"] use_mlock = values["use_mlock"] n_threads = values["n_threads"] + n_batch = values["n_batch"] try: from llama_cpp import Llama @@ -80,6 +85,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings): vocab_only=vocab_only, use_mlock=use_mlock, n_threads=n_threads, + n_batch=n_batch, embedding=True, ) except ImportError: diff --git a/langchain/llms/llamacpp.py b/langchain/llms/llamacpp.py index 536755fe..878078f6 100644 --- a/langchain/llms/llamacpp.py +++ b/langchain/llms/llamacpp.py @@ -53,6 +53,10 @@ class LlamaCpp(LLM, BaseModel): """Number of threads to use. If None, the number of threads is automatically determined.""" + n_batch: Optional[int] = Field(8, alias="n_batch") + """Number of tokens to process in parallel. + Should be a number between 1 and n_ctx.""" + suffix: Optional[str] = Field(None) """A suffix to append to the generated text. If None, no suffix is appended.""" @@ -80,6 +84,9 @@ class LlamaCpp(LLM, BaseModel): top_k: Optional[int] = 40 """The top-k value to use for sampling.""" + last_n_tokens_size: Optional[int] = 64 + """The number of tokens to look back when applying the repeat_penalty.""" + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that llama-cpp-python library is installed.""" @@ -92,6 +99,8 @@ class LlamaCpp(LLM, BaseModel): vocab_only = values["vocab_only"] use_mlock = values["use_mlock"] n_threads = values["n_threads"] + n_batch = values["n_batch"] + last_n_tokens_size = values["last_n_tokens_size"] try: from llama_cpp import Llama @@ -106,6 +115,8 @@ class LlamaCpp(LLM, BaseModel): vocab_only=vocab_only, use_mlock=use_mlock, n_threads=n_threads, + n_batch=n_batch, + last_n_tokens_size=last_n_tokens_size, ) except ImportError: raise ModuleNotFoundError(