langchain/libs/partners/nvidia-ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py
Vadim Kudlay 5f9ac6986e
nvidia-ai-endpoints[patch]: model arguments (e.g. temperature) on construction bug (#17290)
- **Issue:** Issue with model argument support (been there for a while
actually):
- Non-specially-handled arguments like temperature don't work when
passed through constructor.
- Such arguments DO work quite well with `bind`, but also do not abide
by field requirements.
- Since initial push, server-side error messages have gotten better and
v0.0.2 raises better exceptions. So maybe it's better to let server-side
handle such issues?
- **Description:**
- Removed ChatNVIDIA's argument fields in favor of
`model_kwargs`/`model_kws` arguments which aggregates constructor kwargs
(from constructor pathway) and merges them with call kwargs (bind
pathway).
- Shuffled a few functions from `_NVIDIAClient` to `ChatNVIDIA` to
streamline construction for future integrations.
- Minor/Optional: Old services didn't have stop support, so client-side
stopping was implemented. Now do both.
- **Any Breaking Changes:** Minor breaking changes if you strongly rely
on chat_model.temperature, etc. This is captured by
chat_model.model_kwargs.

PR passes tests and example notebooks and example testing. Still gonna
chat with some people, so leaving as draft for now.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-02-09 13:46:02 -08:00

60 lines
2.4 KiB
Python

"""Embeddings Components Derived from NVEModel/Embeddings"""
from typing import List, Literal, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field
from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
class NVIDIAEmbeddings(_NVIDIAClient, Embeddings):
"""NVIDIA's AI Foundation Retriever Question-Answering Asymmetric Model."""
max_length: int = Field(2048, ge=1, le=2048)
max_batch_size: int = Field(default=50)
model_type: Optional[Literal["passage", "query"]] = Field(
"passage", description="The type of text to be embedded."
)
def _embed(
self, texts: List[str], model_type: Literal["passage", "query"]
) -> List[List[float]]:
"""Embed a single text entry to either passage or query type"""
response = self.client.get_req(
model_name=self.model,
payload={
"input": texts,
"model": model_type,
"encoding_format": "float",
},
)
response.raise_for_status()
result = response.json()
data = result["data"]
if not isinstance(data, list):
raise ValueError(f"Expected a list of embeddings. Got: {data}")
embedding_list = [(res["embedding"], res["index"]) for res in data]
return [x[0] for x in sorted(embedding_list, key=lambda x: x[1])]
def embed_query(self, text: str) -> List[float]:
"""Input pathway for query embeddings."""
return self._embed([text], model_type=self.model_type or "query")[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Input pathway for document embeddings."""
# From https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/nvolve-40k/documentation
# The input must not exceed the 2048 max input characters and inputs above 512
# model tokens will be truncated. The input array must not exceed 50 input
# strings.
all_embeddings = []
for i in range(0, len(texts), self.max_batch_size):
batch = texts[i : i + self.max_batch_size]
truncated = [
text[: self.max_length] if len(text) > self.max_length else text
for text in batch
]
all_embeddings.extend(
self._embed(truncated, model_type=self.model_type or "passage")
)
return all_embeddings