From 41a433fa33b271580dcac1a5c0b8669dc9aaa0ec Mon Sep 17 00:00:00 2001 From: Yujie Qian Date: Thu, 16 Nov 2023 16:35:36 -0800 Subject: [PATCH] IMPROVEMENT: add input_type to VoyageEmbeddings (#13488) - **Description:** add input_type to VoyageEmbeddings --- .../langchain/embeddings/voyageai.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/embeddings/voyageai.py b/libs/langchain/langchain/embeddings/voyageai.py index 9154d30714..e02f3de6c1 100644 --- a/libs/langchain/langchain/embeddings/voyageai.py +++ b/libs/langchain/langchain/embeddings/voyageai.py @@ -101,17 +101,21 @@ class VoyageEmbeddings(BaseModel, Embeddings): ) return values - def _invocation_params(self, input: List[str]) -> Dict: + def _invocation_params( + self, input: List[str], input_type: Optional[str] = None + ) -> Dict: api_key = cast(SecretStr, self.voyage_api_key).get_secret_value() params = { "url": self.voyage_api_base, "headers": {"Authorization": f"Bearer {api_key}"}, - "json": {"model": self.model, "input": input}, + "json": {"model": self.model, "input": input, "input_type": input_type}, "timeout": self.request_timeout, } return params - def _get_embeddings(self, texts: List[str], batch_size: int) -> List[List[float]]: + def _get_embeddings( + self, texts: List[str], batch_size: int, input_type: Optional[str] = None + ) -> List[List[float]]: embeddings: List[List[float]] = [] if self.show_progress_bar: @@ -127,9 +131,18 @@ class VoyageEmbeddings(BaseModel, Embeddings): else: _iter = range(0, len(texts), batch_size) + if input_type and input_type not in ["query", "document"]: + raise ValueError( + f"input_type {input_type} is invalid. Options: None, 'query', " + "'document'." + ) + for i in _iter: response = embed_with_retry( - self, **self._invocation_params(input=texts[i : i + batch_size]) + self, + **self._invocation_params( + input=texts[i : i + batch_size], input_type=input_type + ), ) embeddings.extend(r["embedding"] for r in response["data"]) @@ -144,7 +157,9 @@ class VoyageEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - return self._get_embeddings(texts, batch_size=self.batch_size) + return self._get_embeddings( + texts, batch_size=self.batch_size, input_type="document" + ) def embed_query(self, text: str) -> List[float]: """Call out to Voyage Embedding endpoint for embedding query text. @@ -155,4 +170,6 @@ class VoyageEmbeddings(BaseModel, Embeddings): Returns: Embedding for the text. """ - return self.embed_documents([text])[0] + return self._get_embeddings( + [text], batch_size=self.batch_size, input_type="query" + )[0]