mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add function call params to invocation params (#7240)
This commit is contained in:
parent
1f4a51cb9c
commit
cb9ff6efb8
@ -121,12 +121,13 @@ class AzureChatOpenAI(ChatOpenAI):
|
|||||||
return {**self._default_params}
|
return {**self._default_params}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invocation_params(self) -> Mapping[str, Any]:
|
def _client_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the config params used for the openai client."""
|
||||||
openai_creds = {
|
openai_creds = {
|
||||||
"api_type": self.openai_api_type,
|
"api_type": self.openai_api_type,
|
||||||
"api_version": self.openai_api_version,
|
"api_version": self.openai_api_version,
|
||||||
}
|
}
|
||||||
return {**openai_creds, **super()._invocation_params}
|
return {**super()._client_params, **openai_creds}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
|
@ -65,10 +65,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
def _get_invocation_params(
|
def _get_invocation_params(
|
||||||
self,
|
self,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
return params
|
return {**params, **kwargs}
|
||||||
|
|
||||||
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||||
if self.lc_serializable:
|
if self.lc_serializable:
|
||||||
@ -77,7 +78,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
llm_string = dumps(self)
|
llm_string = dumps(self)
|
||||||
return llm_string + "---" + param_string
|
return llm_string + "---" + param_string
|
||||||
else:
|
else:
|
||||||
params = self._get_invocation_params(stop=stop)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
return str(sorted([(k, v) for k, v in params.items()]))
|
return str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
|
||||||
@ -92,7 +93,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
params = self._get_invocation_params(stop=stop)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
|
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
@ -148,7 +149,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
params = self._get_invocation_params(stop=stop)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
|
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
@ -374,7 +374,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
def _create_message_dicts(
|
def _create_message_dicts(
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
params = dict(self._invocation_params)
|
params = dict(self._client_params)
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
if "stop" in params:
|
if "stop" in params:
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
@ -439,8 +439,8 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
return {**{"model_name": self.model_name}, **self._default_params}
|
return {**{"model_name": self.model_name}, **self._default_params}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invocation_params(self) -> Mapping[str, Any]:
|
def _client_params(self) -> Mapping[str, Any]:
|
||||||
"""Get the parameters used to invoke the model."""
|
"""Get the parameters used for the openai client."""
|
||||||
openai_creds: Dict[str, Any] = {
|
openai_creds: Dict[str, Any] = {
|
||||||
"api_key": self.openai_api_key,
|
"api_key": self.openai_api_key,
|
||||||
"api_base": self.openai_api_base,
|
"api_base": self.openai_api_base,
|
||||||
@ -453,6 +453,17 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
|
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
|
||||||
return {**openai_creds, **self._default_params}
|
return {**openai_creds, **self._default_params}
|
||||||
|
|
||||||
|
def _get_invocation_params(
|
||||||
|
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
|
||||||
|
return {
|
||||||
|
**super()._get_invocation_params(stop=stop, **kwargs),
|
||||||
|
**self._default_params,
|
||||||
|
"model": self.model_name,
|
||||||
|
"function": kwargs.get("functions"),
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of chat model."""
|
"""Return type of chat model."""
|
||||||
|
Loading…
Reference in New Issue
Block a user