forked from Archives/langchain
openai embeddings (#3488)
This commit is contained in:
parent
d3ce47414d
commit
eda69b13f3
@ -179,14 +179,20 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
batched_embeddings += [r["embedding"] for r in response["data"]]
|
||||
|
||||
results: List[List[List[float]]] = [[] for i in range(len(texts))]
|
||||
lens: List[List[int]] = [[] for i in range(len(texts))]
|
||||
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
|
||||
lens: List[List[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
lens[indices[i]].append(len(batched_embeddings[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
average = np.average(results[i], axis=0, weights=lens[i])
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
average = embed_with_retry(self, input="", engine=self.deployment)[
|
||||
"data"
|
||||
][0]["embedding"]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=lens[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
return embeddings
|
||||
|
@ -1,4 +1,7 @@
|
||||
"""Test openai embeddings."""
|
||||
import numpy as np
|
||||
import openai
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
|
||||
@ -29,3 +32,17 @@ def test_openai_embedding_query() -> None:
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 1536
|
||||
|
||||
|
||||
def test_openai_embedding_with_empty_string() -> None:
|
||||
"""Test openai embeddings with empty string."""
|
||||
document = ["", "abc"]
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_documents(document)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 1536
|
||||
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[
|
||||
"data"
|
||||
][0]["embedding"]
|
||||
assert np.allclose(output[0], expected_output)
|
||||
assert len(output[1]) == 1536
|
||||
|
Loading…
Reference in New Issue
Block a user