|
|
|
@ -1,6 +1,6 @@
|
|
|
|
|
"""Wrapper around OpenAI APIs."""
|
|
|
|
|
import sys
|
|
|
|
|
from typing import Any, Dict, Generator, List, Mapping, Optional
|
|
|
|
|
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Extra, Field, root_validator
|
|
|
|
|
|
|
|
|
@ -49,6 +49,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|
|
|
|
openai_api_key: Optional[str] = None
|
|
|
|
|
batch_size: int = 20
|
|
|
|
|
"""Batch size to use when passing multiple documents to generate."""
|
|
|
|
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
|
|
|
|
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
@ -98,6 +100,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|
|
|
|
"presence_penalty": self.presence_penalty,
|
|
|
|
|
"n": self.n,
|
|
|
|
|
"best_of": self.best_of,
|
|
|
|
|
"request_timeout": self.request_timeout,
|
|
|
|
|
}
|
|
|
|
|
return {**normal_params, **self.model_kwargs}
|
|
|
|
|
|
|
|
|
|