Ensure dict() does not raise not implemented error, which should instead be raised in our custom method save() (#11970)

.dict() is a Pydantic method that cannot raise exceptions, as it is used
eg. in `__eq__`

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11984/head
Nuno Campos 11 months ago committed by GitHub
commit 9bc7e1851a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -274,7 +274,10 @@ class BaseMultiActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
_dict["_type"] = str(self._agent_type)
try:
_dict["_type"] = str(self._agent_type)
except NotImplementedError:
pass
return _dict
def save(self, file_path: Union[Path, str]) -> None:
@ -295,11 +298,13 @@ class BaseMultiActionAgent(BaseModel):
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
agent_dict = self.dict()
if "_type" not in agent_dict:
raise NotImplementedError(f"Agent {self} does not support saving.")
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:

@ -610,8 +610,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chain.dict(exclude_unset=True)
# -> {"_type": "foo", "verbose": False, ...}
"""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict(**kwargs)
try:
_dict["_type"] = self._chain_type
@ -633,6 +631,14 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chain.save(file_path="path/chain.yaml")
"""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
# Fetch dictionary to save
chain_dict = self.dict()
if "_type" not in chain_dict:
raise NotImplementedError(f"Chain {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
@ -642,11 +648,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
chain_dict = self.dict()
if "_type" not in chain_dict:
raise NotImplementedError(f"Chain {self} does not support saving.")
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)

@ -1,6 +1,7 @@
"""Prompt template that contains few shot examples."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain.prompts.base import (
@ -151,11 +152,10 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Return the prompt type key."""
return "few_shot"
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the prompt."""
def save(self, file_path: Union[Path, str]) -> None:
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().dict(**kwargs)
return super().save(file_path)
class FewShotChatMessagePromptTemplate(

@ -1,5 +1,6 @@
"""Prompt template that contains few shot examples."""
from typing import Any, Dict, List, Optional
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector
@ -140,8 +141,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
"""Return the prompt type key."""
return "few_shot_with_templates"
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the prompt."""
def save(self, file_path: Union[Path, str]) -> None:
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().dict(**kwargs)
return super().save(file_path)

@ -298,7 +298,10 @@ class BaseOutputParser(
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
output_parser_dict["_type"] = self._type
try:
output_parser_dict["_type"] = self._type
except NotImplementedError:
pass
return output_parser_dict

@ -132,7 +132,10 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs)
prompt_dict["_type"] = self._prompt_type
try:
prompt_dict["_type"] = self._prompt_type
except NotImplementedError:
pass
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
@ -148,6 +151,12 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
# Fetch dictionary to save
prompt_dict = self.dict()
if "_type" not in prompt_dict:
raise NotImplementedError(f"Prompt {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
@ -157,9 +166,6 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)

Loading…
Cancel
Save