community[minor]: Nemo embeddings(#16206)

This PR is adding support for NVIDIA NeMo embeddings issue #16095.

---------

Co-authored-by: Praveen Nakshatrala <pnakshatrala@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
langchain-ai/langchain@5cbabbd
nvpranak 5 months ago committed by GitHub
parent 7c6009b76f
commit 91bcc9c5c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,121 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "abede47c-6a58-40c3-b7ef-10966a4fc085",
"metadata": {},
"source": [
"# NVIDIA NeMo embeddings"
]
},
{
"cell_type": "markdown",
"id": "38f3d4ce-b36a-48c6-88b0-5970c26bb146",
"metadata": {},
"source": [
"Connect to NVIDIA's embedding service using the `NeMoEmbeddings` class.\n",
"\n",
"The NeMo Retriever Embedding Microservice (NREM) brings the power of state-of-the-art text embedding to your applications, providing unmatched natural language processing and understanding capabilities. Whether you're developing semantic search, Retrieval Augmented Generation (RAG) pipelines—or any application that needs to use text embeddings—NREM has you covered. Built on the NVIDIA software platform incorporating CUDA, TensorRT, and Triton, NREM brings state of the art GPU accelerated Text Embedding model serving.\n",
"\n",
"NREM uses NVIDIA's TensorRT built on top of the Triton Inference Server for optimized inference of text embedding models."
]
},
{
"cell_type": "markdown",
"id": "f5ab6ea1-d074-4f36-ae45-50312a6a82b9",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "32deab16-530d-455c-b40c-914db048cb05",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.embeddings import NeMoEmbeddings"
]
},
{
"cell_type": "markdown",
"id": "de40023c-3391-474d-96cf-fbfb2311e9d7",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "37177018-47f4-48be-8575-83ce5c9a5447",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 16\n",
"model = \"NV-Embed-QA-003\"\n",
"api_endpoint_url = \"http://localhost:8080/v1/embeddings\""
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "08161ed2-8ba3-4226-a387-15c348f8c343",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Checking if endpoint is live: http://localhost:8080/v1/embeddings\n"
]
}
],
"source": [
"embedding_model = NeMoEmbeddings(\n",
" batch_size=batch_size, model=model, api_endpoint_url=api_endpoint_url\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c69070c3-fe2d-4ff7-be4a-73304e2c4f3e",
"metadata": {},
"outputs": [],
"source": [
"embedding_model.embed_query(\"This is a test.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d1d8852-5298-40b5-89c4-5a91ccfc95e5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -65,6 +65,7 @@ from langchain_community.embeddings.mlflow import (
from langchain_community.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
from langchain_community.embeddings.modelscope_hub import ModelScopeEmbeddings
from langchain_community.embeddings.mosaicml import MosaicMLInstructorEmbeddings
from langchain_community.embeddings.nemo import NeMoEmbeddings
from langchain_community.embeddings.nlpcloud import NLPCloudEmbeddings
from langchain_community.embeddings.oci_generative_ai import OCIGenAIEmbeddings
from langchain_community.embeddings.octoai_embeddings import OctoAIEmbeddings
@ -148,6 +149,7 @@ __all__ = [
"BookendEmbeddings",
"VolcanoEmbeddings",
"OCIGenAIEmbeddings",
"NeMoEmbeddings",
]

@ -0,0 +1,169 @@
from __future__ import annotations
import asyncio
import json
from typing import Any, Dict, List, Optional
import aiohttp
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
def is_endpoint_live(url: str, headers: Optional[dict], payload: Any) -> bool:
"""
Check if an endpoint is live by sending a GET request to the specified URL.
Args:
url (str): The URL of the endpoint to check.
Returns:
bool: True if the endpoint is live (status code 200), False otherwise.
Raises:
Exception: If the endpoint returns a non-successful status code or if there is
an error querying the endpoint.
"""
try:
response = requests.request("POST", url, headers=headers, data=payload)
# Check if the status code is 200 (OK)
if response.status_code == 200:
return True
else:
# Raise an exception if the status code is not 200
raise Exception(
f"Endpoint returned a non-successful status code: "
f"{response.status_code}"
)
except requests.exceptions.RequestException as e:
# Handle any exceptions (e.g., connection errors)
raise Exception(f"Error querying the endpoint: {e}")
class NeMoEmbeddings(BaseModel, Embeddings):
batch_size: int = 16
model: str = "NV-Embed-QA-003"
api_endpoint_url: str = "http://localhost:8088/v1/embeddings"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the end point is alive using the values that are provided."""
url = values["api_endpoint_url"]
model = values["model"]
# Optional: A minimal test payload and headers required by the endpoint
headers = {"Content-Type": "application/json"}
payload = json.dumps(
{"input": "Hello World", "model": model, "input_type": "query"}
)
is_endpoint_live(url, headers, payload)
return values
async def _aembedding_func(
self, session: Any, text: str, input_type: str
) -> List[float]:
"""Async call out to embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
headers = {"Content-Type": "application/json"}
async with session.post(
self.api_endpoint_url,
json={"input": text, "model": self.model, "input_type": input_type},
headers=headers,
) as response:
response.raise_for_status()
answer = await response.text()
answer = json.loads(answer)
return answer["data"][0]["embedding"]
def _embedding_func(self, text: str, input_type: str) -> List[float]:
"""Call out to Cohere's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
payload = json.dumps(
{"input": text, "model": self.model, "input_type": input_type}
)
headers = {"Content-Type": "application/json"}
response = requests.request(
"POST", self.api_endpoint_url, headers=headers, data=payload
)
response_json = json.loads(response.text)
embedding = response_json["data"][0]["embedding"]
return embedding
def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""Embed a list of document texts.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
return [self._embedding_func(text, input_type="passage") for text in documents]
def embed_query(self, text: str) -> List[float]:
return self._embedding_func(text, input_type="query")
async def aembed_query(self, text: str) -> List[float]:
"""Call out to NeMo's embedding endpoint async for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
async with aiohttp.ClientSession() as session:
embedding = await self._aembedding_func(session, text, "passage")
return embedding
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to NeMo's embedding endpoint async for embedding search docs.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
async with aiohttp.ClientSession() as session:
for batch in range(0, len(texts), self.batch_size):
text_batch = texts[batch : batch + self.batch_size]
for text in text_batch:
# Create tasks for all texts in the batch
tasks = [
self._aembedding_func(session, text, "passage")
for text in text_batch
]
# Run all tasks concurrently
batch_results = await asyncio.gather(*tasks)
# Extend the embeddings list with results from this batch
embeddings.extend(batch_results)
return embeddings

@ -58,6 +58,7 @@ EXPECTED_ALL = [
"BookendEmbeddings",
"VolcanoEmbeddings",
"OCIGenAIEmbeddings",
"NeMoEmbeddings",
]

Loading…
Cancel
Save