Fetch up-to-date attributes for env-pulled kwargs during serialisation of OpenAI classes (#11499)

pull/11458/head
David Duong 1 year ago committed by GitHub
parent c3d2b01adf
commit 484947c492
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -141,6 +141,13 @@ class AzureChatOpenAI(ChatOpenAI):
def _llm_type(self) -> str:
return "azure-openai-chat"
@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"openai_api_type": self.openai_api_type,
"openai_api_version": self.openai_api_version,
}
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
for res in response["choices"]:
if res.get("finish_reason", None) == "content_filter":

@ -141,6 +141,21 @@ class ChatOpenAI(BaseChatModel):
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.openai_organization != "":
attributes["openai_organization"] = self.openai_organization
if self.openai_api_base != "":
attributes["openai_api_base"] = self.openai_api_base
if self.openai_proxy != "":
attributes["openai_proxy"] = self.openai_proxy
return attributes
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""

@ -138,6 +138,20 @@ class BaseOpenAI(BaseLLM):
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.openai_api_base != "":
attributes["openai_api_base"] = self.openai_api_base
if self.openai_organization != "":
attributes["openai_organization"] = self.openai_organization
if self.openai_proxy != "":
attributes["openai_proxy"] = self.openai_proxy
return attributes
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@ -692,6 +706,13 @@ class AzureOpenAI(BaseOpenAI):
"""Return type of llm."""
return "azure"
@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"openai_api_type": self.openai_api_type,
"openai_api_version": self.openai_api_version,
}
class OpenAIChat(BaseLLM):
"""OpenAI Chat large language models.

@ -1,16 +1,21 @@
import json
import os
from typing import Any, Mapping, cast
from unittest import mock
import pytest
from langchain.chat_models.azure_openai import AzureChatOpenAI
os.environ["OPENAI_API_KEY"] = "test"
os.environ["OPENAI_API_BASE"] = "https://oai.azure.com/"
os.environ["OPENAI_API_VERSION"] = "2023-05-01"
@mock.patch.dict(
os.environ,
{
"OPENAI_API_KEY": "test",
"OPENAI_API_BASE": "https://oai.azure.com/",
"OPENAI_API_VERSION": "2023-05-01",
},
)
@pytest.mark.requires("openai")
@pytest.mark.parametrize(
"model_name", ["gpt-4", "gpt-4-32k", "gpt-35-turbo", "gpt-35-turbo-16k"]

Loading…
Cancel
Save