From 790ed8be6933bbe624e19f353fb464526a6746ac Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Sat, 18 Nov 2023 14:42:22 -0800 Subject: [PATCH] update multi index templates (#13569) --- .../rag_multi_index_fusion/chain.py | 34 ++++++------- .../rag_multi_index_router/chain.py | 51 ++++++++++++------- 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/templates/rag-multi-index-fusion/rag_multi_index_fusion/chain.py b/templates/rag-multi-index-fusion/rag_multi_index_fusion/chain.py index 1ffd30f1d4..a6b10dc6f2 100644 --- a/templates/rag-multi-index-fusion/rag_multi_index_fusion/chain.py +++ b/templates/rag-multi-index-fusion/rag_multi_index_fusion/chain.py @@ -1,3 +1,5 @@ +from operator import itemgetter + import numpy as np from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings @@ -11,7 +13,6 @@ from langchain.retrievers import ( ) from langchain.schema import StrOutputParser from langchain.schema.runnable import ( - RunnableLambda, RunnableParallel, RunnablePassthrough, ) @@ -51,14 +52,6 @@ def fuse_retrieved_docs(input): ] -retriever_map = { - "medical paper": pubmed, - "scientific paper": arxiv, - "public company finances report": sec, - "general": wiki, -} - - def format_named_docs(named_docs): return "\n\n".join( f"Source: {source}\n\n{doc.page_content}" for source, doc in named_docs @@ -83,19 +76,26 @@ class Question(BaseModel): __root__: str +answer_chain = ( + { + "question": itemgetter("question"), + "sources": lambda x: format_named_docs(x["sources"]), + } + | prompt + | ChatOpenAI(model="gpt-3.5-turbo-1106") + | StrOutputParser() +).with_config(run_name="answer") chain = ( ( RunnableParallel( {"question": RunnablePassthrough(), "sources": retrieve_all} ).with_config(run_name="add_sources") - | RunnablePassthrough.assign( - sources=( - RunnableLambda(fuse_retrieved_docs) | format_named_docs - ).with_config(run_name="fuse_and_format") - ).with_config(run_name="update_sources") - | prompt - | ChatOpenAI(model="gpt-3.5-turbo-1106") - | StrOutputParser() + | RunnablePassthrough.assign(sources=fuse_retrieved_docs).with_config( + run_name="fuse" + ) + | RunnablePassthrough.assign(answer=answer_chain).with_config( + run_name="add_answer" + ) ) .with_config(run_name="QA with fused results") .with_types(input_type=Question) diff --git a/templates/rag-multi-index-router/rag_multi_index_router/chain.py b/templates/rag-multi-index-router/rag_multi_index_router/chain.py index 2882e2ec21..cbdb6a3e2d 100644 --- a/templates/rag-multi-index-router/rag_multi_index_router/chain.py +++ b/templates/rag-multi-index-router/rag_multi_index_router/chain.py @@ -28,7 +28,7 @@ wiki = WikipediaRetriever(top_k_results=5, doc_content_chars_max=2000).with_conf run_name="wiki" ) -llm = ChatOpenAI(model="gpt-3.5-turbo-1106") +llm = ChatOpenAI(model="gpt-3.5-turbo") class Search(BaseModel): @@ -45,18 +45,29 @@ class Search(BaseModel): ) -classifier = llm.bind( - functions=[convert_pydantic_to_openai_function(Search)], - function_call={"name": "Search"}, -) | PydanticAttrOutputFunctionsParser( - pydantic_schema=Search, attr_name="question_resource" +retriever_name = { + "medical paper": "PubMed", + "scientific paper": "ArXiv", + "public company finances report": "SEC filings (Kay AI)", + "general": "Wikipedia", +} + +classifier = ( + llm.bind( + functions=[convert_pydantic_to_openai_function(Search)], + function_call={"name": "Search"}, + ) + | PydanticAttrOutputFunctionsParser( + pydantic_schema=Search, attr_name="question_resource" + ) + | retriever_name.get ) retriever_map = { - "medical paper": pubmed, - "scientific paper": arxiv, - "public company finances report": sec, - "general": wiki, + "PubMed": pubmed, + "ArXiv": arxiv, + "SEC filings (Kay AI)": sec, + "Wikipedia": wiki, } router_retriever = RouterRunnable(runnables=retriever_map) @@ -79,17 +90,23 @@ class Question(BaseModel): __root__: str +retriever_chain = ( + {"input": itemgetter("question"), "key": itemgetter("retriever_choice")} + | router_retriever + | format_docs +).with_config(run_name="retrieve") +answer_chain = ( + {"sources": retriever_chain, "question": itemgetter("question")} + | prompt + | llm + | StrOutputParser() +) chain = ( ( RunnableParallel( - {"input": RunnablePassthrough(), "key": classifier} + question=RunnablePassthrough(), retriever_choice=classifier ).with_config(run_name="classify") - | RunnableParallel( - {"question": itemgetter("input"), "sources": router_retriever | format_docs} - ).with_config(run_name="retrieve") - | prompt - | llm - | StrOutputParser() + | RunnablePassthrough.assign(answer=answer_chain).with_config(run_name="answer") ) .with_config(run_name="QA with router") .with_types(input_type=Question)