diff --git a/langchain/output_parsers/fix.py b/langchain/output_parsers/fix.py index 2654948a..b695586e 100644 --- a/langchain/output_parsers/fix.py +++ b/langchain/output_parsers/fix.py @@ -1,30 +1,32 @@ from __future__ import annotations -from typing import Any +from typing import TypeVar from langchain.chains.llm import LLMChain from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT from langchain.prompts.base import BasePromptTemplate from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException +T = TypeVar("T") -class OutputFixingParser(BaseOutputParser): + +class OutputFixingParser(BaseOutputParser[T]): """Wraps a parser and tries to fix parsing errors.""" - parser: BaseOutputParser + parser: BaseOutputParser[T] retry_chain: LLMChain @classmethod def from_llm( cls, llm: BaseLanguageModel, - parser: BaseOutputParser, + parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_FIX_PROMPT, - ) -> OutputFixingParser: + ) -> OutputFixingParser[T]: chain = LLMChain(llm=llm, prompt=prompt) return cls(parser=parser, retry_chain=chain) - def parse(self, completion: str) -> Any: + def parse(self, completion: str) -> T: try: parsed_completion = self.parser.parse(completion) except OutputParserException as e: diff --git a/langchain/output_parsers/pydantic.py b/langchain/output_parsers/pydantic.py index 9f818cd9..7a4050d0 100644 --- a/langchain/output_parsers/pydantic.py +++ b/langchain/output_parsers/pydantic.py @@ -1,17 +1,19 @@ import json import re -from typing import Any +from typing import Type, TypeVar from pydantic import BaseModel, ValidationError from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS from langchain.schema import BaseOutputParser, OutputParserException +T = TypeVar("T", bound=BaseModel) -class PydanticOutputParser(BaseOutputParser): - pydantic_object: Any - def parse(self, text: str) -> BaseModel: +class PydanticOutputParser(BaseOutputParser[T]): + pydantic_object: Type[T] + + def parse(self, text: str) -> T: try: # Greedy search for 1st json candidate. match = re.search( @@ -38,6 +40,6 @@ class PydanticOutputParser(BaseOutputParser): if "type" in reduced_schema: del reduced_schema["type"] # Ensure json in context is well-formed with double quotes. - schema = json.dumps(reduced_schema) + schema_str = json.dumps(reduced_schema) - return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema) + return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str) diff --git a/langchain/output_parsers/retry.py b/langchain/output_parsers/retry.py index 7c6760ea..6ef08cf6 100644 --- a/langchain/output_parsers/retry.py +++ b/langchain/output_parsers/retry.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import TypeVar from langchain.chains.llm import LLMChain from langchain.prompts.base import BasePromptTemplate @@ -34,28 +34,30 @@ NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template( NAIVE_COMPLETION_RETRY_WITH_ERROR ) +T = TypeVar("T") -class RetryOutputParser(BaseOutputParser): + +class RetryOutputParser(BaseOutputParser[T]): """Wraps a parser and tries to fix parsing errors. Does this by passing the original prompt and the completion to another LLM, and telling it the completion did not satisfy criteria in the prompt. """ - parser: BaseOutputParser + parser: BaseOutputParser[T] retry_chain: LLMChain @classmethod def from_llm( cls, llm: BaseLanguageModel, - parser: BaseOutputParser, + parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT, - ) -> RetryOutputParser: + ) -> RetryOutputParser[T]: chain = LLMChain(llm=llm, prompt=prompt) return cls(parser=parser, retry_chain=chain) - def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: + def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: try: parsed_completion = self.parser.parse(completion) except OutputParserException: @@ -66,7 +68,7 @@ class RetryOutputParser(BaseOutputParser): return parsed_completion - def parse(self, completion: str) -> Any: + def parse(self, completion: str) -> T: raise NotImplementedError( "This OutputParser can only be called by the `parse_with_prompt` method." ) @@ -75,7 +77,7 @@ class RetryOutputParser(BaseOutputParser): return self.parser.get_format_instructions() -class RetryWithErrorOutputParser(BaseOutputParser): +class RetryWithErrorOutputParser(BaseOutputParser[T]): """Wraps a parser and tries to fix parsing errors. Does this by passing the original prompt, the completion, AND the error @@ -85,20 +87,20 @@ class RetryWithErrorOutputParser(BaseOutputParser): LLM, which in theory should give it more information on how to fix it. """ - parser: BaseOutputParser + parser: BaseOutputParser[T] retry_chain: LLMChain @classmethod def from_llm( cls, llm: BaseLanguageModel, - parser: BaseOutputParser, + parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT, - ) -> RetryWithErrorOutputParser: + ) -> RetryWithErrorOutputParser[T]: chain = LLMChain(llm=llm, prompt=prompt) return cls(parser=parser, retry_chain=chain) - def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: + def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: try: parsed_completion = self.parser.parse(completion) except OutputParserException as e: @@ -109,7 +111,7 @@ class RetryWithErrorOutputParser(BaseOutputParser): return parsed_completion - def parse(self, completion: str) -> Any: + def parse(self, completion: str) -> T: raise NotImplementedError( "This OutputParser can only be called by the `parse_with_prompt` method." ) diff --git a/langchain/output_parsers/structured.py b/langchain/output_parsers/structured.py index c77f865a..566e0885 100644 --- a/langchain/output_parsers/structured.py +++ b/langchain/output_parsers/structured.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import List +from typing import Any, List from pydantic import BaseModel @@ -37,7 +37,7 @@ class StructuredOutputParser(BaseOutputParser): ) return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str) - def parse(self, text: str) -> BaseModel: + def parse(self, text: str) -> Any: json_string = text.split("```json")[1].strip().strip("```").strip() try: json_obj = json.loads(json_string) diff --git a/langchain/schema.py b/langchain/schema.py index 2acb4273..8678b1a8 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List, NamedTuple, Optional +from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar from pydantic import BaseModel, Extra, Field, root_validator @@ -327,15 +327,17 @@ class BaseRetriever(ABC): Memory = BaseMemory +T = TypeVar("T") -class BaseOutputParser(BaseModel, ABC): + +class BaseOutputParser(BaseModel, ABC, Generic[T]): """Class to parse the output of an LLM call. Output parsers help structure language model responses. """ @abstractmethod - def parse(self, text: str) -> Any: + def parse(self, text: str) -> T: """Parse the output of an LLM call. A method which takes in a string (assumed output of language model ) diff --git a/tests/unit_tests/output_parsers/test_pydantic_parser.py b/tests/unit_tests/output_parsers/test_pydantic_parser.py index 85bc3387..5acf88f5 100644 --- a/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -46,7 +46,9 @@ DEF_EXPECTED_RESULT = TestModel( def test_pydantic_output_parser() -> None: """Test PydanticOutputParser.""" - pydantic_parser = PydanticOutputParser(pydantic_object=TestModel) + pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser( + pydantic_object=TestModel + ) result = pydantic_parser.parse(DEF_RESULT) print("parse_result:", result) @@ -56,7 +58,9 @@ def test_pydantic_output_parser() -> None: def test_pydantic_output_parser_fail() -> None: """Test PydanticOutputParser where completion result fails schema validation.""" - pydantic_parser = PydanticOutputParser(pydantic_object=TestModel) + pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser( + pydantic_object=TestModel + ) try: pydantic_parser.parse(DEF_RESULT_FAIL)