diff --git a/application/app.py b/application/app.py index 1a024b3..407e26b 100644 --- a/application/app.py +++ b/application/app.py @@ -1,25 +1,31 @@ -import os import json +import os import traceback import dotenv import requests from flask import Flask, request, render_template from langchain import FAISS -from langchain.llms import OpenAIChat from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI from langchain.chains.question_answering import load_qa_chain +from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, \ HuggingFaceInstructEmbeddings from langchain.prompts import PromptTemplate +from langchain.prompts.chat import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) + from error import bad_request -os.environ["LANGCHAIN_HANDLER"] = "langchain" +# os.environ["LANGCHAIN_HANDLER"] = "langchain" if os.getenv("LLM_NAME") is not None: llm_choice = os.getenv("LLM_NAME") else: - llm_choice = "openai" + llm_choice = "openai_chat" if os.getenv("EMBEDDINGS_NAME") is not None: embeddings_choice = os.getenv("EMBEDDINGS_NAME") @@ -47,15 +53,21 @@ if platform.system() == "Windows": # loading the .env file dotenv.load_dotenv() -with open("combine_prompt.txt", "r") as f: +with open("prompts/combine_prompt.txt", "r") as f: template = f.read() -with open("combine_prompt_hist.txt", "r") as f: +with open("prompts/combine_prompt_hist.txt", "r") as f: template_hist = f.read() -with open("question_prompt.txt", "r") as f: +with open("prompts/question_prompt.txt", "r") as f: template_quest = f.read() +with open("prompts/chat_combine_prompt.txt", "r") as f: + chat_combine_template = f.read() + +with open("prompts/chat_reduce_prompt.txt", "r") as f: + chat_reduce_template = f.read() + if os.getenv("API_KEY") is not None: api_key_set = True else: @@ -98,7 +110,7 @@ def api_answer(): vectorstore = "" else: vectorstore = "" - #vectorstore = "outputs/inputs/" + # vectorstore = "outputs/inputs/" # loading the index and the store and the prompt template # Note if you have used other embeddings than OpenAI, you need to change the embeddings if embeddings_choice == "openai_text-embedding-ada-002": @@ -123,9 +135,20 @@ def api_answer(): q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest, template_format="jinja2") - if llm_choice == "openai": - llm = OpenAIChat(openai_api_key=api_key, temperature=0) - #llm = OpenAI(openai_api_key=api_key, temperature=0) + if llm_choice == "openai_chat": + llm = ChatOpenAI(openai_api_key=api_key) + messages_combine = [ + SystemMessagePromptTemplate.from_template(chat_combine_template), + HumanMessagePromptTemplate.from_template("{question}") + ] + p_chat_combine = ChatPromptTemplate.from_messages(messages_combine) + messages_reduce = [ + SystemMessagePromptTemplate.from_template(chat_reduce_template), + HumanMessagePromptTemplate.from_template("{question}") + ] + p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce) + elif llm_choice == "openai": + llm = OpenAI(openai_api_key=api_key, temperature=0) elif llm_choice == "manifest": llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048}) elif llm_choice == "huggingface": @@ -133,13 +156,19 @@ def api_answer(): elif llm_choice == "cohere": llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key) - qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce", - combine_prompt=c_prompt, question_prompt=q_prompt) - - chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=4) + if llm_choice == "openai_chat": + chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch, + k=4, + chain_type_kwargs={"question_prompt": p_chat_reduce, + "combine_prompt": p_chat_combine}) + result = chain({"query": question}) + else: + qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce", + combine_prompt=c_prompt, question_prompt=q_prompt) + chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=4) + result = chain({"query": question}) - # fetch the answer - result = chain({"query": question}) + print(result) # some formatting for the frontend result['answer'] = result['result'] @@ -215,7 +244,6 @@ def api_feedback(): return {"status": 'ok'} - # handling CORS @app.after_request def after_request(response): diff --git a/application/prompts/chat_combine_prompt.txt b/application/prompts/chat_combine_prompt.txt new file mode 100644 index 0000000..608d2a8 --- /dev/null +++ b/application/prompts/chat_combine_prompt.txt @@ -0,0 +1,4 @@ +You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples if possible. +Use the following pieces of context to help answer the users question. +---------------- +{summaries} \ No newline at end of file diff --git a/application/prompts/chat_reduce_prompt.txt b/application/prompts/chat_reduce_prompt.txt new file mode 100644 index 0000000..04a673d --- /dev/null +++ b/application/prompts/chat_reduce_prompt.txt @@ -0,0 +1,3 @@ +Use the following portion of a long document to see if any of the text is relevant to answer the question. +{context} +Provide all relevant text to the question verbatim. Summarize if needed. If nothing relevant return "-". \ No newline at end of file diff --git a/application/combine_prompt.txt b/application/prompts/combine_prompt.txt similarity index 100% rename from application/combine_prompt.txt rename to application/prompts/combine_prompt.txt diff --git a/application/combine_prompt_hist.txt b/application/prompts/combine_prompt_hist.txt similarity index 100% rename from application/combine_prompt_hist.txt rename to application/prompts/combine_prompt_hist.txt diff --git a/application/question_prompt.txt b/application/prompts/question_prompt.txt similarity index 100% rename from application/question_prompt.txt rename to application/prompts/question_prompt.txt