From b574507c51c0a2183ed4cde41efa5b2e8c0d98f7 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 1 Aug 2023 17:12:10 -0700 Subject: [PATCH] normalized openai embeddings embed_query (#8604) we weren't normalizing when embedding queries --- libs/langchain/langchain/embeddings/openai.py | 43 ++----------------- .../embeddings/test_openai.py | 10 +++++ 2 files changed, 13 insertions(+), 40 deletions(-) diff --git a/libs/langchain/langchain/embeddings/openai.py b/libs/langchain/langchain/embeddings/openai.py index 2234975f0a..383c8f4649 100644 --- a/libs/langchain/langchain/embeddings/openai.py +++ b/libs/langchain/langchain/embeddings/openai.py @@ -454,42 +454,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return embeddings - def _embedding_func(self, text: str, *, engine: str) -> List[float]: - """Call out to OpenAI's embedding endpoint.""" - # handle large input text - if len(text) > self.embedding_ctx_length: - return self._get_len_safe_embeddings([text], engine=engine)[0] - else: - if self.model.endswith("001"): - # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 - # replace newlines, which can negatively affect performance. - text = text.replace("\n", " ") - return embed_with_retry( - self, - input=[text], - **self._invocation_params, - )[ - "data" - ][0]["embedding"] - - async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: - """Call out to OpenAI's embedding endpoint.""" - # handle large input text - if len(text) > self.embedding_ctx_length: - return (await self._aget_len_safe_embeddings([text], engine=engine))[0] - else: - if self.model.endswith("001"): - # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 - # replace newlines, which can negatively affect performance. - text = text.replace("\n", " ") - return ( - await async_embed_with_retry( - self, - input=[text], - **self._invocation_params, - ) - )["data"][0]["embedding"] - def embed_documents( self, texts: List[str], chunk_size: Optional[int] = 0 ) -> List[List[float]]: @@ -533,8 +497,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: Embedding for the text. """ - embedding = self._embedding_func(text, engine=self.deployment) - return embedding + return self.embed_documents([text])[0] async def aembed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint async for embedding query text. @@ -545,5 +508,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: Embedding for the text. """ - embedding = await self._aembedding_func(text, engine=self.deployment) - return embedding + embeddings = await self.aembed_documents([text]) + return embeddings[0] diff --git a/libs/langchain/tests/integration_tests/embeddings/test_openai.py b/libs/langchain/tests/integration_tests/embeddings/test_openai.py index 6033d3f114..2d117a1fc3 100644 --- a/libs/langchain/tests/integration_tests/embeddings/test_openai.py +++ b/libs/langchain/tests/integration_tests/embeddings/test_openai.py @@ -69,3 +69,13 @@ def test_openai_embedding_with_empty_string() -> None: ][0]["embedding"] assert np.allclose(output[0], expected_output) assert len(output[1]) == 1536 + + +def test_embed_documents_normalized() -> None: + output = OpenAIEmbeddings().embed_documents(["foo walked to the market"]) + assert np.isclose(np.linalg.norm(output[0]), 1.0) + + +def test_embed_query_normalized() -> None: + output = OpenAIEmbeddings().embed_query("foo walked to the market") + assert np.isclose(np.linalg.norm(output), 1.0)