mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
refactor BedrockEmbeddings class (#7266)
#### Description refactor BedrockEmbeddings class to clean code as below: 1. inline content type and accept 2. rewrite input_body as a dictionary literal 3. no need to declare embeddings variable, so remove it
This commit is contained in:
parent
c7cf11b8ab
commit
8c371e12eb
@ -107,27 +107,21 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
|||||||
text = text.replace(os.linesep, " ")
|
text = text.replace(os.linesep, " ")
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
|
||||||
input_body = {**_model_kwargs}
|
input_body = {**_model_kwargs, "inputText": text}
|
||||||
input_body["inputText"] = text
|
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
content_type = "application/json"
|
|
||||||
accepts = "application/json"
|
|
||||||
|
|
||||||
embeddings = []
|
|
||||||
try:
|
try:
|
||||||
response = self.client.invoke_model(
|
response = self.client.invoke_model(
|
||||||
body=body,
|
body=body,
|
||||||
modelId=self.model_id,
|
modelId=self.model_id,
|
||||||
accept=accepts,
|
accept="application/json",
|
||||||
contentType=content_type,
|
contentType="application/json",
|
||||||
)
|
)
|
||||||
response_body = json.loads(response.get("body").read())
|
response_body = json.loads(response.get("body").read())
|
||||||
embeddings = response_body.get("embedding")
|
return response_body.get("embedding")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def embed_documents(
|
def embed_documents(
|
||||||
self, texts: List[str], chunk_size: int = 1
|
self, texts: List[str], chunk_size: int = 1
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
|
Loading…
Reference in New Issue
Block a user