diff --git a/langchain/chat_models/azure_openai.py b/langchain/chat_models/azure_openai.py index 8b370251..06711c66 100644 --- a/langchain/chat_models/azure_openai.py +++ b/langchain/chat_models/azure_openai.py @@ -53,33 +53,33 @@ class AzureChatOpenAI(ChatOpenAI): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = get_from_dict_or_env( + values["openai_api_key"] = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY", ) - openai_api_base = get_from_dict_or_env( + values["openai_api_base"] = get_from_dict_or_env( values, "openai_api_base", "OPENAI_API_BASE", ) - openai_api_version = get_from_dict_or_env( + values["openai_api_version"] = get_from_dict_or_env( values, "openai_api_version", "OPENAI_API_VERSION", ) - openai_api_type = get_from_dict_or_env( + values["openai_api_type"] = get_from_dict_or_env( values, "openai_api_type", "OPENAI_API_TYPE", ) - openai_organization = get_from_dict_or_env( + values["openai_organization"] = get_from_dict_or_env( values, "openai_organization", "OPENAI_ORGANIZATION", default="", ) - openai_proxy = get_from_dict_or_env( + values["openai_proxy"] = get_from_dict_or_env( values, "openai_proxy", "OPENAI_PROXY", @@ -88,14 +88,6 @@ class AzureChatOpenAI(ChatOpenAI): try: import openai - openai.api_type = openai_api_type - openai.api_base = openai_api_base - openai.api_version = openai_api_version - openai.api_key = openai_api_key - if openai_organization: - openai.organization = openai_organization - if openai_proxy: - openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501 except ImportError: raise ImportError( "Could not import openai python package. " @@ -128,6 +120,14 @@ class AzureChatOpenAI(ChatOpenAI): """Get the identifying parameters.""" return {**self._default_params} + @property + def _invocation_params(self) -> Mapping[str, Any]: + openai_creds = { + "api_type": self.openai_api_type, + "api_version": self.openai_api_version, + } + return {**openai_creds, **super()._invocation_params} + @property def _llm_type(self) -> str: return "azure-openai-chat" diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index b30e4820..7fd2daa3 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -196,22 +196,22 @@ class ChatOpenAI(BaseChatModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = get_from_dict_or_env( + values["openai_api_key"] = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) - openai_organization = get_from_dict_or_env( + values["openai_organization"] = get_from_dict_or_env( values, "openai_organization", "OPENAI_ORGANIZATION", default="", ) - openai_api_base = get_from_dict_or_env( + values["openai_api_base"] = get_from_dict_or_env( values, "openai_api_base", "OPENAI_API_BASE", default="", ) - openai_proxy = get_from_dict_or_env( + values["openai_proxy"] = get_from_dict_or_env( values, "openai_proxy", "OPENAI_PROXY", @@ -225,13 +225,6 @@ class ChatOpenAI(BaseChatModel): "Could not import openai python package. " "Please install it with `pip install openai`." ) - openai.api_key = openai_api_key - if openai_organization: - openai.organization = openai_organization - if openai_api_base: - openai.api_base = openai_api_base - if openai_proxy: - openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501 try: values["client"] = openai.ChatCompletion except AttributeError: @@ -333,7 +326,7 @@ class ChatOpenAI(BaseChatModel): def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} + params = dict(self._invocation_params) if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") @@ -384,6 +377,21 @@ class ChatOpenAI(BaseChatModel): """Get the identifying parameters.""" return {**{"model_name": self.model_name}, **self._default_params} + @property + def _invocation_params(self) -> Mapping[str, Any]: + """Get the parameters used to invoke the model.""" + openai_creds: Dict[str, Any] = { + "api_key": self.openai_api_key, + "api_base": self.openai_api_base, + "organization": self.openai_organization, + "model": self.model_name, + } + if self.openai_proxy: + openai_creds["proxy"] = ( + {"http": self.openai_proxy, "https": self.openai_proxy}, + ) + return {**openai_creds, **self._default_params} + @property def _llm_type(self) -> str: """Return type of chat model.""" diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 1fa9916d..c10a5526 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -136,38 +136,38 @@ class OpenAIEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = get_from_dict_or_env( + values["openai_api_key"] = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) - openai_api_base = get_from_dict_or_env( + values["openai_api_base"] = get_from_dict_or_env( values, "openai_api_base", "OPENAI_API_BASE", default="", ) - openai_api_type = get_from_dict_or_env( + values["openai_api_type"] = get_from_dict_or_env( values, "openai_api_type", "OPENAI_API_TYPE", default="", ) - openai_proxy = get_from_dict_or_env( + values["openai_proxy"] = get_from_dict_or_env( values, "openai_proxy", "OPENAI_PROXY", default="", ) - if openai_api_type in ("azure", "azure_ad", "azuread"): + if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): default_api_version = "2022-12-01" else: default_api_version = "" - openai_api_version = get_from_dict_or_env( + values["openai_api_version"] = get_from_dict_or_env( values, "openai_api_version", "OPENAI_API_VERSION", default=default_api_version, ) - openai_organization = get_from_dict_or_env( + values["openai_organization"] = get_from_dict_or_env( values, "openai_organization", "OPENAI_ORGANIZATION", @@ -176,17 +176,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings): try: import openai - openai.api_key = openai_api_key - if openai_organization: - openai.organization = openai_organization - if openai_api_base: - openai.api_base = openai_api_base - if openai_api_type: - openai.api_version = openai_api_version - if openai_api_type: - openai.api_type = openai_api_type - if openai_proxy: - openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501 values["client"] = openai.Embedding except ImportError: raise ImportError( @@ -195,6 +184,25 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) return values + @property + def _invocation_params(self) -> Dict: + openai_args = { + "engine": self.deployment, + "request_timeout": self.request_timeout, + "headers": self.headers, + "api_key": self.openai_api_key, + "organization": self.openai_organization, + "api_base": self.openai_api_base, + "api_type": self.openai_api_type, + "api_version": self.openai_api_version, + } + if self.openai_proxy: + openai_args["proxy"] = { + "http": self.openai_proxy, + "https": self.openai_proxy, + } + return openai_args + # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb def _get_len_safe_embeddings( @@ -233,9 +241,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): response = embed_with_retry( self, input=tokens[i : i + _chunk_size], - engine=self.deployment, - request_timeout=self.request_timeout, - headers=self.headers, + **self._invocation_params, ) batched_embeddings += [r["embedding"] for r in response["data"]] diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index bcb3ac9f..f56c4e77 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -211,22 +211,22 @@ class BaseOpenAI(BaseLLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = get_from_dict_or_env( + values["openai_api_key"] = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) - openai_api_base = get_from_dict_or_env( + values["openai_api_base"] = get_from_dict_or_env( values, "openai_api_base", "OPENAI_API_BASE", default="", ) - openai_proxy = get_from_dict_or_env( + values["openai_proxy"] = get_from_dict_or_env( values, "openai_proxy", "OPENAI_PROXY", default="", ) - openai_organization = get_from_dict_or_env( + values["openai_organization"] = get_from_dict_or_env( values, "openai_organization", "OPENAI_ORGANIZATION", @@ -235,13 +235,6 @@ class BaseOpenAI(BaseLLM): try: import openai - openai.api_key = openai_api_key - if openai_api_base: - openai.api_base = openai_api_base - if openai_organization: - openai.organization = openai_organization - if openai_proxy: - openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501 values["client"] = openai.Completion except ImportError: raise ImportError( @@ -452,7 +445,17 @@ class BaseOpenAI(BaseLLM): @property def _invocation_params(self) -> Dict[str, Any]: """Get the parameters used to invoke the model.""" - return self._default_params + openai_creds: Dict[str, Any] = { + "api_key": self.openai_api_key, + "api_base": self.openai_api_base, + "organization": self.openai_organization, + } + if self.openai_proxy: + openai_creds["proxy"] = { + "http": self.openai_proxy, + "https": self.openai_proxy, + } + return {**openai_creds, **self._default_params} @property def _identifying_params(self) -> Mapping[str, Any]: @@ -596,6 +599,22 @@ class AzureOpenAI(BaseOpenAI): deployment_name: str = "" """Deployment name to use.""" + openai_api_type: str = "azure" + openai_api_version: str = "" + + @root_validator() + def validate_azure_settings(cls, values: Dict) -> Dict: + values["openai_api_version"] = get_from_dict_or_env( + values, + "openai_api_version", + "OPENAI_API_VERSION", + ) + values["openai_api_type"] = get_from_dict_or_env( + values, + "openai_api_type", + "OPENAI_API_TYPE", + ) + return values @property def _identifying_params(self) -> Mapping[str, Any]: @@ -606,7 +625,12 @@ class AzureOpenAI(BaseOpenAI): @property def _invocation_params(self) -> Dict[str, Any]: - return {**{"engine": self.deployment_name}, **super()._invocation_params} + openai_params = { + "engine": self.deployment_name, + "api_type": self.openai_api_type, + "api_version": self.openai_api_version, + } + return {**openai_params, **super()._invocation_params} @property def _llm_type(self) -> str: