From f7f118e021ac9cb0ab36b9a03ae149d5b0cfe4d9 Mon Sep 17 00:00:00 2001 From: Shishin Mo Date: Sat, 8 Apr 2023 14:02:02 +0900 Subject: [PATCH] use openai_organization as argument (#2566) Added support for passing the openai_organization as an argument, as it was only supported by the environment variable but openai_api_key was supported by both environment variables and arguments. `ChatOpenAI(temperature=0, model_name="gpt-4", openai_api_key="sk-****", openai_organization="org-****")` --- langchain/chains/moderation.py | 1 + langchain/chat_models/azure_openai.py | 1 + langchain/chat_models/openai.py | 1 + langchain/embeddings/openai.py | 1 + langchain/llms/openai.py | 1 + 5 files changed, 5 insertions(+) diff --git a/langchain/chains/moderation.py b/langchain/chains/moderation.py index 32b1b471..f7221c04 100644 --- a/langchain/chains/moderation.py +++ b/langchain/chains/moderation.py @@ -31,6 +31,7 @@ class OpenAIModerationChain(Chain): input_key: str = "input" #: :meta private: output_key: str = "output" #: :meta private: openai_api_key: Optional[str] = None + openai_organization: Optional[str] = None @root_validator() def validate_environment(cls, values: Dict) -> Dict: diff --git a/langchain/chat_models/azure_openai.py b/langchain/chat_models/azure_openai.py index 0b052d5f..510d28ec 100644 --- a/langchain/chat_models/azure_openai.py +++ b/langchain/chat_models/azure_openai.py @@ -44,6 +44,7 @@ class AzureChatOpenAI(ChatOpenAI): openai_api_base: str = "" openai_api_version: str = "" openai_api_key: str = "" + openai_organization: str = "" @root_validator() def validate_environment(cls, values: Dict) -> Dict: diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 6d510d4b..38e1045e 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -115,6 +115,7 @@ class ChatOpenAI(BaseChatModel): model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" openai_api_key: Optional[str] = None + openai_organization: Optional[str] = None request_timeout: int = 60 """Timeout in seconds for the OpenAPI request.""" max_retries: int = 6 diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index c66fd226..09daf764 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -98,6 +98,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): query_model_name: str = "text-embedding-ada-002" embedding_ctx_length: int = 8191 openai_api_key: Optional[str] = None + openai_organization: Optional[str] = None chunk_size: int = 1000 """Maximum number of texts to embed in each batch""" max_retries: int = 6 diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 02be32ec..d34b3787 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -151,6 +151,7 @@ class BaseOpenAI(BaseLLM): model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" openai_api_key: Optional[str] = None + openai_organization: Optional[str] = None batch_size: int = 20 """Batch size to use when passing multiple documents to generate.""" request_timeout: Optional[Union[float, Tuple[float, float]]] = None