Fix .dict() for agent/chain (#11436)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11442/head
Nuno Campos 11 months ago committed by GitHub
parent 1e59c44d36
commit 2f490be09b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -145,10 +145,13 @@ class BaseSingleActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().dict() _dict = super().dict()
_type = self._agent_type try:
_type = self._agent_type
except NotImplementedError:
_type = None
if isinstance(_type, AgentType): if isinstance(_type, AgentType):
_dict["_type"] = str(_type.value) _dict["_type"] = str(_type.value)
else: elif _type is not None:
_dict["_type"] = _type _dict["_type"] = _type
return _dict return _dict
@ -175,6 +178,8 @@ class BaseSingleActionAgent(BaseModel):
# Fetch dictionary to save # Fetch dictionary to save
agent_dict = self.dict() agent_dict = self.dict()
if "_type" not in agent_dict:
raise NotImplementedError(f"Agent {self} does not support saving")
if save_path.suffix == ".json": if save_path.suffix == ".json":
with open(file_path, "w") as f: with open(file_path, "w") as f:

@ -611,7 +611,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
if self.memory is not None: if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.") raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict(**kwargs) _dict = super().dict(**kwargs)
_dict["_type"] = self._chain_type try:
_dict["_type"] = self._chain_type
except NotImplementedError:
pass
return _dict return _dict
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
@ -639,6 +642,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
# Fetch dictionary to save # Fetch dictionary to save
chain_dict = self.dict() chain_dict = self.dict()
if "_type" not in chain_dict:
raise NotImplementedError(f"Chain {self} does not support saving.")
if save_path.suffix == ".json": if save_path.suffix == ".json":
with open(file_path, "w") as f: with open(file_path, "w") as f:

@ -31,6 +31,13 @@ class SerializedNotImplemented(BaseSerialized):
repr: Optional[str] repr: Optional[str]
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
try:
return model.__fields__[key].get_default() != value
except Exception:
return True
class Serializable(BaseModel, ABC): class Serializable(BaseModel, ABC):
"""Serializable base class.""" """Serializable base class."""
@ -81,7 +88,7 @@ class Serializable(BaseModel, ABC):
return [ return [
(k, v) (k, v)
for k, v in super().__repr_args__() for k, v in super().__repr_args__()
if (k not in self.__fields__ or self.__fields__[k].get_default() != v) if (k not in self.__fields__ or try_neq_default(v, k, self))
] ]
_lc_kwargs = PrivateAttr(default_factory=dict) _lc_kwargs = PrivateAttr(default_factory=dict)

@ -2,6 +2,7 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
from langchain.schema.runnable.base import ( from langchain.schema.runnable.base import (
Runnable, Runnable,
RunnableBinding, RunnableBinding,
RunnableGenerator,
RunnableLambda, RunnableLambda,
RunnableMap, RunnableMap,
RunnableSequence, RunnableSequence,
@ -12,8 +13,10 @@ from langchain.schema.runnable.config import RunnableConfig, patch_config
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable from langchain.schema.runnable.router import RouterInput, RouterRunnable
from langchain.schema.runnable.utils import ConfigurableField
__all__ = [ __all__ = [
"ConfigurableField",
"GetLocalVar", "GetLocalVar",
"patch_config", "patch_config",
"PutLocalVar", "PutLocalVar",
@ -24,6 +27,7 @@ __all__ = [
"RunnableBinding", "RunnableBinding",
"RunnableBranch", "RunnableBranch",
"RunnableConfig", "RunnableConfig",
"RunnableGenerator",
"RunnableLambda", "RunnableLambda",
"RunnableMap", "RunnableMap",
"RunnablePassthrough", "RunnablePassthrough",

@ -227,7 +227,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
} }
if configurable: if configurable:
return self.default.__class__(**{**self.default.dict(), **configurable}) return self.default.__class__(**{**self.default.__dict__, **configurable})
else: else:
return self.default return self.default

Loading…
Cancel
Save