Add function call params to invocation params (#7240)

pull/7307/head
William FH 1 year ago committed by GitHub
parent 1f4a51cb9c
commit cb9ff6efb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -121,12 +121,13 @@ class AzureChatOpenAI(ChatOpenAI):
return {**self._default_params}
@property
def _invocation_params(self) -> Mapping[str, Any]:
def _client_params(self) -> Dict[str, Any]:
"""Get the config params used for the openai client."""
openai_creds = {
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
}
return {**openai_creds, **super()._invocation_params}
return {**super()._client_params, **openai_creds}
@property
def _llm_type(self) -> str:

@ -65,10 +65,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
def _get_invocation_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
params = self.dict()
params["stop"] = stop
return params
return {**params, **kwargs}
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
if self.lc_serializable:
@ -77,7 +78,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
llm_string = dumps(self)
return llm_string + "---" + param_string
else:
params = self._get_invocation_params(stop=stop)
params = self._get_invocation_params(stop=stop, **kwargs)
params = {**params, **kwargs}
return str(sorted([(k, v) for k, v in params.items()]))
@ -92,7 +93,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self._get_invocation_params(stop=stop)
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
callback_manager = CallbackManager.configure(
@ -148,7 +149,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self._get_invocation_params(stop=stop)
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(

@ -374,7 +374,7 @@ class ChatOpenAI(BaseChatModel):
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = dict(self._invocation_params)
params = dict(self._client_params)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
@ -439,8 +439,8 @@ class ChatOpenAI(BaseChatModel):
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _invocation_params(self) -> Mapping[str, Any]:
"""Get the parameters used to invoke the model."""
def _client_params(self) -> Mapping[str, Any]:
"""Get the parameters used for the openai client."""
openai_creds: Dict[str, Any] = {
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
@ -453,6 +453,17 @@ class ChatOpenAI(BaseChatModel):
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
return {**openai_creds, **self._default_params}
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
return {
**super()._get_invocation_params(stop=stop, **kwargs),
**self._default_params,
"model": self.model_name,
"function": kwargs.get("functions"),
}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""

Loading…
Cancel
Save