From 812e5f43f541ed8a20f6105f44ddb7e82d86abf2 Mon Sep 17 00:00:00 2001 From: Sunish Sheth Date: Thu, 11 May 2023 01:27:58 -0700 Subject: [PATCH] Add _type for all parsers (#4189) Used for serialization. Also add test that recurses through our subclasses to check they have them implemented Would fix https://github.com/hwchase17/langchain/issues/3217 Blocking: https://github.com/mlflow/mlflow/pull/8297 --------- Signed-off-by: Sunish Sheth Co-authored-by: Dev 2049 --- langchain/agents/chat/output_parser.py | 4 ++ .../agents/conversational/output_parser.py | 4 ++ .../conversational_chat/output_parser.py | 4 ++ langchain/agents/mrkl/output_parser.py | 4 ++ langchain/agents/react/output_parser.py | 4 ++ .../self_ask_with_search/output_parser.py | 4 ++ .../agents/structured_chat/output_parser.py | 8 ++++ .../chains/api/openapi/requests_chain.py | 4 ++ .../chains/api/openapi/response_chain.py | 4 ++ langchain/chains/llm_bash/prompt.py | 4 ++ langchain/output_parsers/fix.py | 2 +- langchain/output_parsers/retry.py | 6 ++- langchain/schema.py | 15 +++--- .../output_parsers/test_base_output_parser.py | 47 +++++++++++++++++++ 14 files changed, 106 insertions(+), 8 deletions(-) create mode 100644 tests/unit_tests/output_parsers/test_base_output_parser.py diff --git a/langchain/agents/chat/output_parser.py b/langchain/agents/chat/output_parser.py index 71a8edd7..9f143d07 100644 --- a/langchain/agents/chat/output_parser.py +++ b/langchain/agents/chat/output_parser.py @@ -24,3 +24,7 @@ class ChatOutputParser(AgentOutputParser): except Exception: raise OutputParserException(f"Could not parse LLM output: {text}") + + @property + def _type(self) -> str: + return "chat" diff --git a/langchain/agents/conversational/output_parser.py b/langchain/agents/conversational/output_parser.py index f11eb540..84c4fec5 100644 --- a/langchain/agents/conversational/output_parser.py +++ b/langchain/agents/conversational/output_parser.py @@ -24,3 +24,7 @@ class ConvoOutputParser(AgentOutputParser): action = match.group(1) action_input = match.group(2) return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text) + + @property + def _type(self) -> str: + return "conversational" diff --git a/langchain/agents/conversational_chat/output_parser.py b/langchain/agents/conversational_chat/output_parser.py index 3b2e7b52..99880fac 100644 --- a/langchain/agents/conversational_chat/output_parser.py +++ b/langchain/agents/conversational_chat/output_parser.py @@ -31,3 +31,7 @@ class ConvoOutputParser(AgentOutputParser): return AgentFinish({"output": action_input}, text) else: return AgentAction(action, action_input, text) + + @property + def _type(self) -> str: + return "conversational_chat" diff --git a/langchain/agents/mrkl/output_parser.py b/langchain/agents/mrkl/output_parser.py index 6f809eb5..0b77c828 100644 --- a/langchain/agents/mrkl/output_parser.py +++ b/langchain/agents/mrkl/output_parser.py @@ -27,3 +27,7 @@ class MRKLOutputParser(AgentOutputParser): action = match.group(1).strip() action_input = match.group(2) return AgentAction(action, action_input.strip(" ").strip('"'), text) + + @property + def _type(self) -> str: + return "mrkl" diff --git a/langchain/agents/react/output_parser.py b/langchain/agents/react/output_parser.py index f63cd3e2..9904d366 100644 --- a/langchain/agents/react/output_parser.py +++ b/langchain/agents/react/output_parser.py @@ -24,3 +24,7 @@ class ReActOutputParser(AgentOutputParser): return AgentFinish({"output": action_input}, text) else: return AgentAction(action, action_input, text) + + @property + def _type(self) -> str: + return "react" diff --git a/langchain/agents/self_ask_with_search/output_parser.py b/langchain/agents/self_ask_with_search/output_parser.py index 009c6161..a091adee 100644 --- a/langchain/agents/self_ask_with_search/output_parser.py +++ b/langchain/agents/self_ask_with_search/output_parser.py @@ -20,3 +20,7 @@ class SelfAskOutputParser(AgentOutputParser): if " " == after_colon[0]: after_colon = after_colon[1:] return AgentAction("Intermediate Answer", after_colon, text) + + @property + def _type(self) -> str: + return "self_ask" diff --git a/langchain/agents/structured_chat/output_parser.py b/langchain/agents/structured_chat/output_parser.py index 4f9240ad..d53ae58c 100644 --- a/langchain/agents/structured_chat/output_parser.py +++ b/langchain/agents/structured_chat/output_parser.py @@ -40,6 +40,10 @@ class StructuredChatOutputParser(AgentOutputParser): except Exception as e: raise OutputParserException(f"Could not parse LLM output: {text}") from e + @property + def _type(self) -> str: + return "structured_chat" + class StructuredChatOutputParserWithRetries(AgentOutputParser): base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser) @@ -76,3 +80,7 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser): return cls(base_parser=base_parser) else: return cls() + + @property + def _type(self) -> str: + return "structured_chat_with_retries" diff --git a/langchain/chains/api/openapi/requests_chain.py b/langchain/chains/api/openapi/requests_chain.py index e26ce296..223a630f 100644 --- a/langchain/chains/api/openapi/requests_chain.py +++ b/langchain/chains/api/openapi/requests_chain.py @@ -31,6 +31,10 @@ class APIRequesterOutputParser(BaseOutputParser): return f"MESSAGE: {message_match.group(1).strip()}" return "ERROR making request" + @property + def _type(self) -> str: + return "api_requester" + class APIRequesterChain(LLMChain): """Get the request parser.""" diff --git a/langchain/chains/api/openapi/response_chain.py b/langchain/chains/api/openapi/response_chain.py index 325797d3..21b4af3b 100644 --- a/langchain/chains/api/openapi/response_chain.py +++ b/langchain/chains/api/openapi/response_chain.py @@ -31,6 +31,10 @@ class APIResponderOutputParser(BaseOutputParser): else: raise ValueError(f"No response found in output: {llm_output}.") + @property + def _type(self) -> str: + return "api_responder" + class APIResponderChain(LLMChain): """Get the response parser.""" diff --git a/langchain/chains/llm_bash/prompt.py b/langchain/chains/llm_bash/prompt.py index 363b5505..72951d2f 100644 --- a/langchain/chains/llm_bash/prompt.py +++ b/langchain/chains/llm_bash/prompt.py @@ -52,6 +52,10 @@ class BashOutputParser(BaseOutputParser): return code_blocks + @property + def _type(self) -> str: + return "bash" + PROMPT = PromptTemplate( input_variables=["question"], diff --git a/langchain/output_parsers/fix.py b/langchain/output_parsers/fix.py index a46b2e4e..166d570f 100644 --- a/langchain/output_parsers/fix.py +++ b/langchain/output_parsers/fix.py @@ -45,4 +45,4 @@ class OutputFixingParser(BaseOutputParser[T]): @property def _type(self) -> str: - return self.parser._type + return "output_fixing" diff --git a/langchain/output_parsers/retry.py b/langchain/output_parsers/retry.py index 080d1a49..bbbe82d9 100644 --- a/langchain/output_parsers/retry.py +++ b/langchain/output_parsers/retry.py @@ -78,7 +78,7 @@ class RetryOutputParser(BaseOutputParser[T]): @property def _type(self) -> str: - return self.parser._type + return "retry" class RetryWithErrorOutputParser(BaseOutputParser[T]): @@ -122,3 +122,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): def get_format_instructions(self) -> str: return self.parser.get_format_instructions() + + @property + def _type(self) -> str: + return "retry_with_error" diff --git a/langchain/schema.py b/langchain/schema.py index ac248c9d..21552b9b 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -227,25 +227,25 @@ class BaseChatMessageHistory(ABC): class FileChatMessageHistory(BaseChatMessageHistory): storage_path: str session_id: str - + @property def messages(self): with open(os.path.join(storage_path, session_id), 'r:utf-8') as f: messages = json.loads(f.read()) - return messages_from_dict(messages) - + return messages_from_dict(messages) + def add_user_message(self, message: str): message_ = HumanMessage(content=message) messages = self.messages.append(_message_to_dict(_message)) with open(os.path.join(storage_path, session_id), 'w') as f: json.dump(f, messages) - + def add_ai_message(self, message: str): message_ = AIMessage(content=message) messages = self.messages.append(_message_to_dict(_message)) with open(os.path.join(storage_path, session_id), 'w') as f: json.dump(f, messages) - + def clear(self): with open(os.path.join(storage_path, session_id), 'w') as f: f.write("[]") @@ -348,7 +348,10 @@ class BaseOutputParser(BaseModel, ABC, Generic[T]): @property def _type(self) -> str: """Return the type key.""" - raise NotImplementedError + raise NotImplementedError( + f"_type property is not implemented in class {self.__class__.__name__}." + " This is required for serialization." + ) def dict(self, **kwargs: Any) -> Dict: """Return dictionary representation of output parser.""" diff --git a/tests/unit_tests/output_parsers/test_base_output_parser.py b/tests/unit_tests/output_parsers/test_base_output_parser.py new file mode 100644 index 00000000..9cbbd910 --- /dev/null +++ b/tests/unit_tests/output_parsers/test_base_output_parser.py @@ -0,0 +1,47 @@ +"""Test the BaseOutputParser class and its sub-classes.""" +from abc import ABC +from typing import List, Optional, Set, Type + +import pytest + +from langchain.schema import BaseOutputParser + + +def non_abstract_subclasses( + cls: Type[ABC], to_skip: Optional[Set] = None +) -> List[Type]: + """Recursively find all non-abstract subclasses of a class.""" + _to_skip = to_skip or set() + subclasses = [] + for subclass in cls.__subclasses__(): + if not getattr(subclass, "__abstractmethods__", None): + if subclass.__name__ not in _to_skip: + subclasses.append(subclass) + subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip)) + return subclasses + + +_PARSERS_TO_SKIP = {"FakeOutputParser", "BaseOutputParser"} +_NON_ABSTRACT_PARSERS = non_abstract_subclasses( + BaseOutputParser, to_skip=_PARSERS_TO_SKIP +) + + +@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS) +def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None: + try: + cls._type + except NotImplementedError: + pytest.fail(f"_type property is not implemented in class {cls.__name__}") + + +def test_all_subclasses_implement_unique_type() -> None: + types = [] + for cls in _NON_ABSTRACT_PARSERS: + try: + types.append(cls._type) + except NotImplementedError: + # This is handled in the previous test + pass + dups = set([t for t in types if types.count(t) > 1]) + assert not dups, f"Duplicate types: {dups}"