openai embeddings (#3488)

This commit is contained in:
Harrison Chase 2023-04-24 22:19:47 -07:00 committed by GitHub
parent d3ce47414d
commit eda69b13f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 3 deletions

View File

@ -179,14 +179,20 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
)
batched_embeddings += [r["embedding"] for r in response["data"]]
results: List[List[List[float]]] = [[] for i in range(len(texts))]
lens: List[List[int]] = [[] for i in range(len(texts))]
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
lens: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
lens[indices[i]].append(len(batched_embeddings[i]))
for i in range(len(texts)):
average = np.average(results[i], axis=0, weights=lens[i])
_result = results[i]
if len(_result) == 0:
average = embed_with_retry(self, input="", engine=self.deployment)[
"data"
][0]["embedding"]
else:
average = np.average(_result, axis=0, weights=lens[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
return embeddings

View File

@ -1,4 +1,7 @@
"""Test openai embeddings."""
import numpy as np
import openai
from langchain.embeddings.openai import OpenAIEmbeddings
@ -29,3 +32,17 @@ def test_openai_embedding_query() -> None:
embedding = OpenAIEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1536
def test_openai_embedding_with_empty_string() -> None:
"""Test openai embeddings with empty string."""
document = ["", "abc"]
embedding = OpenAIEmbeddings()
output = embedding.embed_documents(document)
assert len(output) == 2
assert len(output[0]) == 1536
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[
"data"
][0]["embedding"]
assert np.allclose(output[0], expected_output)
assert len(output[1]) == 1536