introduce output parser (#250)

This commit is contained in:
Harrison Chase 2022-12-03 13:28:07 -08:00 committed by GitHub
parent b4762dfff0
commit db58032973
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 3 deletions

View File

@ -1,5 +1,5 @@
"""Chain that just formats a prompt and calls an LLM."""
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from pydantic import BaseModel, Extra
@ -78,3 +78,11 @@ class LLMChain(Chain, BaseModel):
completion = llm.predict(adjective="funny")
"""
return self(kwargs)[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result

View File

@ -2,10 +2,10 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
import yaml
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, Extra, root_validator
from langchain.formatting import formatter
@ -32,11 +32,35 @@ def check_valid_template(
raise ValueError("Invalid prompt schema.")
class BaseOutputParser(ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
"""Parse the output of an LLM call."""
class ListOutputParser(ABC):
"""Class to parse the output of an LLM call to a list."""
@abstractmethod
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
class BasePromptTemplate(BaseModel, ABC):
"""Base prompt should expose the format method, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:

View File

@ -1,11 +1,22 @@
"""Test LLM chain."""
from typing import Dict, List, Union
import pytest
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BaseOutputParser
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeOutputParser(BaseOutputParser):
"""Fake output parser class for testing."""
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
"""Parse by splitting."""
return text.split()
@pytest.fixture
def fake_llm_chain() -> LLMChain:
"""Fake LLM chain for testing purposes."""
@ -34,3 +45,14 @@ def test_predict_method(fake_llm_chain: LLMChain) -> None:
"""Test predict method works."""
output = fake_llm_chain.predict(bar="baz")
assert output == "foo"
def test_predict_and_parse() -> None:
"""Test parsing ability."""
prompt = PromptTemplate(
input_variables=["foo"], template="{foo}", output_parser=FakeOutputParser()
)
llm = FakeLLM(queries={"foo": "foo bar"})
chain = LLMChain(prompt=prompt, llm=llm)
output = chain.predict_and_parse(foo="foo")
assert output == ["foo", "bar"]