experimental[patch]: SmartLLMChain Output Key Customization (#14466)

**Description**
The `SmartLLMChain` was was fixed to output key "resolution".
Unfortunately, this prevents the ability to use multiple `SmartLLMChain`
in a `SequentialChain` because of colliding output keys. This change
simply gives the option the customize the output key to allow for
sequential chaining. The default behavior is the same as the current
behavior.

Now, it's possible to do the following:
```
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_experimental.smart_llm import SmartLLMChain
from langchain.chains import SequentialChain

joke_prompt = PromptTemplate(
    input_variables=["content"],
    template="Tell me a joke about {content}.",
)
review_prompt = PromptTemplate(
    input_variables=["scale", "joke"],
    template="Rate the following joke from 1 to {scale}: {joke}"
)

llm = ChatOpenAI(temperature=0.9, model_name="gpt-4-32k")
joke_chain = SmartLLMChain(llm=llm, prompt=joke_prompt, output_key="joke")
review_chain = SmartLLMChain(llm=llm, prompt=review_prompt, output_key="review")

chain = SequentialChain(
    chains=[joke_chain, review_chain],
    input_variables=["content", "scale"],
    output_variables=["review"],
    verbose=True
)
response = chain.run({"content": "chickens", "scale": "10"})
print(response)
```

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Anish Nag 2023-12-08 13:55:51 -08:00 committed by GitHub
parent 0797358c1b
commit 6da0cfea0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -66,6 +66,7 @@ class SmartLLMChain(Chain):
prompt: BasePromptTemplate prompt: BasePromptTemplate
"""Prompt object to use.""" """Prompt object to use."""
output_key: str = "resolution"
ideation_llm: Optional[BaseLanguageModel] = None ideation_llm: Optional[BaseLanguageModel] = None
"""LLM to use in ideation step. If None given, 'llm' will be used.""" """LLM to use in ideation step. If None given, 'llm' will be used."""
critique_llm: Optional[BaseLanguageModel] = None critique_llm: Optional[BaseLanguageModel] = None
@ -132,8 +133,8 @@ class SmartLLMChain(Chain):
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Defines the output keys.""" """Defines the output keys."""
if self.return_intermediate_steps: if self.return_intermediate_steps:
return ["ideas", "critique", "resolution"] return ["ideas", "critique", self.output_key]
return ["resolution"] return [self.output_key]
def prep_prompts( def prep_prompts(
self, self,
@ -169,8 +170,8 @@ class SmartLLMChain(Chain):
self.history.critique = critique self.history.critique = critique
resolution = self._resolve(stop, run_manager) resolution = self._resolve(stop, run_manager)
if self.return_intermediate_steps: if self.return_intermediate_steps:
return {"ideas": ideas, "critique": critique, "resolution": resolution} return {"ideas": ideas, "critique": critique, self.output_key: resolution}
return {"resolution": resolution} return {self.output_key: resolution}
def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str: def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str:
"""Between steps, only the LLM result text is passed, not the LLMResult object. """Between steps, only the LLM result text is passed, not the LLMResult object.