mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
openai embeddings (#3488)
This commit is contained in:
parent
d3ce47414d
commit
eda69b13f3
@ -179,14 +179,20 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
batched_embeddings += [r["embedding"] for r in response["data"]]
|
batched_embeddings += [r["embedding"] for r in response["data"]]
|
||||||
|
|
||||||
results: List[List[List[float]]] = [[] for i in range(len(texts))]
|
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
|
||||||
lens: List[List[int]] = [[] for i in range(len(texts))]
|
lens: List[List[int]] = [[] for _ in range(len(texts))]
|
||||||
for i in range(len(indices)):
|
for i in range(len(indices)):
|
||||||
results[indices[i]].append(batched_embeddings[i])
|
results[indices[i]].append(batched_embeddings[i])
|
||||||
lens[indices[i]].append(len(batched_embeddings[i]))
|
lens[indices[i]].append(len(batched_embeddings[i]))
|
||||||
|
|
||||||
for i in range(len(texts)):
|
for i in range(len(texts)):
|
||||||
average = np.average(results[i], axis=0, weights=lens[i])
|
_result = results[i]
|
||||||
|
if len(_result) == 0:
|
||||||
|
average = embed_with_retry(self, input="", engine=self.deployment)[
|
||||||
|
"data"
|
||||||
|
][0]["embedding"]
|
||||||
|
else:
|
||||||
|
average = np.average(_result, axis=0, weights=lens[i])
|
||||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
"""Test openai embeddings."""
|
"""Test openai embeddings."""
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
@ -29,3 +32,17 @@ def test_openai_embedding_query() -> None:
|
|||||||
embedding = OpenAIEmbeddings()
|
embedding = OpenAIEmbeddings()
|
||||||
output = embedding.embed_query(document)
|
output = embedding.embed_query(document)
|
||||||
assert len(output) == 1536
|
assert len(output) == 1536
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embedding_with_empty_string() -> None:
|
||||||
|
"""Test openai embeddings with empty string."""
|
||||||
|
document = ["", "abc"]
|
||||||
|
embedding = OpenAIEmbeddings()
|
||||||
|
output = embedding.embed_documents(document)
|
||||||
|
assert len(output) == 2
|
||||||
|
assert len(output[0]) == 1536
|
||||||
|
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[
|
||||||
|
"data"
|
||||||
|
][0]["embedding"]
|
||||||
|
assert np.allclose(output[0], expected_output)
|
||||||
|
assert len(output[1]) == 1536
|
||||||
|
Loading…
Reference in New Issue
Block a user