From 2f490be09b13ec223e03a9e9eaf2e53e65d77c85 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 5 Oct 2023 15:51:21 +0100 Subject: [PATCH] Fix .dict() for agent/chain (#11436) --- libs/langchain/langchain/agents/agent.py | 9 +++++++-- libs/langchain/langchain/chains/base.py | 7 ++++++- libs/langchain/langchain/load/serializable.py | 9 ++++++++- libs/langchain/langchain/schema/runnable/__init__.py | 4 ++++ libs/langchain/langchain/schema/runnable/configurable.py | 2 +- 5 files changed, 26 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 662ea5d365..648f177c26 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -145,10 +145,13 @@ class BaseSingleActionAgent(BaseModel): def dict(self, **kwargs: Any) -> Dict: """Return dictionary representation of agent.""" _dict = super().dict() - _type = self._agent_type + try: + _type = self._agent_type + except NotImplementedError: + _type = None if isinstance(_type, AgentType): _dict["_type"] = str(_type.value) - else: + elif _type is not None: _dict["_type"] = _type return _dict @@ -175,6 +178,8 @@ class BaseSingleActionAgent(BaseModel): # Fetch dictionary to save agent_dict = self.dict() + if "_type" not in agent_dict: + raise NotImplementedError(f"Agent {self} does not support saving") if save_path.suffix == ".json": with open(file_path, "w") as f: diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index bea11d135c..ac812b45e9 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -611,7 +611,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): if self.memory is not None: raise ValueError("Saving of memory is not yet supported.") _dict = super().dict(**kwargs) - _dict["_type"] = self._chain_type + try: + _dict["_type"] = self._chain_type + except NotImplementedError: + pass return _dict def save(self, file_path: Union[Path, str]) -> None: @@ -639,6 +642,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): # Fetch dictionary to save chain_dict = self.dict() + if "_type" not in chain_dict: + raise NotImplementedError(f"Chain {self} does not support saving.") if save_path.suffix == ".json": with open(file_path, "w") as f: diff --git a/libs/langchain/langchain/load/serializable.py b/libs/langchain/langchain/load/serializable.py index 053b732d62..368f2f0d32 100644 --- a/libs/langchain/langchain/load/serializable.py +++ b/libs/langchain/langchain/load/serializable.py @@ -31,6 +31,13 @@ class SerializedNotImplemented(BaseSerialized): repr: Optional[str] +def try_neq_default(value: Any, key: str, model: BaseModel) -> bool: + try: + return model.__fields__[key].get_default() != value + except Exception: + return True + + class Serializable(BaseModel, ABC): """Serializable base class.""" @@ -81,7 +88,7 @@ class Serializable(BaseModel, ABC): return [ (k, v) for k, v in super().__repr_args__() - if (k not in self.__fields__ or self.__fields__[k].get_default() != v) + if (k not in self.__fields__ or try_neq_default(v, k, self)) ] _lc_kwargs = PrivateAttr(default_factory=dict) diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index e9277319d9..430b1b16b6 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -2,6 +2,7 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.base import ( Runnable, RunnableBinding, + RunnableGenerator, RunnableLambda, RunnableMap, RunnableSequence, @@ -12,8 +13,10 @@ from langchain.schema.runnable.config import RunnableConfig, patch_config from langchain.schema.runnable.fallbacks import RunnableWithFallbacks from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable +from langchain.schema.runnable.utils import ConfigurableField __all__ = [ + "ConfigurableField", "GetLocalVar", "patch_config", "PutLocalVar", @@ -24,6 +27,7 @@ __all__ = [ "RunnableBinding", "RunnableBranch", "RunnableConfig", + "RunnableGenerator", "RunnableLambda", "RunnableMap", "RunnablePassthrough", diff --git a/libs/langchain/langchain/schema/runnable/configurable.py b/libs/langchain/langchain/schema/runnable/configurable.py index e58af4a82b..0f7b07cdda 100644 --- a/libs/langchain/langchain/schema/runnable/configurable.py +++ b/libs/langchain/langchain/schema/runnable/configurable.py @@ -227,7 +227,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): } if configurable: - return self.default.__class__(**{**self.default.dict(), **configurable}) + return self.default.__class__(**{**self.default.__dict__, **configurable}) else: return self.default