mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
normalized openai embeddings embed_query (#8604)
we weren't normalizing when embedding queries
This commit is contained in:
parent
31820a31e4
commit
b574507c51
@ -454,42 +454,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
|
||||||
"""Call out to OpenAI's embedding endpoint."""
|
|
||||||
# handle large input text
|
|
||||||
if len(text) > self.embedding_ctx_length:
|
|
||||||
return self._get_len_safe_embeddings([text], engine=engine)[0]
|
|
||||||
else:
|
|
||||||
if self.model.endswith("001"):
|
|
||||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
|
||||||
# replace newlines, which can negatively affect performance.
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
return embed_with_retry(
|
|
||||||
self,
|
|
||||||
input=[text],
|
|
||||||
**self._invocation_params,
|
|
||||||
)[
|
|
||||||
"data"
|
|
||||||
][0]["embedding"]
|
|
||||||
|
|
||||||
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
|
|
||||||
"""Call out to OpenAI's embedding endpoint."""
|
|
||||||
# handle large input text
|
|
||||||
if len(text) > self.embedding_ctx_length:
|
|
||||||
return (await self._aget_len_safe_embeddings([text], engine=engine))[0]
|
|
||||||
else:
|
|
||||||
if self.model.endswith("001"):
|
|
||||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
|
||||||
# replace newlines, which can negatively affect performance.
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
return (
|
|
||||||
await async_embed_with_retry(
|
|
||||||
self,
|
|
||||||
input=[text],
|
|
||||||
**self._invocation_params,
|
|
||||||
)
|
|
||||||
)["data"][0]["embedding"]
|
|
||||||
|
|
||||||
def embed_documents(
|
def embed_documents(
|
||||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
@ -533,8 +497,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embedding for the text.
|
Embedding for the text.
|
||||||
"""
|
"""
|
||||||
embedding = self._embedding_func(text, engine=self.deployment)
|
return self.embed_documents([text])[0]
|
||||||
return embedding
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
"""Call out to OpenAI's embedding endpoint async for embedding query text.
|
"""Call out to OpenAI's embedding endpoint async for embedding query text.
|
||||||
@ -545,5 +508,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embedding for the text.
|
Embedding for the text.
|
||||||
"""
|
"""
|
||||||
embedding = await self._aembedding_func(text, engine=self.deployment)
|
embeddings = await self.aembed_documents([text])
|
||||||
return embedding
|
return embeddings[0]
|
||||||
|
@ -69,3 +69,13 @@ def test_openai_embedding_with_empty_string() -> None:
|
|||||||
][0]["embedding"]
|
][0]["embedding"]
|
||||||
assert np.allclose(output[0], expected_output)
|
assert np.allclose(output[0], expected_output)
|
||||||
assert len(output[1]) == 1536
|
assert len(output[1]) == 1536
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_documents_normalized() -> None:
|
||||||
|
output = OpenAIEmbeddings().embed_documents(["foo walked to the market"])
|
||||||
|
assert np.isclose(np.linalg.norm(output[0]), 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_query_normalized() -> None:
|
||||||
|
output = OpenAIEmbeddings().embed_query("foo walked to the market")
|
||||||
|
assert np.isclose(np.linalg.norm(output), 1.0)
|
||||||
|
Loading…
Reference in New Issue
Block a user