langchain/templates/rewrite-retrieve-read/rewrite_retrieve_read/chain.py

60 lines
1.3 KiB
Python
Raw Normal View History

from langchain.chat_models import ChatOpenAI
2023-10-27 02:44:30 +00:00
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.output_parser import StrOutputParser
2023-10-27 02:44:30 +00:00
from langchain.schema.runnable import RunnablePassthrough
from langchain.utilities import DuckDuckGoSearchAPIWrapper
template = """Answer the users question based only on the following context:
<context>
{context}
</context>
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI(temperature=0)
search = DuckDuckGoSearchAPIWrapper()
def retriever(query):
return search.run(query)
2023-10-29 22:50:09 +00:00
template = """Provide a better search query for \
web search engine to answer the given question, end \
the queries with **. Question: \
{x} Answer:"""
rewrite_prompt = ChatPromptTemplate.from_template(template)
# Parser to remove the `**`
2023-10-29 22:50:09 +00:00
def _parse(text):
return text.strip("**")
2023-10-29 22:50:09 +00:00
rewriter = rewrite_prompt | ChatOpenAI(temperature=0) | StrOutputParser() | _parse
2023-10-29 22:50:09 +00:00
chain = (
{
"context": {"x": RunnablePassthrough()} | rewriter | retriever,
"question": RunnablePassthrough(),
}
| prompt
| model
| StrOutputParser()
)
# Add input type for playground
2023-10-29 22:50:09 +00:00
class Question(BaseModel):
__root__: str
2023-10-29 22:50:09 +00:00
chain = chain.with_types(input_type=Question)