mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
improve llamacpp embeddings (#12972)
- **Description:** Improve llamacpp embedding class by adding the `device` parameter so it can be passed to the model and used with `gpu`, `cpu` or Apple metal (`mps`). Improve performance by making use of the bulk client api to compute embeddings in batches. - **Dependencies:** none - **Tag maintainer:** @hwchase17 --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
f882824eac
commit
654da27255
@ -57,6 +57,9 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
verbose: bool = Field(True, alias="verbose")
|
||||
"""Print verbose output to stderr."""
|
||||
|
||||
device: Optional[str] = Field(None, alias="device")
|
||||
"""Device type to use and pass to the model"""
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
@ -75,6 +78,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
"n_threads",
|
||||
"n_batch",
|
||||
"verbose",
|
||||
"device",
|
||||
]
|
||||
model_params = {k: values[k] for k in model_param_names}
|
||||
# For backwards compatibility, only include if non-null.
|
||||
@ -108,8 +112,8 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = [self.client.embed(text) for text in texts]
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
embeddings = self.client.create_embedding(texts)
|
||||
return [list(map(float, e["embedding"])) for e in embeddings["data"]]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a query using the Llama model.
|
||||
|
Loading…
Reference in New Issue
Block a user