|
|
|
@ -1,10 +1,11 @@
|
|
|
|
|
"""Chain that first uses an LLM to generate multiple items then loops over them."""
|
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, Any
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Extra
|
|
|
|
|
from pydantic import BaseModel, Extra, root_validator
|
|
|
|
|
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.prompts.base import ListOutputParser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLMForLoopChain(Chain, BaseModel):
|
|
|
|
@ -38,6 +39,17 @@ class LLMForLoopChain(Chain, BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
return [self.output_key]
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_output_parser(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Validate that the correct inputs exist for all chains."""
|
|
|
|
|
chain = values["llm_chain"]
|
|
|
|
|
if not isinstance(chain.prompt.output_parser, ListOutputParser):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"The OutputParser on the base prompt should be of type "
|
|
|
|
|
f"ListOutputParser, got {type(chain.prompt.output_parser)}"
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
def run_list(self, **kwargs: Any) -> List[str]:
|
|
|
|
|
"""Get list from LLM chain and then run chain on each item."""
|
|
|
|
|
output_items = self.llm_chain.predict_and_parse(**kwargs)
|
|
|
|
|