mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
normalized openai embeddings embed_query (#8604)
we weren't normalizing when embedding queries
This commit is contained in:
parent
31820a31e4
commit
b574507c51
@ -454,42 +454,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return embeddings
|
||||
|
||||
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint."""
|
||||
# handle large input text
|
||||
if len(text) > self.embedding_ctx_length:
|
||||
return self._get_len_safe_embeddings([text], engine=engine)[0]
|
||||
else:
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
return embed_with_retry(
|
||||
self,
|
||||
input=[text],
|
||||
**self._invocation_params,
|
||||
)[
|
||||
"data"
|
||||
][0]["embedding"]
|
||||
|
||||
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint."""
|
||||
# handle large input text
|
||||
if len(text) > self.embedding_ctx_length:
|
||||
return (await self._aget_len_safe_embeddings([text], engine=engine))[0]
|
||||
else:
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
await async_embed_with_retry(
|
||||
self,
|
||||
input=[text],
|
||||
**self._invocation_params,
|
||||
)
|
||||
)["data"][0]["embedding"]
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
) -> List[List[float]]:
|
||||
@ -533,8 +497,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embedding = self._embedding_func(text, engine=self.deployment)
|
||||
return embedding
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding query text.
|
||||
@ -545,5 +508,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embedding = await self._aembedding_func(text, engine=self.deployment)
|
||||
return embedding
|
||||
embeddings = await self.aembed_documents([text])
|
||||
return embeddings[0]
|
||||
|
@ -69,3 +69,13 @@ def test_openai_embedding_with_empty_string() -> None:
|
||||
][0]["embedding"]
|
||||
assert np.allclose(output[0], expected_output)
|
||||
assert len(output[1]) == 1536
|
||||
|
||||
|
||||
def test_embed_documents_normalized() -> None:
|
||||
output = OpenAIEmbeddings().embed_documents(["foo walked to the market"])
|
||||
assert np.isclose(np.linalg.norm(output[0]), 1.0)
|
||||
|
||||
|
||||
def test_embed_query_normalized() -> None:
|
||||
output = OpenAIEmbeddings().embed_query("foo walked to the market")
|
||||
assert np.isclose(np.linalg.norm(output), 1.0)
|
||||
|
Loading…
Reference in New Issue
Block a user