From 205be538a33cc43fcf101a969743cd12ec190d99 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 12 Feb 2023 17:58:54 +0000 Subject: [PATCH] fix dbqa, with new chain type, also fix for doc export --- application/app.py | 19 +++++++++++++++---- scripts/parser/open_ai_func.py | 11 +++++++++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/application/app.py b/application/app.py index aa9089ed..c114c63b 100644 --- a/application/app.py +++ b/application/app.py @@ -5,8 +5,8 @@ import datetime from flask import Flask, request, render_template # os.environ["LANGCHAIN_HANDLER"] = "langchain" import faiss -from langchain import OpenAI -from langchain.chains import VectorDBQAWithSourcesChain +from langchain import OpenAI, VectorDBQA +from langchain.chains.question_answering import load_qa_chain from langchain.prompts import PromptTemplate import requests @@ -69,11 +69,22 @@ def api_answer(): c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template) # create a chain with the prompt template and the store - chain = VectorDBQAWithSourcesChain.from_llm(llm=OpenAI(openai_api_key=api_key, temperature=0), vectorstore=store, combine_prompt=c_prompt) + #chain = VectorDBQA.from_llm(llm=OpenAI(openai_api_key=api_key, temperature=0), vectorstore=store, combine_prompt=c_prompt) + # chain = VectorDBQA.from_chain_type(llm=OpenAI(openai_api_key=api_key, temperature=0), chain_type='map_reduce', + # vectorstore=store) + + qa_chain = load_qa_chain(OpenAI(openai_api_key=api_key, temperature=0), chain_type="map_reduce", + combine_prompt=c_prompt) + chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=store) + + + # fetch the answer - result = chain({"question": question}) + result = chain({"query": question}) + print(result) # some formatting for the frontend + result['answer'] = result['result'] result['answer'] = result['answer'].replace("\\n", "
") result['answer'] = result['answer'].replace("SOURCES:", "") # mock result diff --git a/scripts/parser/open_ai_func.py b/scripts/parser/open_ai_func.py index 00c57be9..cbd947ee 100644 --- a/scripts/parser/open_ai_func.py +++ b/scripts/parser/open_ai_func.py @@ -31,13 +31,20 @@ def call_openai_api(docs): print("Error on ", i) print("Saving progress") print(f"stopped at {c1} out of {len(docs)}") - store.save_local("outputs") + faiss.write_index(store.index, "docs.index") + store.index = None + with open("faiss_store.pkl", "wb") as f: + pickle.dump(store, f) print("Sleeping for 10 seconds and trying again") time.sleep(10) store.add_texts([i.page_content], metadatas=[i.metadata]) c1 += 1 - store.save_local("outputs") + + faiss.write_index(store.index, "docs.index") + store.index = None + with open("faiss_store.pkl", "wb") as f: + pickle.dump(store, f) def get_user_permission(docs): # Function to ask user permission to call the OpenAI api and spend their OpenAI funds.