mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
95dc90609e
Replaced `from langchain.prompts` with `from langchain_core.prompts` where it is appropriate. Most of the changes go to `langchain_experimental` Similar to #20348
193 lines
6.8 KiB
Python
193 lines
6.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, List, Mapping, Optional, cast
|
|
|
|
from langchain.chains import LLMChain
|
|
from langchain.chains.base import Chain
|
|
from langchain.schema.language_model import BaseLanguageModel
|
|
from langchain_core.callbacks.manager import (
|
|
CallbackManagerForChainRun,
|
|
)
|
|
from langchain_core.prompts.prompt import PromptTemplate
|
|
|
|
from langchain_experimental.recommenders.amazon_personalize import AmazonPersonalize
|
|
|
|
SUMMARIZE_PROMPT_QUERY = """
|
|
Summarize the recommended items for a user from the items list in tag <result> below.
|
|
Make correlation into the items in the list and provide a summary.
|
|
<result>
|
|
{result}
|
|
</result>
|
|
"""
|
|
|
|
SUMMARIZE_PROMPT = PromptTemplate(
|
|
input_variables=["result"], template=SUMMARIZE_PROMPT_QUERY
|
|
)
|
|
|
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
|
|
|
# Input Key Names to be used
|
|
USER_ID_INPUT_KEY = "user_id"
|
|
ITEM_ID_INPUT_KEY = "item_id"
|
|
INPUT_LIST_INPUT_KEY = "input_list"
|
|
FILTER_ARN_INPUT_KEY = "filter_arn"
|
|
FILTER_VALUES_INPUT_KEY = "filter_values"
|
|
CONTEXT_INPUT_KEY = "context"
|
|
PROMOTIONS_INPUT_KEY = "promotions"
|
|
METADATA_COLUMNS_INPUT_KEY = "metadata_columns"
|
|
RESULT_OUTPUT_KEY = "result"
|
|
|
|
|
|
class AmazonPersonalizeChain(Chain):
|
|
"""Chain for retrieving recommendations from Amazon Personalize,
|
|
and summarizing them.
|
|
|
|
It only returns recommendations if return_direct=True.
|
|
It can also be used in sequential chains for working with
|
|
the output of Amazon Personalize.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
chain = PersonalizeChain.from_llm(llm=agent_llm, client=personalize_lg,
|
|
return_direct=True)\n
|
|
response = chain.run({'user_id':'1'})\n
|
|
response = chain.run({'user_id':'1', 'item_id':'234'})
|
|
"""
|
|
|
|
client: AmazonPersonalize
|
|
summarization_chain: LLMChain
|
|
return_direct: bool = False
|
|
return_intermediate_steps: bool = False
|
|
is_ranking_recipe: bool = False
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""This returns an empty list since not there are optional
|
|
input_keys and none is required.
|
|
|
|
:meta private:
|
|
"""
|
|
return []
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Will always return result key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [RESULT_OUTPUT_KEY]
|
|
|
|
@classmethod
|
|
def from_llm(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
client: AmazonPersonalize,
|
|
prompt_template: PromptTemplate = SUMMARIZE_PROMPT,
|
|
is_ranking_recipe: bool = False,
|
|
**kwargs: Any,
|
|
) -> AmazonPersonalizeChain:
|
|
"""Initializes the Personalize Chain with LLMAgent, Personalize Client,
|
|
Prompts to be used
|
|
|
|
Args:
|
|
llm: BaseLanguageModel: The LLM to be used in the Chain
|
|
client: AmazonPersonalize: The client created to support
|
|
invoking AmazonPersonalize
|
|
prompt_template: PromptTemplate: The prompt template which can be
|
|
invoked with the output from Amazon Personalize
|
|
is_ranking_recipe: bool: default: False: specifies
|
|
if the trained recipe is USER_PERSONALIZED_RANKING
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
chain = PersonalizeChain.from_llm(llm=agent_llm,
|
|
client=personalize_lg, return_direct=True)\n
|
|
response = chain.run({'user_id':'1'})\n
|
|
response = chain.run({'user_id':'1', 'item_id':'234'})
|
|
|
|
RANDOM_PROMPT_QUERY=" Summarize recommendations in {result}"
|
|
chain = PersonalizeChain.from_llm(llm=agent_llm,
|
|
client=personalize_lg, prompt_template=PROMPT_TEMPLATE)\n
|
|
"""
|
|
summarization_chain = LLMChain(llm=llm, prompt=prompt_template)
|
|
|
|
return cls(
|
|
summarization_chain=summarization_chain,
|
|
client=client,
|
|
is_ranking_recipe=is_ranking_recipe,
|
|
**kwargs,
|
|
)
|
|
|
|
def _call(
|
|
self,
|
|
inputs: Mapping[str, Any],
|
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Retrieves recommendations by invoking Amazon Personalize,
|
|
and invokes an LLM using the default/overridden
|
|
prompt template with the output from Amazon Personalize
|
|
|
|
Args:
|
|
inputs: Mapping [str, Any] : Provide input identifiers in a map.
|
|
For example - {'user_id','1'} or
|
|
{'user_id':'1', 'item_id':'123'}. You can also pass the
|
|
filter_arn, filter_values as an
|
|
input.
|
|
"""
|
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
callbacks = _run_manager.get_child()
|
|
|
|
user_id = inputs.get(USER_ID_INPUT_KEY)
|
|
item_id = inputs.get(ITEM_ID_INPUT_KEY)
|
|
input_list = inputs.get(INPUT_LIST_INPUT_KEY)
|
|
filter_arn = inputs.get(FILTER_ARN_INPUT_KEY)
|
|
filter_values = inputs.get(FILTER_VALUES_INPUT_KEY)
|
|
promotions = inputs.get(PROMOTIONS_INPUT_KEY)
|
|
context = inputs.get(CONTEXT_INPUT_KEY)
|
|
metadata_columns = inputs.get(METADATA_COLUMNS_INPUT_KEY)
|
|
|
|
intermediate_steps: List = []
|
|
intermediate_steps.append({"Calling Amazon Personalize"})
|
|
|
|
if self.is_ranking_recipe:
|
|
response = self.client.get_personalized_ranking(
|
|
user_id=str(user_id),
|
|
input_list=cast(List[str], input_list),
|
|
filter_arn=filter_arn,
|
|
filter_values=filter_values,
|
|
context=context,
|
|
metadata_columns=metadata_columns,
|
|
)
|
|
else:
|
|
response = self.client.get_recommendations(
|
|
user_id=user_id,
|
|
item_id=item_id,
|
|
filter_arn=filter_arn,
|
|
filter_values=filter_values,
|
|
context=context,
|
|
promotions=promotions,
|
|
metadata_columns=metadata_columns,
|
|
)
|
|
|
|
_run_manager.on_text("Call to Amazon Personalize complete \n")
|
|
|
|
if self.return_direct:
|
|
final_result = response
|
|
else:
|
|
result = self.summarization_chain(
|
|
{RESULT_OUTPUT_KEY: response}, callbacks=callbacks
|
|
)
|
|
final_result = result[self.summarization_chain.output_key]
|
|
|
|
intermediate_steps.append({"context": response})
|
|
chain_result: Dict[str, Any] = {RESULT_OUTPUT_KEY: final_result}
|
|
if self.return_intermediate_steps:
|
|
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
|
return chain_result
|
|
|
|
@property
|
|
def _chain_type(self) -> str:
|
|
return "amazon_personalize_chain"
|