normalized openai embeddings embed_query (#8604)

we weren't normalizing when embedding queries
pull/8607/head
Bagatur 1 year ago committed by GitHub
parent 31820a31e4
commit b574507c51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -454,42 +454,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return embeddings
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint."""
# handle large input text
if len(text) > self.embedding_ctx_length:
return self._get_len_safe_embeddings([text], engine=engine)[0]
else:
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return embed_with_retry(
self,
input=[text],
**self._invocation_params,
)[
"data"
][0]["embedding"]
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint."""
# handle large input text
if len(text) > self.embedding_ctx_length:
return (await self._aget_len_safe_embeddings([text], engine=engine))[0]
else:
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (
await async_embed_with_retry(
self,
input=[text],
**self._invocation_params,
)
)["data"][0]["embedding"]
def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
) -> List[List[float]]:
@ -533,8 +497,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
Embedding for the text.
"""
embedding = self._embedding_func(text, engine=self.deployment)
return embedding
return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint async for embedding query text.
@ -545,5 +508,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
Embedding for the text.
"""
embedding = await self._aembedding_func(text, engine=self.deployment)
return embedding
embeddings = await self.aembed_documents([text])
return embeddings[0]

@ -69,3 +69,13 @@ def test_openai_embedding_with_empty_string() -> None:
][0]["embedding"]
assert np.allclose(output[0], expected_output)
assert len(output[1]) == 1536
def test_embed_documents_normalized() -> None:
output = OpenAIEmbeddings().embed_documents(["foo walked to the market"])
assert np.isclose(np.linalg.norm(output[0]), 1.0)
def test_embed_query_normalized() -> None:
output = OpenAIEmbeddings().embed_query("foo walked to the market")
assert np.isclose(np.linalg.norm(output), 1.0)

Loading…
Cancel
Save