mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add type inference for output parsers (#2769)
Currently, the output type of a number of OutputParser's `parse` methods is `Any` when it can in fact be inferred. This PR makes BaseOutputParser use a generic type and fixes the output types of the following parsers: - `PydanticOutputParser` - `OutputFixingParser` - `RetryOutputParser` - `RetryWithErrorOutputParser` The output of the `StructuredOutputParser` is corrected from `BaseModel` to `Any` since there are no type guarantees provided by the parser. Fixes issue #2715
This commit is contained in:
parent
789cc314c5
commit
59d054308c
@ -1,30 +1,32 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import TypeVar
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
|
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
class OutputFixingParser(BaseOutputParser):
|
|
||||||
|
class OutputFixingParser(BaseOutputParser[T]):
|
||||||
"""Wraps a parser and tries to fix parsing errors."""
|
"""Wraps a parser and tries to fix parsing errors."""
|
||||||
|
|
||||||
parser: BaseOutputParser
|
parser: BaseOutputParser[T]
|
||||||
retry_chain: LLMChain
|
retry_chain: LLMChain
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
parser: BaseOutputParser,
|
parser: BaseOutputParser[T],
|
||||||
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
|
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
|
||||||
) -> OutputFixingParser:
|
) -> OutputFixingParser[T]:
|
||||||
chain = LLMChain(llm=llm, prompt=prompt)
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
return cls(parser=parser, retry_chain=chain)
|
return cls(parser=parser, retry_chain=chain)
|
||||||
|
|
||||||
def parse(self, completion: str) -> Any:
|
def parse(self, completion: str) -> T:
|
||||||
try:
|
try:
|
||||||
parsed_completion = self.parser.parse(completion)
|
parsed_completion = self.parser.parse(completion)
|
||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
|
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException
|
from langchain.schema import BaseOutputParser, OutputParserException
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
class PydanticOutputParser(BaseOutputParser):
|
|
||||||
pydantic_object: Any
|
|
||||||
|
|
||||||
def parse(self, text: str) -> BaseModel:
|
class PydanticOutputParser(BaseOutputParser[T]):
|
||||||
|
pydantic_object: Type[T]
|
||||||
|
|
||||||
|
def parse(self, text: str) -> T:
|
||||||
try:
|
try:
|
||||||
# Greedy search for 1st json candidate.
|
# Greedy search for 1st json candidate.
|
||||||
match = re.search(
|
match = re.search(
|
||||||
@ -38,6 +40,6 @@ class PydanticOutputParser(BaseOutputParser):
|
|||||||
if "type" in reduced_schema:
|
if "type" in reduced_schema:
|
||||||
del reduced_schema["type"]
|
del reduced_schema["type"]
|
||||||
# Ensure json in context is well-formed with double quotes.
|
# Ensure json in context is well-formed with double quotes.
|
||||||
schema = json.dumps(reduced_schema)
|
schema_str = json.dumps(reduced_schema)
|
||||||
|
|
||||||
return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema)
|
return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import TypeVar
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
@ -34,28 +34,30 @@ NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
|
|||||||
NAIVE_COMPLETION_RETRY_WITH_ERROR
|
NAIVE_COMPLETION_RETRY_WITH_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
class RetryOutputParser(BaseOutputParser):
|
|
||||||
|
class RetryOutputParser(BaseOutputParser[T]):
|
||||||
"""Wraps a parser and tries to fix parsing errors.
|
"""Wraps a parser and tries to fix parsing errors.
|
||||||
|
|
||||||
Does this by passing the original prompt and the completion to another
|
Does this by passing the original prompt and the completion to another
|
||||||
LLM, and telling it the completion did not satisfy criteria in the prompt.
|
LLM, and telling it the completion did not satisfy criteria in the prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parser: BaseOutputParser
|
parser: BaseOutputParser[T]
|
||||||
retry_chain: LLMChain
|
retry_chain: LLMChain
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
parser: BaseOutputParser,
|
parser: BaseOutputParser[T],
|
||||||
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
|
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
|
||||||
) -> RetryOutputParser:
|
) -> RetryOutputParser[T]:
|
||||||
chain = LLMChain(llm=llm, prompt=prompt)
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
return cls(parser=parser, retry_chain=chain)
|
return cls(parser=parser, retry_chain=chain)
|
||||||
|
|
||||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
|
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||||
try:
|
try:
|
||||||
parsed_completion = self.parser.parse(completion)
|
parsed_completion = self.parser.parse(completion)
|
||||||
except OutputParserException:
|
except OutputParserException:
|
||||||
@ -66,7 +68,7 @@ class RetryOutputParser(BaseOutputParser):
|
|||||||
|
|
||||||
return parsed_completion
|
return parsed_completion
|
||||||
|
|
||||||
def parse(self, completion: str) -> Any:
|
def parse(self, completion: str) -> T:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"This OutputParser can only be called by the `parse_with_prompt` method."
|
"This OutputParser can only be called by the `parse_with_prompt` method."
|
||||||
)
|
)
|
||||||
@ -75,7 +77,7 @@ class RetryOutputParser(BaseOutputParser):
|
|||||||
return self.parser.get_format_instructions()
|
return self.parser.get_format_instructions()
|
||||||
|
|
||||||
|
|
||||||
class RetryWithErrorOutputParser(BaseOutputParser):
|
class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||||
"""Wraps a parser and tries to fix parsing errors.
|
"""Wraps a parser and tries to fix parsing errors.
|
||||||
|
|
||||||
Does this by passing the original prompt, the completion, AND the error
|
Does this by passing the original prompt, the completion, AND the error
|
||||||
@ -85,20 +87,20 @@ class RetryWithErrorOutputParser(BaseOutputParser):
|
|||||||
LLM, which in theory should give it more information on how to fix it.
|
LLM, which in theory should give it more information on how to fix it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parser: BaseOutputParser
|
parser: BaseOutputParser[T]
|
||||||
retry_chain: LLMChain
|
retry_chain: LLMChain
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
parser: BaseOutputParser,
|
parser: BaseOutputParser[T],
|
||||||
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
|
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
|
||||||
) -> RetryWithErrorOutputParser:
|
) -> RetryWithErrorOutputParser[T]:
|
||||||
chain = LLMChain(llm=llm, prompt=prompt)
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
return cls(parser=parser, retry_chain=chain)
|
return cls(parser=parser, retry_chain=chain)
|
||||||
|
|
||||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
|
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||||
try:
|
try:
|
||||||
parsed_completion = self.parser.parse(completion)
|
parsed_completion = self.parser.parse(completion)
|
||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
@ -109,7 +111,7 @@ class RetryWithErrorOutputParser(BaseOutputParser):
|
|||||||
|
|
||||||
return parsed_completion
|
return parsed_completion
|
||||||
|
|
||||||
def parse(self, completion: str) -> Any:
|
def parse(self, completion: str) -> T:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"This OutputParser can only be called by the `parse_with_prompt` method."
|
"This OutputParser can only be called by the `parse_with_prompt` method."
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import List
|
from typing import Any, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ class StructuredOutputParser(BaseOutputParser):
|
|||||||
)
|
)
|
||||||
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
||||||
|
|
||||||
def parse(self, text: str) -> BaseModel:
|
def parse(self, text: str) -> Any:
|
||||||
json_string = text.split("```json")[1].strip().strip("```").strip()
|
json_string = text.split("```json")[1].strip().strip("```").strip()
|
||||||
try:
|
try:
|
||||||
json_obj = json.loads(json_string)
|
json_obj = json.loads(json_string)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional
|
from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
@ -327,15 +327,17 @@ class BaseRetriever(ABC):
|
|||||||
|
|
||||||
Memory = BaseMemory
|
Memory = BaseMemory
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
class BaseOutputParser(BaseModel, ABC):
|
|
||||||
|
class BaseOutputParser(BaseModel, ABC, Generic[T]):
|
||||||
"""Class to parse the output of an LLM call.
|
"""Class to parse the output of an LLM call.
|
||||||
|
|
||||||
Output parsers help structure language model responses.
|
Output parsers help structure language model responses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse(self, text: str) -> Any:
|
def parse(self, text: str) -> T:
|
||||||
"""Parse the output of an LLM call.
|
"""Parse the output of an LLM call.
|
||||||
|
|
||||||
A method which takes in a string (assumed output of language model )
|
A method which takes in a string (assumed output of language model )
|
||||||
|
@ -46,7 +46,9 @@ DEF_EXPECTED_RESULT = TestModel(
|
|||||||
def test_pydantic_output_parser() -> None:
|
def test_pydantic_output_parser() -> None:
|
||||||
"""Test PydanticOutputParser."""
|
"""Test PydanticOutputParser."""
|
||||||
|
|
||||||
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
|
pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser(
|
||||||
|
pydantic_object=TestModel
|
||||||
|
)
|
||||||
|
|
||||||
result = pydantic_parser.parse(DEF_RESULT)
|
result = pydantic_parser.parse(DEF_RESULT)
|
||||||
print("parse_result:", result)
|
print("parse_result:", result)
|
||||||
@ -56,7 +58,9 @@ def test_pydantic_output_parser() -> None:
|
|||||||
def test_pydantic_output_parser_fail() -> None:
|
def test_pydantic_output_parser_fail() -> None:
|
||||||
"""Test PydanticOutputParser where completion result fails schema validation."""
|
"""Test PydanticOutputParser where completion result fails schema validation."""
|
||||||
|
|
||||||
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
|
pydantic_parser: PydanticOutputParser[TestModel] = PydanticOutputParser(
|
||||||
|
pydantic_object=TestModel
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pydantic_parser.parse(DEF_RESULT_FAIL)
|
pydantic_parser.parse(DEF_RESULT_FAIL)
|
||||||
|
Loading…
Reference in New Issue
Block a user