From 59d054308c850da1a61fc9621385182c7459120d Mon Sep 17 00:00:00 2001 From: Joshua Snyder Date: Wed, 12 Apr 2023 18:12:20 +0200 Subject: [PATCH] Add type inference for output parsers (#2769) Currently, the output type of a number of OutputParser's `parse` methods is `Any` when it can in fact be inferred. This PR makes BaseOutputParser use a generic type and fixes the output types of the following parsers: - `PydanticOutputParser` - `OutputFixingParser` - `RetryOutputParser` - `RetryWithErrorOutputParser` The output of the `StructuredOutputParser` is corrected from `BaseModel` to `Any` since there are no type guarantees provided by the parser. Fixes issue #2715 --- langchain/output_parsers/fix.py | 14 ++++++---- langchain/output_parsers/pydantic.py | 14 ++++++---- langchain/output_parsers/retry.py | 28 ++++++++++--------- langchain/output_parsers/structured.py | 4 +-- langchain/schema.py | 8 ++++-- .../output_parsers/test_pydantic_parser.py | 8 ++++-- 6 files changed, 44 insertions(+), 32 deletions(-) diff --git a/langchain/output_parsers/fix.py b/langchain/output_parsers/fix.py index 2654948ad2..b695586ead 100644 --- a/langchain/output_parsers/fix.py +++ b/langchain/output_parsers/fix.py @@ -1,30 +1,32 @@ from __future__ import annotations -from typing import Any +from typing import TypeVar from langchain.chains.llm import LLMChain from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT from langchain.prompts.base import BasePromptTemplate from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException +T = TypeVar("T") -class OutputFixingParser(BaseOutputParser): + +class OutputFixingParser(BaseOutputParser[T]): """Wraps a parser and tries to fix parsing errors.""" - parser: BaseOutputParser + parser: BaseOutputParser[T] retry_chain: LLMChain @classmethod def from_llm( cls, llm: BaseLanguageModel, - parser: BaseOutputParser, + parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_FIX_PROMPT, - ) -> OutputFixingParser: + ) -> OutputFixingParser[T]: chain = LLMChain(llm=llm, prompt=prompt) return cls(parser=parser, retry_chain=chain) - def parse(self, completion: str) -> Any: + def parse(self, completion: str) -> T: try: parsed_completion = self.parser.parse(completion) except OutputParserException as e: diff --git a/langchain/output_parsers/pydantic.py b/langchain/output_parsers/pydantic.py index 9f818cd916..7a4050d058 100644 --- a/langchain/output_parsers/pydantic.py +++ b/langchain/output_parsers/pydantic.py @@ -1,17 +1,19 @@ import json import re -from typing import Any +from typing import Type, TypeVar from pydantic import BaseModel, ValidationError from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS from langchain.schema import BaseOutputParser, OutputParserException +T = TypeVar("T", bound=BaseModel) -class PydanticOutputParser(BaseOutputParser): - pydantic_object: Any - def parse(self, text: str) -> BaseModel: +class PydanticOutputParser(BaseOutputParser[T]): + pydantic_object: Type[T] + + def parse(self, text: str) -> T: try: # Greedy search for 1st json candidate. match = re.search( @@ -38,6 +40,6 @@ class PydanticOutputParser(BaseOutputParser): if "type" in reduced_schema: del reduced_schema["type"] # Ensure json in context is well-formed with double quotes. - schema = json.dumps(reduced_schema) + schema_str = json.dumps(reduced_schema) - return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema) + return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str) diff --git a/langchain/output_parsers/retry.py b/langchain/output_parsers/retry.py index 7c6760ea6f..6ef08cf683 100644 --- a/langchain/output_parsers/retry.py +++ b/langchain/output_parsers/retry.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import TypeVar from langchain.chains.llm import LLMChain from langchain.prompts.base import BasePromptTemplate @@ -34,28 +34,30 @@ NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template( NAIVE_COMPLETION_RETRY_WITH_ERROR ) +T = TypeVar("T") -class RetryOutputParser(BaseOutputParser): + +class RetryOutputParser(BaseOutputParser[T]): """Wraps a parser and tries to fix parsing errors. Does this by passing the original prompt and the completion to another LLM, and telling it the completion did not satisfy criteria in the prompt. """ - parser: BaseOutputParser + parser: BaseOutputParser[T] retry_chain: LLMChain @classmethod def from_llm( cls, llm: BaseLanguageModel, - parser: BaseOutputParser, + parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT, - ) -> RetryOutputParser: + ) -> RetryOutputParser[T]: chain = LLMChain(llm=llm, prompt=prompt) return cls(parser=parser, retry_chain=chain) - def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: + def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: try: parsed_completion = self.parser.parse(completion) except OutputParserException: @@ -66,7 +68,7 @@ class RetryOutputParser(BaseOutputParser): return parsed_completion - def parse(self, completion: str) -> Any: + def parse(self, completion: str) -> T: raise NotImplementedError( "This OutputParser can only be called by the `parse_with_prompt` method." ) @@ -75,7 +77,7 @@ class RetryOutputParser(BaseOutputParser): return self.parser.get_format_instructions() -class RetryWithErrorOutputParser(BaseOutputParser): +class RetryWithErrorOutputParser(BaseOutputParser[T]): """Wraps a parser and tries to fix parsing errors. Does this by passing the original prompt, the completion, AND the error @@ -85,20 +87,20 @@ class RetryWithErrorOutputParser(BaseOutputParser): LLM, which in theory should give it more information on how to fix it. """ - parser: BaseOutputParser + parser: BaseOutputParser[T] retry_chain: LLMChain @classmethod def from_llm( cls, llm: BaseLanguageModel, - parser: BaseOutputParser, + parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT, - ) -> RetryWithErrorOutputParser: + ) -> RetryWithErrorOutputParser[T]: chain = LLMChain(llm=llm, prompt=prompt) return cls(parser=parser, retry_chain=chain) - def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: + def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: try: parsed_completion = self.parser.parse(completion) except OutputParserException as e: @@ -109,7 +111,7 @@ class RetryWithErrorOutputParser(BaseOutputParser): return parsed_completion - def parse(self, completion: str) -> Any: + def parse(self, completion: str) -> T: raise NotImplementedError( "This OutputParser can only be called by the `parse_with_prompt` method." ) diff --git a/langchain/output_parsers/structured.py b/langchain/output_parsers/structured.py index c77f865a6f..566e0885d0 100644 --- a/langchain/output_parsers/structured.py +++ b/langchain/output_parsers/structured.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import List +from typing import Any, List from pydantic import BaseModel @@ -37,7 +37,7 @@ class StructuredOutputParser(BaseOutputParser): ) return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str) - def parse(self, text: str) -> BaseModel: + def parse(self, text: str) -> Any: json_string = text.split("```json")[1].strip().strip("```").strip() try: json_obj = json.loads(json_string) diff --git a/langchain/schema.py b/langchain/schema.py index 2acb427391..8678b1a89f 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List, NamedTuple, Optional +from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar from pydantic import BaseModel, Extra, Field, root_validator @@ -327,15 +327,17 @@ class BaseRetriever(ABC): Memory = BaseMemory +T = TypeVar("T") -class BaseOutputParser(BaseModel, ABC): + +class BaseOutputParser(BaseModel, ABC, Generic[T]): """Class to parse the output of an LLM call. Output parsers help structure language model responses. """ @abstractmethod - def parse(self, text: str) -> Any: + def parse(self, text: str) -> T: """Parse the output of an LLM call. A method which takes in a string (assumed output of language model ) diff --git a/tests/unit_tests/output_parsers/test_pydantic_parser.py b/tests/unit_tests/output_parsers/test_pydantic_parser.py index 85bc338764..5acf88f5e4 100644 --- a/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -46,7 +46,9 @@ DEF_EXPECTED_RESULT = TestModel( def test_pydantic_output_parser() -> None: """Test PydanticOutputParser.""" - pydantic_parser = PydanticOutputParser(pydantic_object=TestModel) + pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser( + pydantic_object=TestModel + ) result = pydantic_parser.parse(DEF_RESULT) print("parse_result:", result) @@ -56,7 +58,9 @@ def test_pydantic_output_parser() -> None: def test_pydantic_output_parser_fail() -> None: """Test PydanticOutputParser where completion result fails schema validation.""" - pydantic_parser = PydanticOutputParser(pydantic_object=TestModel) + pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser( + pydantic_object=TestModel + ) try: pydantic_parser.parse(DEF_RESULT_FAIL)