diff --git a/langchain/chat_models/azure_openai.py b/langchain/chat_models/azure_openai.py index 06711c6681..a38bbd7f69 100644 --- a/langchain/chat_models/azure_openai.py +++ b/langchain/chat_models/azure_openai.py @@ -121,12 +121,13 @@ class AzureChatOpenAI(ChatOpenAI): return {**self._default_params} @property - def _invocation_params(self) -> Mapping[str, Any]: + def _client_params(self) -> Dict[str, Any]: + """Get the config params used for the openai client.""" openai_creds = { "api_type": self.openai_api_type, "api_version": self.openai_api_version, } - return {**openai_creds, **super()._invocation_params} + return {**super()._client_params, **openai_creds} @property def _llm_type(self) -> str: diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index f32ec3c77e..a3480137b5 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -65,10 +65,11 @@ class BaseChatModel(BaseLanguageModel, ABC): def _get_invocation_params( self, stop: Optional[List[str]] = None, + **kwargs: Any, ) -> dict: params = self.dict() params["stop"] = stop - return params + return {**params, **kwargs} def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str: if self.lc_serializable: @@ -77,7 +78,7 @@ class BaseChatModel(BaseLanguageModel, ABC): llm_string = dumps(self) return llm_string + "---" + param_string else: - params = self._get_invocation_params(stop=stop) + params = self._get_invocation_params(stop=stop, **kwargs) params = {**params, **kwargs} return str(sorted([(k, v) for k, v in params.items()])) @@ -92,7 +93,7 @@ class BaseChatModel(BaseLanguageModel, ABC): **kwargs: Any, ) -> LLMResult: """Top Level call""" - params = self._get_invocation_params(stop=stop) + params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop} callback_manager = CallbackManager.configure( @@ -148,7 +149,7 @@ class BaseChatModel(BaseLanguageModel, ABC): **kwargs: Any, ) -> LLMResult: """Top Level call""" - params = self._get_invocation_params(stop=stop) + params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop} callback_manager = AsyncCallbackManager.configure( diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index af1c6d3e1e..eb61651438 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -374,7 +374,7 @@ class ChatOpenAI(BaseChatModel): def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - params = dict(self._invocation_params) + params = dict(self._client_params) if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") @@ -439,8 +439,8 @@ class ChatOpenAI(BaseChatModel): return {**{"model_name": self.model_name}, **self._default_params} @property - def _invocation_params(self) -> Mapping[str, Any]: - """Get the parameters used to invoke the model.""" + def _client_params(self) -> Mapping[str, Any]: + """Get the parameters used for the openai client.""" openai_creds: Dict[str, Any] = { "api_key": self.openai_api_key, "api_base": self.openai_api_base, @@ -453,6 +453,17 @@ class ChatOpenAI(BaseChatModel): openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501 return {**openai_creds, **self._default_params} + def _get_invocation_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """Get the parameters used to invoke the model FOR THE CALLBACKS.""" + return { + **super()._get_invocation_params(stop=stop, **kwargs), + **self._default_params, + "model": self.model_name, + "function": kwargs.get("functions"), + } + @property def _llm_type(self) -> str: """Return type of chat model."""