diff --git a/langchain/embeddings/bedrock.py b/langchain/embeddings/bedrock.py index 56089ed659..35a7eab5cc 100644 --- a/langchain/embeddings/bedrock.py +++ b/langchain/embeddings/bedrock.py @@ -107,27 +107,21 @@ class BedrockEmbeddings(BaseModel, Embeddings): text = text.replace(os.linesep, " ") _model_kwargs = self.model_kwargs or {} - input_body = {**_model_kwargs} - input_body["inputText"] = text + input_body = {**_model_kwargs, "inputText": text} body = json.dumps(input_body) - content_type = "application/json" - accepts = "application/json" - embeddings = [] try: response = self.client.invoke_model( body=body, modelId=self.model_id, - accept=accepts, - contentType=content_type, + accept="application/json", + contentType="application/json", ) response_body = json.loads(response.get("body").read()) - embeddings = response_body.get("embedding") + return response_body.get("embedding") except Exception as e: raise ValueError(f"Error raised by inference endpoint: {e}") - return embeddings - def embed_documents( self, texts: List[str], chunk_size: int = 1 ) -> List[List[float]]: