Ensure dict() does not raise not implemented error, which should instead be raised in our custom method save()

This commit is contained in:
Nuno Campos 2023-10-18 09:44:41 +01:00
parent 392df7b2e3
commit 202acce0c9
6 changed files with 38 additions and 23 deletions

View File

@ -274,7 +274,10 @@ class BaseMultiActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
try:
_dict["_type"] = str(self._agent_type)
except NotImplementedError:
pass
return _dict
def save(self, file_path: Union[Path, str]) -> None:
@ -295,11 +298,13 @@ class BaseMultiActionAgent(BaseModel):
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
agent_dict = self.dict()
if "_type" not in agent_dict:
raise NotImplementedError(f"Agent {self} does not support saving.")
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:

View File

@ -610,8 +610,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chain.dict(exclude_unset=True)
# -> {"_type": "foo", "verbose": False, ...}
"""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict(**kwargs)
try:
_dict["_type"] = self._chain_type
@ -633,6 +631,14 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chain.save(file_path="path/chain.yaml")
"""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
# Fetch dictionary to save
chain_dict = self.dict()
if "_type" not in chain_dict:
raise NotImplementedError(f"Chain {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
@ -642,11 +648,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
chain_dict = self.dict()
if "_type" not in chain_dict:
raise NotImplementedError(f"Chain {self} does not support saving.")
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)

View File

@ -1,6 +1,7 @@
"""Prompt template that contains few shot examples."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain.prompts.base import (
@ -151,11 +152,10 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Return the prompt type key."""
return "few_shot"
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the prompt."""
def save(self, file_path: Union[Path, str]) -> None:
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().dict(**kwargs)
return super().save(file_path)
class FewShotChatMessagePromptTemplate(

View File

@ -1,5 +1,6 @@
"""Prompt template that contains few shot examples."""
from typing import Any, Dict, List, Optional
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector
@ -140,8 +141,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
"""Return the prompt type key."""
return "few_shot_with_templates"
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the prompt."""
def save(self, file_path: Union[Path, str]) -> None:
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().dict(**kwargs)
return super().save(file_path)

View File

@ -298,7 +298,10 @@ class BaseOutputParser(
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
try:
output_parser_dict["_type"] = self._type
except NotImplementedError:
pass
return output_parser_dict

View File

@ -132,7 +132,10 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs)
try:
prompt_dict["_type"] = self._prompt_type
except NotImplementedError:
pass
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
@ -148,6 +151,12 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
# Fetch dictionary to save
prompt_dict = self.dict()
if "_type" not in prompt_dict:
raise NotImplementedError(f"Prompt {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
@ -157,9 +166,6 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)