@ -7,6 +7,10 @@ from langchain.embeddings.base import Embeddings
from langchain . llms . sagemaker_endpoint import ContentHandlerBase
class EmbeddingsContentHandler ( ContentHandlerBase [ List [ str ] , List [ List [ float ] ] ] ) :
""" Content handler for LLM class. """
class SagemakerEndpointEmbeddings ( BaseModel , Embeddings ) :
""" Wrapper around custom Sagemaker Inference Endpoints.
@ -62,7 +66,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
See : https : / / boto3 . amazonaws . com / v1 / documentation / api / latest / guide / credentials . html
"""
content_handler : ContentHandlerBase
content_handler : Embeddings ContentHandler
""" The content handler class that provides an input and
output transform functions to handle formats between LLM
and the endpoint .
@ -71,21 +75,21 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
"""
Example :
. . code - block : : python
from langchain . llms . sagemaker_endpoint import ContentHandlerBase
class ContentHandler ( ContentHandlerBase ) :
from langchain . embeddings . sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler ( EmbeddingsContentHandler ) :
content_type = " application/json "
accepts = " application/json "
def transform_input ( self , prompt : str , model_kwargs : Dict ) - > bytes :
input_str = json . dumps ( { prompt : prompt , * * model_kwargs } )
def transform_input ( self , prompt s : List [ str ] , model_kwargs : Dict ) - > bytes :
input_str = json . dumps ( { prompt s : prompt s , * * model_kwargs } )
return input_str . encode ( ' utf-8 ' )
def transform_output ( self , output : bytes ) - > str :
def transform_output ( self , output : bytes ) - > List [ List [ float ] ] :
response_json = json . loads ( output . read ( ) . decode ( " utf-8 " ) )
return response_json [ 0 ] [ " generated_text " ]
"""
return response_json [ " vectors " ]
""" # noqa: E501
model_kwargs : Optional [ Dict ] = None
""" Key word arguments to pass to the model. """
@ -135,7 +139,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
)
return values
def _embedding_func ( self , texts : List [ str ] ) - > List [ float ] :
def _embedding_func ( self , texts : List [ str ] ) - > List [ List [ float ] ] :
""" Call out to SageMaker Inference embedding endpoint. """
# replace newlines, which can negatively affect performance.
texts = list ( map ( lambda x : x . replace ( " \n " , " " ) , texts ) )
@ -179,7 +183,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
_chunk_size = len ( texts ) if chunk_size > len ( texts ) else chunk_size
for i in range ( 0 , len ( texts ) , _chunk_size ) :
response = self . _embedding_func ( texts [ i : i + _chunk_size ] )
results . app end( response )
results . ext end( response )
return results
def embed_query ( self , text : str ) - > List [ float ] :
@ -191,4 +195,4 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
Returns :
Embeddings for the text .
"""
return self . _embedding_func ( [ text ] )
return self . _embedding_func ( [ text ] ) [ 0 ]