langchain/templates/rag-supabase/rag_supabase/chain.py

51 lines
1.2 KiB
Python
Raw Normal View History

import os
2023-10-27 02:44:30 +00:00
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
2023-10-27 02:44:30 +00:00
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.output_parser import StrOutputParser
2023-10-27 02:44:30 +00:00
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores.supabase import SupabaseVectorStore
2023-10-27 02:44:30 +00:00
from supabase.client import create_client
supabase_url = os.environ.get("SUPABASE_URL")
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
supabase = create_client(supabase_url, supabase_key)
embeddings = OpenAIEmbeddings()
vectorstore = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="documents",
2023-10-27 02:44:30 +00:00
query_name="match_documents",
)
retriever = vectorstore.as_retriever()
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI()
chain = (
RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
| prompt
| model
| StrOutputParser()
)
# Add typing for input
class Question(BaseModel):
__root__: str
chain = chain.with_types(input_type=Question)