Harrison/headers for openai (#4648)

Co-authored-by: aakash.shah <aakash.shah@quintiles.com>
This commit is contained in:
Harrison Chase 2023-05-13 21:46:20 -07:00 committed by GitHub
parent c09bb00959
commit 87d8d221fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -122,6 +122,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout in seconds for the OpenAPI request.""" """Timeout in seconds for the OpenAPI request."""
headers: Any = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -210,6 +211,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
input=tokens[i : i + _chunk_size], input=tokens[i : i + _chunk_size],
engine=self.deployment, engine=self.deployment,
request_timeout=self.request_timeout, request_timeout=self.request_timeout,
headers=self.headers,
) )
batched_embeddings += [r["embedding"] for r in response["data"]] batched_embeddings += [r["embedding"] for r in response["data"]]
@ -227,6 +229,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
input="", input="",
engine=self.deployment, engine=self.deployment,
request_timeout=self.request_timeout, request_timeout=self.request_timeout,
headers=self.headers,
)["data"][0]["embedding"] )["data"][0]["embedding"]
else: else:
average = np.average( average = np.average(
@ -254,7 +257,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
text = text.replace("\n", " ") text = text.replace("\n", " ")
return embed_with_retry( return embed_with_retry(
self, input=[text], engine=engine, request_timeout=self.request_timeout self,
input=[text],
engine=engine,
request_timeout=self.request_timeout,
headers=self.headers,
)["data"][0]["embedding"] )["data"][0]["embedding"]
def embed_documents( def embed_documents(