mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add function call params to invocation params (#7240)
This commit is contained in:
parent
1f4a51cb9c
commit
cb9ff6efb8
@ -121,12 +121,13 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
return {**self._default_params}
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Mapping[str, Any]:
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the config params used for the openai client."""
|
||||
openai_creds = {
|
||||
"api_type": self.openai_api_type,
|
||||
"api_version": self.openai_api_version,
|
||||
}
|
||||
return {**openai_creds, **super()._invocation_params}
|
||||
return {**super()._client_params, **openai_creds}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
@ -65,10 +65,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
def _get_invocation_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
return params
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||
if self.lc_serializable:
|
||||
@ -77,7 +78,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
llm_string = dumps(self)
|
||||
return llm_string + "---" + param_string
|
||||
else:
|
||||
params = self._get_invocation_params(stop=stop)
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
params = {**params, **kwargs}
|
||||
return str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
@ -92,7 +93,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self._get_invocation_params(stop=stop)
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
@ -148,7 +149,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self._get_invocation_params(stop=stop)
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
|
@ -374,7 +374,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = dict(self._invocation_params)
|
||||
params = dict(self._client_params)
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
@ -439,8 +439,8 @@ class ChatOpenAI(BaseChatModel):
|
||||
return {**{"model_name": self.model_name}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Mapping[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
def _client_params(self) -> Mapping[str, Any]:
|
||||
"""Get the parameters used for the openai client."""
|
||||
openai_creds: Dict[str, Any] = {
|
||||
"api_key": self.openai_api_key,
|
||||
"api_base": self.openai_api_base,
|
||||
@ -453,6 +453,17 @@ class ChatOpenAI(BaseChatModel):
|
||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
|
||||
return {**openai_creds, **self._default_params}
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
|
||||
return {
|
||||
**super()._get_invocation_params(stop=stop, **kwargs),
|
||||
**self._default_params,
|
||||
"model": self.model_name,
|
||||
"function": kwargs.get("functions"),
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
Loading…
Reference in New Issue
Block a user