FEAT Bedrock cohere embedding support (#13366)

- **Description:** adding cohere embedding support to bedrock embedding
class
  - **Issue:** N/A
  - **Dependencies:** None
  - **Tag maintainer:** @3coins 
  - **Twitter handle:** celmore25

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Clay Elmore 2023-11-15 10:19:12 -08:00 committed by GitHub
parent 9f543634e2
commit 8823e3831f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -112,20 +112,36 @@ class BedrockEmbeddings(BaseModel, Embeddings):
"""Call out to Bedrock embedding endpoint.""" """Call out to Bedrock embedding endpoint."""
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
text = text.replace(os.linesep, " ") text = text.replace(os.linesep, " ")
_model_kwargs = self.model_kwargs or {}
input_body = {**_model_kwargs, "inputText": text} # format input body for provider
provider = self.model_id.split(".")[0]
_model_kwargs = self.model_kwargs or {}
input_body = {**_model_kwargs}
if provider == "cohere":
if "input_type" not in input_body.keys():
input_body["input_type"] = "search_document"
input_body["texts"] = [text]
else:
# includes common provider == "amazon"
input_body["inputText"] = text
body = json.dumps(input_body) body = json.dumps(input_body)
try: try:
# invoke bedrock API
response = self.client.invoke_model( response = self.client.invoke_model(
body=body, body=body,
modelId=self.model_id, modelId=self.model_id,
accept="application/json", accept="application/json",
contentType="application/json", contentType="application/json",
) )
# format output based on provider
response_body = json.loads(response.get("body").read()) response_body = json.loads(response.get("body").read())
return response_body.get("embedding") if provider == "cohere":
return response_body.get("embeddings")[0]
else:
# includes common provider == "amazon"
return response_body.get("embedding")
except Exception as e: except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}") raise ValueError(f"Error raised by inference endpoint: {e}")