langchain/templates/rag-aws-bedrock/rag_aws_bedrock/chain.py
Piyush Jain 5545de0466
Updated the Bedrock rag template (#12462)
Updates the bedrock rag template.
- Removes pinecone and replaces with FAISS as the vector store
- Fixes the environment variables, setting defaults
- Adds a `main.py` test file quick sanity testing
- Updates README.md with correct instructions
2023-10-27 17:02:28 -07:00

49 lines
1.3 KiB
Python

import os
from langchain.embeddings import BedrockEmbeddings
from langchain.llms.bedrock import Bedrock
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import FAISS
# Get region and profile from env
region = os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
profile = os.environ.get("AWS_PROFILE", "default")
# Set LLM and embeddings
model = Bedrock(
model_id="anthropic.claude-v2",
region_name=region,
credentials_profile_name=profile,
model_kwargs={'max_tokens_to_sample':200}
)
bedrock_embeddings = BedrockEmbeddings(
model_id="amazon.titan-embed-text-v1"
)
# Add to vectorDB
vectorstore = FAISS.from_texts(
["harrison worked at kensho"],
embedding=bedrock_embeddings
)
retriever = vectorstore.as_retriever()
# Get retriever from vectorstore
retriever = vectorstore.as_retriever()
# RAG prompt
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
# RAG
chain = (
RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
| prompt
| model
| StrOutputParser()
)