mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Fix add callbacks to spark_sql due to depreciation of callback_manager (#9831)
Description: Due to depreciation (regarding to line 109 in [langchain/libs/langchain/langchain/chains/base.py](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chains/base.py) of callback_manager i replaced several parts Issue: None Dependencies: Maintainer: @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ffa5625134
commit
a05fed9369
@ -686,12 +686,15 @@ s
|
||||
cls,
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Create from agent and tools."""
|
||||
return cls(
|
||||
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
|
@ -6,7 +6,7 @@ from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUF
|
||||
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
@ -15,6 +15,7 @@ def create_spark_sql_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: SparkSQLToolkit,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
prefix: str = SQL_PREFIX,
|
||||
suffix: str = SQL_SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
@ -41,6 +42,7 @@ def create_spark_sql_agent(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
@ -48,6 +50,7 @@ def create_spark_sql_agent(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
verbose=verbose,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
|
@ -136,6 +136,12 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
if values.get("callbacks") is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both callback_manager and callbacks. "
|
||||
"callback_manager is deprecated, callbacks is the preferred "
|
||||
"parameter to pass in."
|
||||
)
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
|
Loading…
Reference in New Issue
Block a user