harrison/use_output_parser
Harrison Chase 2 years ago
parent 3ef44f41b7
commit 9966fd0e05

@ -1,10 +1,11 @@
"""Chain that first uses an LLM to generate multiple items then loops over them."""
from typing import Any, Dict, List
from typing import Dict, List, Any
from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts.base import ListOutputParser
class LLMForLoopChain(Chain, BaseModel):
@ -38,6 +39,17 @@ class LLMForLoopChain(Chain, BaseModel):
"""
return [self.output_key]
@root_validator()
def validate_output_parser(cls, values: Dict) -> Dict:
"""Validate that the correct inputs exist for all chains."""
chain = values["llm_chain"]
if not isinstance(chain.prompt.output_parser, ListOutputParser):
raise ValueError(
f"The OutputParser on the base prompt should be of type "
f"ListOutputParser, got {type(chain.prompt.output_parser)}"
)
return values
def run_list(self, **kwargs: Any) -> List[str]:
"""Get list from LLM chain and then run chain on each item."""
output_items = self.llm_chain.predict_and_parse(**kwargs)

@ -1,8 +1,8 @@
"""BasePrompt schema definition."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator, Extra
from pydantic import BaseModel, Extra, root_validator
from langchain.formatting import formatter
@ -37,6 +37,13 @@ class BaseOutputParser(ABC):
"""Parse the output of an LLM call."""
class ListOutputParser(ABC):
"""Class to parse the output of an LLM call to a list."""
@abstractmethod
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
class BasePromptTemplate(BaseModel, ABC):
"""Base prompt should expose the format method, returning a prompt."""

Loading…
Cancel
Save