Fix .dict() for agent/chain (#11436)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11442/head
Nuno Campos 10 months ago committed by GitHub
parent 1e59c44d36
commit 2f490be09b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -145,10 +145,13 @@ class BaseSingleActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
_type = self._agent_type
try:
_type = self._agent_type
except NotImplementedError:
_type = None
if isinstance(_type, AgentType):
_dict["_type"] = str(_type.value)
else:
elif _type is not None:
_dict["_type"] = _type
return _dict
@ -175,6 +178,8 @@ class BaseSingleActionAgent(BaseModel):
# Fetch dictionary to save
agent_dict = self.dict()
if "_type" not in agent_dict:
raise NotImplementedError(f"Agent {self} does not support saving")
if save_path.suffix == ".json":
with open(file_path, "w") as f:

@ -611,7 +611,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict(**kwargs)
_dict["_type"] = self._chain_type
try:
_dict["_type"] = self._chain_type
except NotImplementedError:
pass
return _dict
def save(self, file_path: Union[Path, str]) -> None:
@ -639,6 +642,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
# Fetch dictionary to save
chain_dict = self.dict()
if "_type" not in chain_dict:
raise NotImplementedError(f"Chain {self} does not support saving.")
if save_path.suffix == ".json":
with open(file_path, "w") as f:

@ -31,6 +31,13 @@ class SerializedNotImplemented(BaseSerialized):
repr: Optional[str]
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
try:
return model.__fields__[key].get_default() != value
except Exception:
return True
class Serializable(BaseModel, ABC):
"""Serializable base class."""
@ -81,7 +88,7 @@ class Serializable(BaseModel, ABC):
return [
(k, v)
for k, v in super().__repr_args__()
if (k not in self.__fields__ or self.__fields__[k].get_default() != v)
if (k not in self.__fields__ or try_neq_default(v, k, self))
]
_lc_kwargs = PrivateAttr(default_factory=dict)

@ -2,6 +2,7 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
from langchain.schema.runnable.base import (
Runnable,
RunnableBinding,
RunnableGenerator,
RunnableLambda,
RunnableMap,
RunnableSequence,
@ -12,8 +13,10 @@ from langchain.schema.runnable.config import RunnableConfig, patch_config
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable
from langchain.schema.runnable.utils import ConfigurableField
__all__ = [
"ConfigurableField",
"GetLocalVar",
"patch_config",
"PutLocalVar",
@ -24,6 +27,7 @@ __all__ = [
"RunnableBinding",
"RunnableBranch",
"RunnableConfig",
"RunnableGenerator",
"RunnableLambda",
"RunnableMap",
"RunnablePassthrough",

@ -227,7 +227,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
}
if configurable:
return self.default.__class__(**{**self.default.dict(), **configurable})
return self.default.__class__(**{**self.default.__dict__, **configurable})
else:
return self.default

Loading…
Cancel
Save