Merge pull request #163 from arc53/chat-prompts

chat prompts
This commit is contained in:
Pavel 2023-03-08 22:07:08 +04:00 committed by GitHub
commit 377070e3a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 53 additions and 18 deletions

View File

@ -1,25 +1,31 @@
import os
import json
import os
import traceback
import dotenv
import requests
from flask import Flask, request, render_template
from langchain import FAISS
from langchain.llms import OpenAIChat
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, \
HuggingFaceInstructEmbeddings
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from error import bad_request
os.environ["LANGCHAIN_HANDLER"] = "langchain"
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
if os.getenv("LLM_NAME") is not None:
llm_choice = os.getenv("LLM_NAME")
else:
llm_choice = "openai"
llm_choice = "openai_chat"
if os.getenv("EMBEDDINGS_NAME") is not None:
embeddings_choice = os.getenv("EMBEDDINGS_NAME")
@ -47,15 +53,21 @@ if platform.system() == "Windows":
# loading the .env file
dotenv.load_dotenv()
with open("combine_prompt.txt", "r") as f:
with open("prompts/combine_prompt.txt", "r") as f:
template = f.read()
with open("combine_prompt_hist.txt", "r") as f:
with open("prompts/combine_prompt_hist.txt", "r") as f:
template_hist = f.read()
with open("question_prompt.txt", "r") as f:
with open("prompts/question_prompt.txt", "r") as f:
template_quest = f.read()
with open("prompts/chat_combine_prompt.txt", "r") as f:
chat_combine_template = f.read()
with open("prompts/chat_reduce_prompt.txt", "r") as f:
chat_reduce_template = f.read()
if os.getenv("API_KEY") is not None:
api_key_set = True
else:
@ -98,7 +110,7 @@ def api_answer():
vectorstore = ""
else:
vectorstore = ""
#vectorstore = "outputs/inputs/"
# vectorstore = "outputs/inputs/"
# loading the index and the store and the prompt template
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
if embeddings_choice == "openai_text-embedding-ada-002":
@ -123,9 +135,20 @@ def api_answer():
q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
template_format="jinja2")
if llm_choice == "openai":
llm = OpenAIChat(openai_api_key=api_key, temperature=0)
#llm = OpenAI(openai_api_key=api_key, temperature=0)
if llm_choice == "openai_chat":
llm = ChatOpenAI(openai_api_key=api_key)
messages_combine = [
SystemMessagePromptTemplate.from_template(chat_combine_template),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
messages_reduce = [
SystemMessagePromptTemplate.from_template(chat_reduce_template),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
elif llm_choice == "openai":
llm = OpenAI(openai_api_key=api_key, temperature=0)
elif llm_choice == "manifest":
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048})
elif llm_choice == "huggingface":
@ -133,13 +156,19 @@ def api_answer():
elif llm_choice == "cohere":
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
combine_prompt=c_prompt, question_prompt=q_prompt)
if llm_choice == "openai_chat":
chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
k=4,
chain_type_kwargs={"question_prompt": p_chat_reduce,
"combine_prompt": p_chat_combine})
result = chain({"query": question})
else:
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
combine_prompt=c_prompt, question_prompt=q_prompt)
chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=4)
result = chain({"query": question})
chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=4)
# fetch the answer
result = chain({"query": question})
print(result)
# some formatting for the frontend
result['answer'] = result['result']
@ -215,7 +244,6 @@ def api_feedback():
return {"status": 'ok'}
# handling CORS
@app.after_request
def after_request(response):

View File

@ -0,0 +1,4 @@
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples if possible.
Use the following pieces of context to help answer the users question.
----------------
{summaries}

View File

@ -0,0 +1,3 @@
Use the following portion of a long document to see if any of the text is relevant to answer the question.
{context}
Provide all relevant text to the question verbatim. Summarize if needed. If nothing relevant return "-".