Add _type for all parsers (#4189)

Used for serialization. Also add test that recurses through
our subclasses to check they have them implemented

Would fix https://github.com/hwchase17/langchain/issues/3217
Blocking: https://github.com/mlflow/mlflow/pull/8297

---------

Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
parallel_dir_loader
Sunish Sheth 1 year ago committed by GitHub
parent b21d7c138c
commit 812e5f43f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,3 +24,7 @@ class ChatOutputParser(AgentOutputParser):
except Exception:
raise OutputParserException(f"Could not parse LLM output: {text}")
@property
def _type(self) -> str:
return "chat"

@ -24,3 +24,7 @@ class ConvoOutputParser(AgentOutputParser):
action = match.group(1)
action_input = match.group(2)
return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text)
@property
def _type(self) -> str:
return "conversational"

@ -31,3 +31,7 @@ class ConvoOutputParser(AgentOutputParser):
return AgentFinish({"output": action_input}, text)
else:
return AgentAction(action, action_input, text)
@property
def _type(self) -> str:
return "conversational_chat"

@ -27,3 +27,7 @@ class MRKLOutputParser(AgentOutputParser):
action = match.group(1).strip()
action_input = match.group(2)
return AgentAction(action, action_input.strip(" ").strip('"'), text)
@property
def _type(self) -> str:
return "mrkl"

@ -24,3 +24,7 @@ class ReActOutputParser(AgentOutputParser):
return AgentFinish({"output": action_input}, text)
else:
return AgentAction(action, action_input, text)
@property
def _type(self) -> str:
return "react"

@ -20,3 +20,7 @@ class SelfAskOutputParser(AgentOutputParser):
if " " == after_colon[0]:
after_colon = after_colon[1:]
return AgentAction("Intermediate Answer", after_colon, text)
@property
def _type(self) -> str:
return "self_ask"

@ -40,6 +40,10 @@ class StructuredChatOutputParser(AgentOutputParser):
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e
@property
def _type(self) -> str:
return "structured_chat"
class StructuredChatOutputParserWithRetries(AgentOutputParser):
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
@ -76,3 +80,7 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
return cls(base_parser=base_parser)
else:
return cls()
@property
def _type(self) -> str:
return "structured_chat_with_retries"

@ -31,6 +31,10 @@ class APIRequesterOutputParser(BaseOutputParser):
return f"MESSAGE: {message_match.group(1).strip()}"
return "ERROR making request"
@property
def _type(self) -> str:
return "api_requester"
class APIRequesterChain(LLMChain):
"""Get the request parser."""

@ -31,6 +31,10 @@ class APIResponderOutputParser(BaseOutputParser):
else:
raise ValueError(f"No response found in output: {llm_output}.")
@property
def _type(self) -> str:
return "api_responder"
class APIResponderChain(LLMChain):
"""Get the response parser."""

@ -52,6 +52,10 @@ class BashOutputParser(BaseOutputParser):
return code_blocks
@property
def _type(self) -> str:
return "bash"
PROMPT = PromptTemplate(
input_variables=["question"],

@ -45,4 +45,4 @@ class OutputFixingParser(BaseOutputParser[T]):
@property
def _type(self) -> str:
return self.parser._type
return "output_fixing"

@ -78,7 +78,7 @@ class RetryOutputParser(BaseOutputParser[T]):
@property
def _type(self) -> str:
return self.parser._type
return "retry"
class RetryWithErrorOutputParser(BaseOutputParser[T]):
@ -122,3 +122,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()
@property
def _type(self) -> str:
return "retry_with_error"

@ -227,25 +227,25 @@ class BaseChatMessageHistory(ABC):
class FileChatMessageHistory(BaseChatMessageHistory):
storage_path: str
session_id: str
@property
def messages(self):
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
messages = json.loads(f.read())
return messages_from_dict(messages)
return messages_from_dict(messages)
def add_user_message(self, message: str):
message_ = HumanMessage(content=message)
messages = self.messages.append(_message_to_dict(_message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def add_ai_message(self, message: str):
message_ = AIMessage(content=message)
messages = self.messages.append(_message_to_dict(_message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
@ -348,7 +348,10 @@ class BaseOutputParser(BaseModel, ABC, Generic[T]):
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""

@ -0,0 +1,47 @@
"""Test the BaseOutputParser class and its sub-classes."""
from abc import ABC
from typing import List, Optional, Set, Type
import pytest
from langchain.schema import BaseOutputParser
def non_abstract_subclasses(
cls: Type[ABC], to_skip: Optional[Set] = None
) -> List[Type]:
"""Recursively find all non-abstract subclasses of a class."""
_to_skip = to_skip or set()
subclasses = []
for subclass in cls.__subclasses__():
if not getattr(subclass, "__abstractmethods__", None):
if subclass.__name__ not in _to_skip:
subclasses.append(subclass)
subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip))
return subclasses
_PARSERS_TO_SKIP = {"FakeOutputParser", "BaseOutputParser"}
_NON_ABSTRACT_PARSERS = non_abstract_subclasses(
BaseOutputParser, to_skip=_PARSERS_TO_SKIP
)
@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS)
def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None:
try:
cls._type
except NotImplementedError:
pytest.fail(f"_type property is not implemented in class {cls.__name__}")
def test_all_subclasses_implement_unique_type() -> None:
types = []
for cls in _NON_ABSTRACT_PARSERS:
try:
types.append(cls._type)
except NotImplementedError:
# This is handled in the previous test
pass
dups = set([t for t in types if types.count(t) > 1])
assert not dups, f"Duplicate types: {dups}"
Loading…
Cancel
Save