mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
introduce output parser (#250)
This commit is contained in:
parent
b4762dfff0
commit
db58032973
@ -1,5 +1,5 @@
|
||||
"""Chain that just formats a prompt and calls an LLM."""
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
@ -78,3 +78,11 @@ class LLMChain(Chain, BaseModel):
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs)[self.output_key]
|
||||
|
||||
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(**kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
@ -2,10 +2,10 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.formatting import formatter
|
||||
|
||||
@ -32,11 +32,35 @@ def check_valid_template(
|
||||
raise ValueError("Invalid prompt schema.")
|
||||
|
||||
|
||||
class BaseOutputParser(ABC):
|
||||
"""Class to parse the output of an LLM call."""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
|
||||
class ListOutputParser(ABC):
|
||||
"""Class to parse the output of an LLM call to a list."""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
|
||||
class BasePromptTemplate(BaseModel, ABC):
|
||||
"""Base prompt should expose the format method, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
|
@ -1,11 +1,22 @@
|
||||
"""Test LLM chain."""
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BaseOutputParser
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
class FakeOutputParser(BaseOutputParser):
|
||||
"""Fake output parser class for testing."""
|
||||
|
||||
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Parse by splitting."""
|
||||
return text.split()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_chain() -> LLMChain:
|
||||
"""Fake LLM chain for testing purposes."""
|
||||
@ -34,3 +45,14 @@ def test_predict_method(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test predict method works."""
|
||||
output = fake_llm_chain.predict(bar="baz")
|
||||
assert output == "foo"
|
||||
|
||||
|
||||
def test_predict_and_parse() -> None:
|
||||
"""Test parsing ability."""
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["foo"], template="{foo}", output_parser=FakeOutputParser()
|
||||
)
|
||||
llm = FakeLLM(queries={"foo": "foo bar"})
|
||||
chain = LLMChain(prompt=prompt, llm=llm)
|
||||
output = chain.predict_and_parse(foo="foo")
|
||||
assert output == ["foo", "bar"]
|
||||
|
Loading…
Reference in New Issue
Block a user