|
|
@ -7,6 +7,7 @@ from typing import Any, Dict, Mapping
|
|
|
|
from pydantic import root_validator
|
|
|
|
from pydantic import root_validator
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.chat_models.openai import ChatOpenAI
|
|
|
|
from langchain.chat_models.openai import ChatOpenAI
|
|
|
|
|
|
|
|
from langchain.schema import ChatResult
|
|
|
|
from langchain.utils import get_from_dict_or_env
|
|
|
|
from langchain.utils import get_from_dict_or_env
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
@ -119,3 +120,12 @@ class AzureChatOpenAI(ChatOpenAI):
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
return "azure-openai-chat"
|
|
|
|
return "azure-openai-chat"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
|
|
|
|
|
|
|
for res in response["choices"]:
|
|
|
|
|
|
|
|
if res.get("finish_reason", None) == "content_filter":
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"Azure has not provided the response due to a content"
|
|
|
|
|
|
|
|
" filter being triggered"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return super()._create_chat_result(response)
|
|
|
|