From db58032973367bbb6ef1c6403ebef183c9f4347b Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 3 Dec 2022 13:28:07 -0800 Subject: [PATCH] introduce output parser (#250) --- langchain/chains/llm.py | 10 +++++++++- langchain/prompts/base.py | 28 ++++++++++++++++++++++++++-- tests/unit_tests/chains/test_llm.py | 22 ++++++++++++++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 58514bf1..54462852 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -1,5 +1,5 @@ """Chain that just formats a prompt and calls an LLM.""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from pydantic import BaseModel, Extra @@ -78,3 +78,11 @@ class LLMChain(Chain, BaseModel): completion = llm.predict(adjective="funny") """ return self(kwargs)[self.output_key] + + def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]: + """Call predict and then parse the results.""" + result = self.predict(**kwargs) + if self.prompt.output_parser is not None: + return self.prompt.output_parser.parse(result) + else: + return result diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 26b6cc23..b5ba37eb 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -2,10 +2,10 @@ import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import yaml -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, Extra, root_validator from langchain.formatting import formatter @@ -32,11 +32,35 @@ def check_valid_template( raise ValueError("Invalid prompt schema.") +class BaseOutputParser(ABC): + """Class to parse the output of an LLM call.""" + + @abstractmethod + def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]: + """Parse the output of an LLM call.""" + + +class ListOutputParser(ABC): + """Class to parse the output of an LLM call to a list.""" + + @abstractmethod + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + + class BasePromptTemplate(BaseModel, ABC): """Base prompt should expose the format method, returning a prompt.""" input_variables: List[str] """A list of the names of the variables the prompt template expects.""" + output_parser: Optional[BaseOutputParser] = None + """How to parse the output of calling an LLM on this formatted prompt.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True @root_validator() def validate_variable_names(cls, values: Dict) -> Dict: diff --git a/tests/unit_tests/chains/test_llm.py b/tests/unit_tests/chains/test_llm.py index 425713ff..65b29ddf 100644 --- a/tests/unit_tests/chains/test_llm.py +++ b/tests/unit_tests/chains/test_llm.py @@ -1,11 +1,22 @@ """Test LLM chain.""" +from typing import Dict, List, Union + import pytest from langchain.chains.llm import LLMChain +from langchain.prompts.base import BaseOutputParser from langchain.prompts.prompt import PromptTemplate from tests.unit_tests.llms.fake_llm import FakeLLM +class FakeOutputParser(BaseOutputParser): + """Fake output parser class for testing.""" + + def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]: + """Parse by splitting.""" + return text.split() + + @pytest.fixture def fake_llm_chain() -> LLMChain: """Fake LLM chain for testing purposes.""" @@ -34,3 +45,14 @@ def test_predict_method(fake_llm_chain: LLMChain) -> None: """Test predict method works.""" output = fake_llm_chain.predict(bar="baz") assert output == "foo" + + +def test_predict_and_parse() -> None: + """Test parsing ability.""" + prompt = PromptTemplate( + input_variables=["foo"], template="{foo}", output_parser=FakeOutputParser() + ) + llm = FakeLLM(queries={"foo": "foo bar"}) + chain = LLMChain(prompt=prompt, llm=llm) + output = chain.predict_and_parse(foo="foo") + assert output == ["foo", "bar"]