mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Ensure dict() does not raise not implemented error, which should instead be raised in our custom method save()
This commit is contained in:
parent
392df7b2e3
commit
202acce0c9
@ -274,7 +274,10 @@ class BaseMultiActionAgent(BaseModel):
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
try:
|
||||
_dict["_type"] = str(self._agent_type)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
@ -295,11 +298,13 @@ class BaseMultiActionAgent(BaseModel):
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
agent_dict = self.dict()
|
||||
if "_type" not in agent_dict:
|
||||
raise NotImplementedError(f"Agent {self} does not support saving.")
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
|
@ -610,8 +610,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
chain.dict(exclude_unset=True)
|
||||
# -> {"_type": "foo", "verbose": False, ...}
|
||||
"""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
_dict = super().dict(**kwargs)
|
||||
try:
|
||||
_dict["_type"] = self._chain_type
|
||||
@ -633,6 +631,14 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
chain.save(file_path="path/chain.yaml")
|
||||
"""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
if "_type" not in chain_dict:
|
||||
raise NotImplementedError(f"Chain {self} does not support saving.")
|
||||
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
@ -642,11 +648,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
if "_type" not in chain_dict:
|
||||
raise NotImplementedError(f"Chain {self} does not support saving.")
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(chain_dict, f, indent=4)
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Prompt template that contains few shot examples."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.prompts.base import (
|
||||
@ -151,11 +152,10 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
"""Return the prompt type key."""
|
||||
return "few_shot"
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return a dictionary of the prompt."""
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
if self.example_selector:
|
||||
raise ValueError("Saving an example selector is not currently supported")
|
||||
return super().dict(**kwargs)
|
||||
return super().save(file_path)
|
||||
|
||||
|
||||
class FewShotChatMessagePromptTemplate(
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Prompt template that contains few shot examples."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
|
||||
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||
@ -140,8 +141,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
"""Return the prompt type key."""
|
||||
return "few_shot_with_templates"
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return a dictionary of the prompt."""
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
if self.example_selector:
|
||||
raise ValueError("Saving an example selector is not currently supported")
|
||||
return super().dict(**kwargs)
|
||||
return super().save(file_path)
|
||||
|
@ -298,7 +298,10 @@ class BaseOutputParser(
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict(**kwargs)
|
||||
try:
|
||||
output_parser_dict["_type"] = self._type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return output_parser_dict
|
||||
|
||||
|
||||
|
@ -132,7 +132,10 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of prompt."""
|
||||
prompt_dict = super().dict(**kwargs)
|
||||
try:
|
||||
prompt_dict["_type"] = self._prompt_type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return prompt_dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
@ -148,6 +151,12 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
"""
|
||||
if self.partial_variables:
|
||||
raise ValueError("Cannot save prompt with partial variables.")
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
if "_type" not in prompt_dict:
|
||||
raise NotImplementedError(f"Prompt {self} does not support saving.")
|
||||
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
@ -157,9 +166,6 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(prompt_dict, f, indent=4)
|
||||
|
Loading…
Reference in New Issue
Block a user