mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
experimental[patch]: SmartLLMChain Output Key Customization (#14466)
**Description** The `SmartLLMChain` was was fixed to output key "resolution". Unfortunately, this prevents the ability to use multiple `SmartLLMChain` in a `SequentialChain` because of colliding output keys. This change simply gives the option the customize the output key to allow for sequential chaining. The default behavior is the same as the current behavior. Now, it's possible to do the following: ``` from langchain.chat_models import ChatOpenAI from langchain.prompts import PromptTemplate from langchain_experimental.smart_llm import SmartLLMChain from langchain.chains import SequentialChain joke_prompt = PromptTemplate( input_variables=["content"], template="Tell me a joke about {content}.", ) review_prompt = PromptTemplate( input_variables=["scale", "joke"], template="Rate the following joke from 1 to {scale}: {joke}" ) llm = ChatOpenAI(temperature=0.9, model_name="gpt-4-32k") joke_chain = SmartLLMChain(llm=llm, prompt=joke_prompt, output_key="joke") review_chain = SmartLLMChain(llm=llm, prompt=review_prompt, output_key="review") chain = SequentialChain( chains=[joke_chain, review_chain], input_variables=["content", "scale"], output_variables=["review"], verbose=True ) response = chain.run({"content": "chickens", "scale": "10"}) print(response) ``` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
0797358c1b
commit
6da0cfea0e
@ -66,6 +66,7 @@ class SmartLLMChain(Chain):
|
|||||||
|
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
"""Prompt object to use."""
|
"""Prompt object to use."""
|
||||||
|
output_key: str = "resolution"
|
||||||
ideation_llm: Optional[BaseLanguageModel] = None
|
ideation_llm: Optional[BaseLanguageModel] = None
|
||||||
"""LLM to use in ideation step. If None given, 'llm' will be used."""
|
"""LLM to use in ideation step. If None given, 'llm' will be used."""
|
||||||
critique_llm: Optional[BaseLanguageModel] = None
|
critique_llm: Optional[BaseLanguageModel] = None
|
||||||
@ -132,8 +133,8 @@ class SmartLLMChain(Chain):
|
|||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Defines the output keys."""
|
"""Defines the output keys."""
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
return ["ideas", "critique", "resolution"]
|
return ["ideas", "critique", self.output_key]
|
||||||
return ["resolution"]
|
return [self.output_key]
|
||||||
|
|
||||||
def prep_prompts(
|
def prep_prompts(
|
||||||
self,
|
self,
|
||||||
@ -169,8 +170,8 @@ class SmartLLMChain(Chain):
|
|||||||
self.history.critique = critique
|
self.history.critique = critique
|
||||||
resolution = self._resolve(stop, run_manager)
|
resolution = self._resolve(stop, run_manager)
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
return {"ideas": ideas, "critique": critique, "resolution": resolution}
|
return {"ideas": ideas, "critique": critique, self.output_key: resolution}
|
||||||
return {"resolution": resolution}
|
return {self.output_key: resolution}
|
||||||
|
|
||||||
def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str:
|
def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str:
|
||||||
"""Between steps, only the LLM result text is passed, not the LLMResult object.
|
"""Between steps, only the LLM result text is passed, not the LLMResult object.
|
||||||
|
Loading…
Reference in New Issue
Block a user