from __future__ import annotations from typing import Any, Dict, List, Mapping, Optional, cast from langchain.chains import LLMChain from langchain.chains.base import Chain from langchain.schema.language_model import BaseLanguageModel from langchain_core.callbacks.manager import ( CallbackManagerForChainRun, ) from langchain_core.prompts.prompt import PromptTemplate from langchain_experimental.recommenders.amazon_personalize import AmazonPersonalize SUMMARIZE_PROMPT_QUERY = """ Summarize the recommended items for a user from the items list in tag below. Make correlation into the items in the list and provide a summary. {result} """ SUMMARIZE_PROMPT = PromptTemplate( input_variables=["result"], template=SUMMARIZE_PROMPT_QUERY ) INTERMEDIATE_STEPS_KEY = "intermediate_steps" # Input Key Names to be used USER_ID_INPUT_KEY = "user_id" ITEM_ID_INPUT_KEY = "item_id" INPUT_LIST_INPUT_KEY = "input_list" FILTER_ARN_INPUT_KEY = "filter_arn" FILTER_VALUES_INPUT_KEY = "filter_values" CONTEXT_INPUT_KEY = "context" PROMOTIONS_INPUT_KEY = "promotions" METADATA_COLUMNS_INPUT_KEY = "metadata_columns" RESULT_OUTPUT_KEY = "result" class AmazonPersonalizeChain(Chain): """Chain for retrieving recommendations from Amazon Personalize, and summarizing them. It only returns recommendations if return_direct=True. It can also be used in sequential chains for working with the output of Amazon Personalize. Example: .. code-block:: python chain = PersonalizeChain.from_llm(llm=agent_llm, client=personalize_lg, return_direct=True)\n response = chain.run({'user_id':'1'})\n response = chain.run({'user_id':'1', 'item_id':'234'}) """ client: AmazonPersonalize summarization_chain: LLMChain return_direct: bool = False return_intermediate_steps: bool = False is_ranking_recipe: bool = False @property def input_keys(self) -> List[str]: """This returns an empty list since not there are optional input_keys and none is required. :meta private: """ return [] @property def output_keys(self) -> List[str]: """Will always return result key. :meta private: """ return [RESULT_OUTPUT_KEY] @classmethod def from_llm( cls, llm: BaseLanguageModel, client: AmazonPersonalize, prompt_template: PromptTemplate = SUMMARIZE_PROMPT, is_ranking_recipe: bool = False, **kwargs: Any, ) -> AmazonPersonalizeChain: """Initializes the Personalize Chain with LLMAgent, Personalize Client, Prompts to be used Args: llm: BaseLanguageModel: The LLM to be used in the Chain client: AmazonPersonalize: The client created to support invoking AmazonPersonalize prompt_template: PromptTemplate: The prompt template which can be invoked with the output from Amazon Personalize is_ranking_recipe: bool: default: False: specifies if the trained recipe is USER_PERSONALIZED_RANKING Example: .. code-block:: python chain = PersonalizeChain.from_llm(llm=agent_llm, client=personalize_lg, return_direct=True)\n response = chain.run({'user_id':'1'})\n response = chain.run({'user_id':'1', 'item_id':'234'}) RANDOM_PROMPT_QUERY=" Summarize recommendations in {result}" chain = PersonalizeChain.from_llm(llm=agent_llm, client=personalize_lg, prompt_template=PROMPT_TEMPLATE)\n """ summarization_chain = LLMChain(llm=llm, prompt=prompt_template) return cls( summarization_chain=summarization_chain, client=client, is_ranking_recipe=is_ranking_recipe, **kwargs, ) def _call( self, inputs: Mapping[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """Retrieves recommendations by invoking Amazon Personalize, and invokes an LLM using the default/overridden prompt template with the output from Amazon Personalize Args: inputs: Mapping [str, Any] : Provide input identifiers in a map. For example - {'user_id','1'} or {'user_id':'1', 'item_id':'123'}. You can also pass the filter_arn, filter_values as an input. """ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() user_id = inputs.get(USER_ID_INPUT_KEY) item_id = inputs.get(ITEM_ID_INPUT_KEY) input_list = inputs.get(INPUT_LIST_INPUT_KEY) filter_arn = inputs.get(FILTER_ARN_INPUT_KEY) filter_values = inputs.get(FILTER_VALUES_INPUT_KEY) promotions = inputs.get(PROMOTIONS_INPUT_KEY) context = inputs.get(CONTEXT_INPUT_KEY) metadata_columns = inputs.get(METADATA_COLUMNS_INPUT_KEY) intermediate_steps: List = [] intermediate_steps.append({"Calling Amazon Personalize"}) if self.is_ranking_recipe: response = self.client.get_personalized_ranking( user_id=str(user_id), input_list=cast(List[str], input_list), filter_arn=filter_arn, filter_values=filter_values, context=context, metadata_columns=metadata_columns, ) else: response = self.client.get_recommendations( user_id=user_id, item_id=item_id, filter_arn=filter_arn, filter_values=filter_values, context=context, promotions=promotions, metadata_columns=metadata_columns, ) _run_manager.on_text("Call to Amazon Personalize complete \n") if self.return_direct: final_result = response else: result = self.summarization_chain( {RESULT_OUTPUT_KEY: response}, callbacks=callbacks ) final_result = result[self.summarization_chain.output_key] intermediate_steps.append({"context": response}) chain_result: Dict[str, Any] = {RESULT_OUTPUT_KEY: final_result} if self.return_intermediate_steps: chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps return chain_result @property def _chain_type(self) -> str: return "amazon_personalize_chain"