mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
FEAT Bedrock cohere embedding support (#13366)
- **Description:** adding cohere embedding support to bedrock embedding class - **Issue:** N/A - **Dependencies:** None - **Tag maintainer:** @3coins - **Twitter handle:** celmore25 --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
9f543634e2
commit
8823e3831f
@ -112,20 +112,36 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Call out to Bedrock embedding endpoint."""
|
"""Call out to Bedrock embedding endpoint."""
|
||||||
# replace newlines, which can negatively affect performance.
|
# replace newlines, which can negatively affect performance.
|
||||||
text = text.replace(os.linesep, " ")
|
text = text.replace(os.linesep, " ")
|
||||||
_model_kwargs = self.model_kwargs or {}
|
|
||||||
|
|
||||||
input_body = {**_model_kwargs, "inputText": text}
|
# format input body for provider
|
||||||
|
provider = self.model_id.split(".")[0]
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
input_body = {**_model_kwargs}
|
||||||
|
if provider == "cohere":
|
||||||
|
if "input_type" not in input_body.keys():
|
||||||
|
input_body["input_type"] = "search_document"
|
||||||
|
input_body["texts"] = [text]
|
||||||
|
else:
|
||||||
|
# includes common provider == "amazon"
|
||||||
|
input_body["inputText"] = text
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# invoke bedrock API
|
||||||
response = self.client.invoke_model(
|
response = self.client.invoke_model(
|
||||||
body=body,
|
body=body,
|
||||||
modelId=self.model_id,
|
modelId=self.model_id,
|
||||||
accept="application/json",
|
accept="application/json",
|
||||||
contentType="application/json",
|
contentType="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# format output based on provider
|
||||||
response_body = json.loads(response.get("body").read())
|
response_body = json.loads(response.get("body").read())
|
||||||
return response_body.get("embedding")
|
if provider == "cohere":
|
||||||
|
return response_body.get("embeddings")[0]
|
||||||
|
else:
|
||||||
|
# includes common provider == "amazon"
|
||||||
|
return response_body.get("embedding")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user