From 6da0cfea0ed677ad5cb45f24a4543a215e081554 Mon Sep 17 00:00:00 2001 From: Anish Nag Date: Fri, 8 Dec 2023 13:55:51 -0800 Subject: [PATCH] experimental[patch]: SmartLLMChain Output Key Customization (#14466) **Description** The `SmartLLMChain` was was fixed to output key "resolution". Unfortunately, this prevents the ability to use multiple `SmartLLMChain` in a `SequentialChain` because of colliding output keys. This change simply gives the option the customize the output key to allow for sequential chaining. The default behavior is the same as the current behavior. Now, it's possible to do the following: ``` from langchain.chat_models import ChatOpenAI from langchain.prompts import PromptTemplate from langchain_experimental.smart_llm import SmartLLMChain from langchain.chains import SequentialChain joke_prompt = PromptTemplate( input_variables=["content"], template="Tell me a joke about {content}.", ) review_prompt = PromptTemplate( input_variables=["scale", "joke"], template="Rate the following joke from 1 to {scale}: {joke}" ) llm = ChatOpenAI(temperature=0.9, model_name="gpt-4-32k") joke_chain = SmartLLMChain(llm=llm, prompt=joke_prompt, output_key="joke") review_chain = SmartLLMChain(llm=llm, prompt=review_prompt, output_key="review") chain = SequentialChain( chains=[joke_chain, review_chain], input_variables=["content", "scale"], output_variables=["review"], verbose=True ) response = chain.run({"content": "chickens", "scale": "10"}) print(response) ``` --------- Co-authored-by: Erick Friis --- .../langchain_experimental/smart_llm/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/libs/experimental/langchain_experimental/smart_llm/base.py b/libs/experimental/langchain_experimental/smart_llm/base.py index 8301c5df53..d9d8929cb0 100644 --- a/libs/experimental/langchain_experimental/smart_llm/base.py +++ b/libs/experimental/langchain_experimental/smart_llm/base.py @@ -66,6 +66,7 @@ class SmartLLMChain(Chain): prompt: BasePromptTemplate """Prompt object to use.""" + output_key: str = "resolution" ideation_llm: Optional[BaseLanguageModel] = None """LLM to use in ideation step. If None given, 'llm' will be used.""" critique_llm: Optional[BaseLanguageModel] = None @@ -132,8 +133,8 @@ class SmartLLMChain(Chain): def output_keys(self) -> List[str]: """Defines the output keys.""" if self.return_intermediate_steps: - return ["ideas", "critique", "resolution"] - return ["resolution"] + return ["ideas", "critique", self.output_key] + return [self.output_key] def prep_prompts( self, @@ -169,8 +170,8 @@ class SmartLLMChain(Chain): self.history.critique = critique resolution = self._resolve(stop, run_manager) if self.return_intermediate_steps: - return {"ideas": ideas, "critique": critique, "resolution": resolution} - return {"resolution": resolution} + return {"ideas": ideas, "critique": critique, self.output_key: resolution} + return {self.output_key: resolution} def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str: """Between steps, only the LLM result text is passed, not the LLMResult object.