Fix add callbacks to spark_sql due to depreciation of callback_manager (#9831)

Description: Due to depreciation (regarding to line 109 in
[langchain/libs/langchain/langchain/chains/base.py](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chains/base.py)
of callback_manager i replaced several parts

Issue: None
Dependencies: 
Maintainer: @baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Jurik-001 2023-08-30 04:23:44 +02:00 committed by GitHub
parent ffa5625134
commit a05fed9369
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 3 deletions

View File

@ -686,12 +686,15 @@ s
cls,
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> AgentExecutor:
"""Create from agent and tools."""
return cls(
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
agent=agent,
tools=tools,
callbacks=callbacks,
**kwargs,
)
@root_validator()

View File

@ -6,7 +6,7 @@ from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUF
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.chains.llm import LLMChain
from langchain.schema.language_model import BaseLanguageModel
@ -15,6 +15,7 @@ def create_spark_sql_agent(
llm: BaseLanguageModel,
toolkit: SparkSQLToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
prefix: str = SQL_PREFIX,
suffix: str = SQL_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
@ -41,6 +42,7 @@ def create_spark_sql_agent(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
callbacks=callbacks,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
@ -48,6 +50,7 @@ def create_spark_sql_agent(
agent=agent,
tools=tools,
callback_manager=callback_manager,
callbacks=callbacks,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,

View File

@ -136,6 +136,12 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
if values.get("callbacks") is not None:
raise ValueError(
"Cannot specify both callback_manager and callbacks. "
"callback_manager is deprecated, callbacks is the preferred "
"parameter to pass in."
)
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,