langchain/templates/rewrite-retrieve-read/rewrite_retrieve_read/chain.py

43 lines
1.1 KiB
Python
Raw Normal View History

from langchain.chat_models import ChatOpenAI
2023-10-27 02:44:30 +00:00
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
2023-10-27 02:44:30 +00:00
from langchain.schema.runnable import RunnablePassthrough
from langchain.utilities import DuckDuckGoSearchAPIWrapper
template = """Answer the users question based only on the following context:
<context>
{context}
</context>
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI(temperature=0)
search = DuckDuckGoSearchAPIWrapper()
def retriever(query):
return search.run(query)
template = """Provide a better search query for \
web search engine to answer the given question, end \
the queries with **. Question: \
{x} Answer:"""
rewrite_prompt = ChatPromptTemplate.from_template(template)
# Parser to remove the `**`
def _parse(text):
return text.strip("**")
rewriter = rewrite_prompt | ChatOpenAI(temperature=0) | StrOutputParser() | _parse
chain = {
"context": {"x": RunnablePassthrough()} | rewriter | retriever,
"question": RunnablePassthrough()
} | prompt | model | StrOutputParser()