mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
91bcc9c5c9
This PR is adding support for NVIDIA NeMo embeddings issue #16095. --------- Co-authored-by: Praveen Nakshatrala <pnakshatrala@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
170 lines
5.3 KiB
Python
170 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import aiohttp
|
|
import requests
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
|
|
|
|
|
def is_endpoint_live(url: str, headers: Optional[dict], payload: Any) -> bool:
|
|
"""
|
|
Check if an endpoint is live by sending a GET request to the specified URL.
|
|
|
|
Args:
|
|
url (str): The URL of the endpoint to check.
|
|
|
|
Returns:
|
|
bool: True if the endpoint is live (status code 200), False otherwise.
|
|
|
|
Raises:
|
|
Exception: If the endpoint returns a non-successful status code or if there is
|
|
an error querying the endpoint.
|
|
"""
|
|
try:
|
|
response = requests.request("POST", url, headers=headers, data=payload)
|
|
|
|
# Check if the status code is 200 (OK)
|
|
if response.status_code == 200:
|
|
return True
|
|
else:
|
|
# Raise an exception if the status code is not 200
|
|
raise Exception(
|
|
f"Endpoint returned a non-successful status code: "
|
|
f"{response.status_code}"
|
|
)
|
|
except requests.exceptions.RequestException as e:
|
|
# Handle any exceptions (e.g., connection errors)
|
|
raise Exception(f"Error querying the endpoint: {e}")
|
|
|
|
|
|
class NeMoEmbeddings(BaseModel, Embeddings):
|
|
batch_size: int = 16
|
|
model: str = "NV-Embed-QA-003"
|
|
api_endpoint_url: str = "http://localhost:8088/v1/embeddings"
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that the end point is alive using the values that are provided."""
|
|
|
|
url = values["api_endpoint_url"]
|
|
model = values["model"]
|
|
|
|
# Optional: A minimal test payload and headers required by the endpoint
|
|
headers = {"Content-Type": "application/json"}
|
|
payload = json.dumps(
|
|
{"input": "Hello World", "model": model, "input_type": "query"}
|
|
)
|
|
|
|
is_endpoint_live(url, headers, payload)
|
|
|
|
return values
|
|
|
|
async def _aembedding_func(
|
|
self, session: Any, text: str, input_type: str
|
|
) -> List[float]:
|
|
"""Async call out to embedding endpoint.
|
|
|
|
Args:
|
|
text: The text to embed.
|
|
|
|
Returns:
|
|
Embeddings for the text.
|
|
"""
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
async with session.post(
|
|
self.api_endpoint_url,
|
|
json={"input": text, "model": self.model, "input_type": input_type},
|
|
headers=headers,
|
|
) as response:
|
|
response.raise_for_status()
|
|
answer = await response.text()
|
|
answer = json.loads(answer)
|
|
return answer["data"][0]["embedding"]
|
|
|
|
def _embedding_func(self, text: str, input_type: str) -> List[float]:
|
|
"""Call out to Cohere's embedding endpoint.
|
|
|
|
Args:
|
|
text: The text to embed.
|
|
|
|
Returns:
|
|
Embeddings for the text.
|
|
"""
|
|
|
|
payload = json.dumps(
|
|
{"input": text, "model": self.model, "input_type": input_type}
|
|
)
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
response = requests.request(
|
|
"POST", self.api_endpoint_url, headers=headers, data=payload
|
|
)
|
|
response_json = json.loads(response.text)
|
|
embedding = response_json["data"][0]["embedding"]
|
|
|
|
return embedding
|
|
|
|
def embed_documents(self, documents: List[str]) -> List[List[float]]:
|
|
"""Embed a list of document texts.
|
|
|
|
Args:
|
|
texts: The list of texts to embed.
|
|
|
|
Returns:
|
|
List of embeddings, one for each text.
|
|
"""
|
|
return [self._embedding_func(text, input_type="passage") for text in documents]
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
return self._embedding_func(text, input_type="query")
|
|
|
|
async def aembed_query(self, text: str) -> List[float]:
|
|
"""Call out to NeMo's embedding endpoint async for embedding query text.
|
|
|
|
Args:
|
|
text: The text to embed.
|
|
|
|
Returns:
|
|
Embedding for the text.
|
|
"""
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
embedding = await self._aembedding_func(session, text, "passage")
|
|
return embedding
|
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Call out to NeMo's embedding endpoint async for embedding search docs.
|
|
|
|
Args:
|
|
texts: The list of texts to embed.
|
|
|
|
Returns:
|
|
List of embeddings, one for each text.
|
|
"""
|
|
embeddings = []
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
for batch in range(0, len(texts), self.batch_size):
|
|
text_batch = texts[batch : batch + self.batch_size]
|
|
|
|
for text in text_batch:
|
|
# Create tasks for all texts in the batch
|
|
tasks = [
|
|
self._aembedding_func(session, text, "passage")
|
|
for text in text_batch
|
|
]
|
|
|
|
# Run all tasks concurrently
|
|
batch_results = await asyncio.gather(*tasks)
|
|
|
|
# Extend the embeddings list with results from this batch
|
|
embeddings.extend(batch_results)
|
|
|
|
return embeddings
|