From 202acce0c998ef80c09a296ca0a21c1c5479ac25 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 18 Oct 2023 09:44:41 +0100 Subject: [PATCH] Ensure dict() does not raise not implemented error, which should instead be raised in our custom method save() --- libs/langchain/langchain/agents/agent.py | 13 +++++++++---- libs/langchain/langchain/chains/base.py | 15 ++++++++------- libs/langchain/langchain/prompts/few_shot.py | 6 +++--- .../langchain/prompts/few_shot_with_templates.py | 8 ++++---- libs/langchain/langchain/schema/output_parser.py | 5 ++++- .../langchain/langchain/schema/prompt_template.py | 14 ++++++++++---- 6 files changed, 38 insertions(+), 23 deletions(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 648f177c26..86f45ed244 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -274,7 +274,10 @@ class BaseMultiActionAgent(BaseModel): def dict(self, **kwargs: Any) -> Dict: """Return dictionary representation of agent.""" _dict = super().dict() - _dict["_type"] = str(self._agent_type) + try: + _dict["_type"] = str(self._agent_type) + except NotImplementedError: + pass return _dict def save(self, file_path: Union[Path, str]) -> None: @@ -295,11 +298,13 @@ class BaseMultiActionAgent(BaseModel): else: save_path = file_path - directory_path = save_path.parent - directory_path.mkdir(parents=True, exist_ok=True) - # Fetch dictionary to save agent_dict = self.dict() + if "_type" not in agent_dict: + raise NotImplementedError(f"Agent {self} does not support saving.") + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) if save_path.suffix == ".json": with open(file_path, "w") as f: diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 0da360c7e5..83a83346ab 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -610,8 +610,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): chain.dict(exclude_unset=True) # -> {"_type": "foo", "verbose": False, ...} """ - if self.memory is not None: - raise ValueError("Saving of memory is not yet supported.") _dict = super().dict(**kwargs) try: _dict["_type"] = self._chain_type @@ -633,6 +631,14 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): chain.save(file_path="path/chain.yaml") """ + if self.memory is not None: + raise ValueError("Saving of memory is not yet supported.") + + # Fetch dictionary to save + chain_dict = self.dict() + if "_type" not in chain_dict: + raise NotImplementedError(f"Chain {self} does not support saving.") + # Convert file to Path object. if isinstance(file_path, str): save_path = Path(file_path) @@ -642,11 +648,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) - # Fetch dictionary to save - chain_dict = self.dict() - if "_type" not in chain_dict: - raise NotImplementedError(f"Chain {self} does not support saving.") - if save_path.suffix == ".json": with open(file_path, "w") as f: json.dump(chain_dict, f, indent=4) diff --git a/libs/langchain/langchain/prompts/few_shot.py b/libs/langchain/langchain/prompts/few_shot.py index bf9c4421b0..e8fa1b2447 100644 --- a/libs/langchain/langchain/prompts/few_shot.py +++ b/libs/langchain/langchain/prompts/few_shot.py @@ -1,6 +1,7 @@ """Prompt template that contains few shot examples.""" from __future__ import annotations +from pathlib import Path from typing import Any, Dict, List, Optional, Union from langchain.prompts.base import ( @@ -151,11 +152,10 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): """Return the prompt type key.""" return "few_shot" - def dict(self, **kwargs: Any) -> Dict: - """Return a dictionary of the prompt.""" + def save(self, file_path: Union[Path, str]) -> None: if self.example_selector: raise ValueError("Saving an example selector is not currently supported") - return super().dict(**kwargs) + return super().save(file_path) class FewShotChatMessagePromptTemplate( diff --git a/libs/langchain/langchain/prompts/few_shot_with_templates.py b/libs/langchain/langchain/prompts/few_shot_with_templates.py index a66f1a37ee..1e34a0f5a8 100644 --- a/libs/langchain/langchain/prompts/few_shot_with_templates.py +++ b/libs/langchain/langchain/prompts/few_shot_with_templates.py @@ -1,5 +1,6 @@ """Prompt template that contains few shot examples.""" -from typing import Any, Dict, List, Optional +from pathlib import Path +from typing import Any, Dict, List, Optional, Union from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate from langchain.prompts.example_selector.base import BaseExampleSelector @@ -140,8 +141,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): """Return the prompt type key.""" return "few_shot_with_templates" - def dict(self, **kwargs: Any) -> Dict: - """Return a dictionary of the prompt.""" + def save(self, file_path: Union[Path, str]) -> None: if self.example_selector: raise ValueError("Saving an example selector is not currently supported") - return super().dict(**kwargs) + return super().save(file_path) diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index eed1491b8f..e0eb5b00c2 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -298,7 +298,10 @@ class BaseOutputParser( def dict(self, **kwargs: Any) -> Dict: """Return dictionary representation of output parser.""" output_parser_dict = super().dict(**kwargs) - output_parser_dict["_type"] = self._type + try: + output_parser_dict["_type"] = self._type + except NotImplementedError: + pass return output_parser_dict diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index 9ac8164bb6..6f976beb70 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -132,7 +132,10 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): def dict(self, **kwargs: Any) -> Dict: """Return dictionary representation of prompt.""" prompt_dict = super().dict(**kwargs) - prompt_dict["_type"] = self._prompt_type + try: + prompt_dict["_type"] = self._prompt_type + except NotImplementedError: + pass return prompt_dict def save(self, file_path: Union[Path, str]) -> None: @@ -148,6 +151,12 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): """ if self.partial_variables: raise ValueError("Cannot save prompt with partial variables.") + + # Fetch dictionary to save + prompt_dict = self.dict() + if "_type" not in prompt_dict: + raise NotImplementedError(f"Prompt {self} does not support saving.") + # Convert file to Path object. if isinstance(file_path, str): save_path = Path(file_path) @@ -157,9 +166,6 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) - # Fetch dictionary to save - prompt_dict = self.dict() - if save_path.suffix == ".json": with open(file_path, "w") as f: json.dump(prompt_dict, f, indent=4)