@ -7,6 +7,10 @@ from langchain.embeddings.base import Embeddings
from langchain . llms . sagemaker_endpoint import ContentHandlerBase
from langchain . llms . sagemaker_endpoint import ContentHandlerBase
class EmbeddingsContentHandler ( ContentHandlerBase [ List [ str ] , List [ List [ float ] ] ] ) :
""" Content handler for LLM class. """
class SagemakerEndpointEmbeddings ( BaseModel , Embeddings ) :
class SagemakerEndpointEmbeddings ( BaseModel , Embeddings ) :
""" Wrapper around custom Sagemaker Inference Endpoints.
""" Wrapper around custom Sagemaker Inference Endpoints.
@ -62,7 +66,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
See : https : / / boto3 . amazonaws . com / v1 / documentation / api / latest / guide / credentials . html
See : https : / / boto3 . amazonaws . com / v1 / documentation / api / latest / guide / credentials . html
"""
"""
content_handler : ContentHandlerBase
content_handler : Embeddings ContentHandler
""" The content handler class that provides an input and
""" The content handler class that provides an input and
output transform functions to handle formats between LLM
output transform functions to handle formats between LLM
and the endpoint .
and the endpoint .
@ -71,21 +75,21 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
"""
"""
Example :
Example :
. . code - block : : python
. . code - block : : python
from langchain . llms . sagemaker_endpoint import ContentHandlerBase
class ContentHandler ( ContentHandlerBase ) :
from langchain . embeddings . sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler ( EmbeddingsContentHandler ) :
content_type = " application/json "
content_type = " application/json "
accepts = " application/json "
accepts = " application/json "
def transform_input ( self , prompt : str , model_kwargs : Dict ) - > bytes :
def transform_input ( self , prompt s : List [ str ] , model_kwargs : Dict ) - > bytes :
input_str = json . dumps ( { prompt : prompt , * * model_kwargs } )
input_str = json . dumps ( { prompt s : prompt s , * * model_kwargs } )
return input_str . encode ( ' utf-8 ' )
return input_str . encode ( ' utf-8 ' )
def transform_output ( self , output : bytes ) - > str :
def transform_output ( self , output : bytes ) - > List [ List [ float ] ] :
response_json = json . loads ( output . read ( ) . decode ( " utf-8 " ) )
response_json = json . loads ( output . read ( ) . decode ( " utf-8 " ) )
return response_json [ 0 ] [ " generated_text " ]
return response_json [ " vectors " ]
"""
""" # noqa: E501
model_kwargs : Optional [ Dict ] = None
model_kwargs : Optional [ Dict ] = None
""" Key word arguments to pass to the model. """
""" Key word arguments to pass to the model. """
@ -135,7 +139,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
)
)
return values
return values
def _embedding_func ( self , texts : List [ str ] ) - > List [ float ] :
def _embedding_func ( self , texts : List [ str ] ) - > List [ List [ float ] ] :
""" Call out to SageMaker Inference embedding endpoint. """
""" Call out to SageMaker Inference embedding endpoint. """
# replace newlines, which can negatively affect performance.
# replace newlines, which can negatively affect performance.
texts = list ( map ( lambda x : x . replace ( " \n " , " " ) , texts ) )
texts = list ( map ( lambda x : x . replace ( " \n " , " " ) , texts ) )
@ -179,7 +183,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
_chunk_size = len ( texts ) if chunk_size > len ( texts ) else chunk_size
_chunk_size = len ( texts ) if chunk_size > len ( texts ) else chunk_size
for i in range ( 0 , len ( texts ) , _chunk_size ) :
for i in range ( 0 , len ( texts ) , _chunk_size ) :
response = self . _embedding_func ( texts [ i : i + _chunk_size ] )
response = self . _embedding_func ( texts [ i : i + _chunk_size ] )
results . app end( response )
results . ext end( response )
return results
return results
def embed_query ( self , text : str ) - > List [ float ] :
def embed_query ( self , text : str ) - > List [ float ] :
@ -191,4 +195,4 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
Returns :
Returns :
Embeddings for the text .
Embeddings for the text .
"""
"""
return self . _embedding_func ( [ text ] )
return self . _embedding_func ( [ text ] ) [ 0 ]