langchain/libs/community/langchain_community/embeddings/nemo.py
nvpranak 91bcc9c5c9
community[minor]: Nemo embeddings(#16206)
This PR is adding support for NVIDIA NeMo embeddings issue #16095.

---------

Co-authored-by: Praveen Nakshatrala <pnakshatrala@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2024-02-14 13:25:42 -08:00

170 lines
5.3 KiB
Python

from __future__ import annotations
import asyncio
import json
from typing import Any, Dict, List, Optional
import aiohttp
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
def is_endpoint_live(url: str, headers: Optional[dict], payload: Any) -> bool:
"""
Check if an endpoint is live by sending a GET request to the specified URL.
Args:
url (str): The URL of the endpoint to check.
Returns:
bool: True if the endpoint is live (status code 200), False otherwise.
Raises:
Exception: If the endpoint returns a non-successful status code or if there is
an error querying the endpoint.
"""
try:
response = requests.request("POST", url, headers=headers, data=payload)
# Check if the status code is 200 (OK)
if response.status_code == 200:
return True
else:
# Raise an exception if the status code is not 200
raise Exception(
f"Endpoint returned a non-successful status code: "
f"{response.status_code}"
)
except requests.exceptions.RequestException as e:
# Handle any exceptions (e.g., connection errors)
raise Exception(f"Error querying the endpoint: {e}")
class NeMoEmbeddings(BaseModel, Embeddings):
batch_size: int = 16
model: str = "NV-Embed-QA-003"
api_endpoint_url: str = "http://localhost:8088/v1/embeddings"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the end point is alive using the values that are provided."""
url = values["api_endpoint_url"]
model = values["model"]
# Optional: A minimal test payload and headers required by the endpoint
headers = {"Content-Type": "application/json"}
payload = json.dumps(
{"input": "Hello World", "model": model, "input_type": "query"}
)
is_endpoint_live(url, headers, payload)
return values
async def _aembedding_func(
self, session: Any, text: str, input_type: str
) -> List[float]:
"""Async call out to embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
headers = {"Content-Type": "application/json"}
async with session.post(
self.api_endpoint_url,
json={"input": text, "model": self.model, "input_type": input_type},
headers=headers,
) as response:
response.raise_for_status()
answer = await response.text()
answer = json.loads(answer)
return answer["data"][0]["embedding"]
def _embedding_func(self, text: str, input_type: str) -> List[float]:
"""Call out to Cohere's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
payload = json.dumps(
{"input": text, "model": self.model, "input_type": input_type}
)
headers = {"Content-Type": "application/json"}
response = requests.request(
"POST", self.api_endpoint_url, headers=headers, data=payload
)
response_json = json.loads(response.text)
embedding = response_json["data"][0]["embedding"]
return embedding
def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""Embed a list of document texts.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
return [self._embedding_func(text, input_type="passage") for text in documents]
def embed_query(self, text: str) -> List[float]:
return self._embedding_func(text, input_type="query")
async def aembed_query(self, text: str) -> List[float]:
"""Call out to NeMo's embedding endpoint async for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
async with aiohttp.ClientSession() as session:
embedding = await self._aembedding_func(session, text, "passage")
return embedding
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to NeMo's embedding endpoint async for embedding search docs.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
async with aiohttp.ClientSession() as session:
for batch in range(0, len(texts), self.batch_size):
text_batch = texts[batch : batch + self.batch_size]
for text in text_batch:
# Create tasks for all texts in the batch
tasks = [
self._aembedding_func(session, text, "passage")
for text in text_batch
]
# Run all tasks concurrently
batch_results = await asyncio.gather(*tasks)
# Extend the embeddings list with results from this batch
embeddings.extend(batch_results)
return embeddings