From eda69b13f3456548e60305433dc1bb6ead6440ff Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 24 Apr 2023 22:19:47 -0700 Subject: [PATCH] openai embeddings (#3488) --- langchain/embeddings/openai.py | 12 +++++++++--- .../integration_tests/embeddings/test_openai.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 6326f0f6..e52b695a 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -179,14 +179,20 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) batched_embeddings += [r["embedding"] for r in response["data"]] - results: List[List[List[float]]] = [[] for i in range(len(texts))] - lens: List[List[int]] = [[] for i in range(len(texts))] + results: List[List[List[float]]] = [[] for _ in range(len(texts))] + lens: List[List[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) lens[indices[i]].append(len(batched_embeddings[i])) for i in range(len(texts)): - average = np.average(results[i], axis=0, weights=lens[i]) + _result = results[i] + if len(_result) == 0: + average = embed_with_retry(self, input="", engine=self.deployment)[ + "data" + ][0]["embedding"] + else: + average = np.average(_result, axis=0, weights=lens[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings diff --git a/tests/integration_tests/embeddings/test_openai.py b/tests/integration_tests/embeddings/test_openai.py index 9aa7d19c..1dba7553 100644 --- a/tests/integration_tests/embeddings/test_openai.py +++ b/tests/integration_tests/embeddings/test_openai.py @@ -1,4 +1,7 @@ """Test openai embeddings.""" +import numpy as np +import openai + from langchain.embeddings.openai import OpenAIEmbeddings @@ -29,3 +32,17 @@ def test_openai_embedding_query() -> None: embedding = OpenAIEmbeddings() output = embedding.embed_query(document) assert len(output) == 1536 + + +def test_openai_embedding_with_empty_string() -> None: + """Test openai embeddings with empty string.""" + document = ["", "abc"] + embedding = OpenAIEmbeddings() + output = embedding.embed_documents(document) + assert len(output) == 2 + assert len(output[0]) == 1536 + expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[ + "data" + ][0]["embedding"] + assert np.allclose(output[0], expected_output) + assert len(output[1]) == 1536