You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/embeddings/titan_takeoff.py

208 lines
7.5 KiB
Python

from enum import Enum
from typing import Any, List, Optional, Set, Union
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
class TakeoffEmbeddingException(Exception):
"""Exceptions experienced with interfacing with Takeoff Embedding Wrapper"""
class MissingConsumerGroup(TakeoffEmbeddingException):
"""Exception raised when no consumer group is provided on initialization of
TitanTakeoffEmbed or in embed request"""
class Device(str, Enum):
"""The device to use for inference, cuda or cpu"""
cuda = "cuda"
cpu = "cpu"
class ReaderConfig(BaseModel):
class Config:
protected_namespaces = ()
model_name: str
"""The name of the model to use"""
device: Device = Device.cuda
"""The device to use for inference, cuda or cpu"""
consumer_group: str = "primary"
"""The consumer group to place the reader into"""
class TitanTakeoffEmbed(Embeddings):
"""Titan Takeoff Embed is a wrapper to interface with Takeoff Inference API
for embedding models
You can use this wrapper to send embedding requests and to deploy embedding
readers with Takeoff.
Examples:
This is an example how to deploy an embedding model and send requests.
.. code-block:: python
# Import the TitanTakeoffEmbed class from community package
import time
from langchain_community.embeddings import TitanTakeoffEmbed
# Specify the embedding reader you'd like to deploy
reader_1 = {
"model_name": "avsolatorio/GIST-large-Embedding-v0",
"device": "cpu",
"consumer_group": "embed"
}
# For every reader you pass into models arg Takeoff will spin up a reader
# according to the specs you provide. If you don't specify the arg no models
# are spun up and it assumes you have already done this separately.
embed = TitanTakeoffEmbed(models=[reader_1])
# Wait for the reader to be deployed, time needed depends on the model size
# and your internet speed
time.sleep(60)
# Returns the embedded query, ie a List[float], sent to `embed` consumer
# group where we just spun up the embedding reader
print(embed.embed_query(
"Where can I see football?", consumer_group="embed"
))
# Returns a List of embeddings, ie a List[List[float]], sent to `embed`
# consumer group where we just spun up the embedding reader
print(embed.embed_document(
["Document1", "Document2"],
consumer_group="embed"
))
"""
base_url: str = "http://localhost"
"""The base URL of the Titan Takeoff (Pro) server. Default = "http://localhost"."""
port: int = 3000
"""The port of the Titan Takeoff (Pro) server. Default = 3000."""
mgmt_port: int = 3001
"""The management port of the Titan Takeoff (Pro) server. Default = 3001."""
client: Any = None
"""Takeoff Client Python SDK used to interact with Takeoff API"""
embed_consumer_groups: Set[str] = set()
"""The consumer groups in Takeoff which contain embedding models"""
def __init__(
self,
base_url: str = "http://localhost",
port: int = 3000,
mgmt_port: int = 3001,
models: List[ReaderConfig] = [],
):
"""Initialize the Titan Takeoff embedding wrapper.
Args:
base_url (str, optional): The base url where Takeoff Inference Server is
listening. Defaults to "http://localhost".
port (int, optional): What port is Takeoff Inference API listening on.
Defaults to 3000.
mgmt_port (int, optional): What port is Takeoff Management API listening on.
Defaults to 3001.
models (List[ReaderConfig], optional): Any readers you'd like to spin up on.
Defaults to [].
Raises:
ImportError: If you haven't installed takeoff-client, you will get an
ImportError. To remedy run `pip install 'takeoff-client==0.4.0'`
"""
self.base_url = base_url
self.port = port
self.mgmt_port = mgmt_port
try:
from takeoff_client import TakeoffClient
except ImportError:
raise ImportError(
"takeoff-client is required for TitanTakeoff. "
"Please install it with `pip install 'takeoff-client==0.4.0'`."
)
self.client = TakeoffClient(
self.base_url, port=self.port, mgmt_port=self.mgmt_port
)
for model in models:
self.client.create_reader(model)
if isinstance(model, dict):
self.embed_consumer_groups.add(model.get("consumer_group"))
else:
self.embed_consumer_groups.add(model.consumer_group)
super(TitanTakeoffEmbed, self).__init__()
def _embed(
self, input: Union[List[str], str], consumer_group: Optional[str]
) -> dict:
"""Embed text.
Args:
input (List[str]): prompt/document or list of prompts/documents to embed
consumer_group (Optional[str]): what consumer group to send the embedding
request to. If not specified and there is only one
consumer group specified during initialization, it will be used. If there
are multiple consumer groups specified during initialization, you must
specify which one to use.
Raises:
MissingConsumerGroup: The consumer group can not be inferred from the
initialization and must be specified with request.
Returns:
Dict[str, Any]: Result of query, {"result": List[List[float]]} or
{"result": List[float]}
"""
if not consumer_group:
if len(self.embed_consumer_groups) == 1:
consumer_group = list(self.embed_consumer_groups)[0]
elif len(self.embed_consumer_groups) > 1:
raise MissingConsumerGroup(
"TakeoffEmbedding was initialized with multiple embedding reader"
"groups, you must specify which one to use."
)
else:
raise MissingConsumerGroup(
"You must specify what consumer group you want to send embedding"
"response to as TitanTakeoffEmbed was not initialized with an "
"embedding reader."
)
return self.client.embed(input, consumer_group)
def embed_documents(
self, texts: List[str], consumer_group: Optional[str] = None
) -> List[List[float]]:
"""Embed documents.
Args:
texts (List[str]): List of prompts/documents to embed
consumer_group (Optional[str], optional): Consumer group to send request
to containing embedding model. Defaults to None.
Returns:
List[List[float]]: List of embeddings
"""
return self._embed(texts, consumer_group)["result"]
def embed_query(
self, text: str, consumer_group: Optional[str] = None
) -> List[float]:
"""Embed query.
Args:
text (str): Prompt/document to embed
consumer_group (Optional[str], optional): Consumer group to send request
to containing embedding model. Defaults to None.
Returns:
List[float]: Embedding
"""
return self._embed(text, consumer_group)["result"]