Merge pull request #241 from arc53/feature/history

Feature/history
pull/243/head
Alex 1 year ago committed by GitHub
commit 1800e51b19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,6 +23,7 @@ from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
AIMessagePromptTemplate,
)
from pymongo import MongoClient
from werkzeug.utils import secure_filename
@ -107,6 +108,8 @@ def run_async_chain(chain, question, chat_history):
result["answer"] = answer
return result
@celery.task(bind=True)
def ingest(self, directory, formats, name_job, filename, user):
@ -164,16 +167,6 @@ def api_answer():
docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))
# create a prompt template
if history:
history = json.loads(history)
template_temp = template_hist.replace("{historyquestion}", history[0]).replace("{historyanswer}",
history[1])
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template_temp,
template_format="jinja2")
else:
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template,
template_format="jinja2")
q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
template_format="jinja2")
if settings.LLM_NAME == "openai_chat":
@ -182,6 +175,18 @@ def api_answer():
SystemMessagePromptTemplate.from_template(chat_combine_template),
HumanMessagePromptTemplate.from_template("{question}")
]
if history:
tokens_current_history = 0
tokens_max_history = 500
#count tokens in history
for i in history:
if "prompt" in i and "response" in i:
tokens_batch = llm.get_num_tokens(i["prompt"]) + llm.get_num_tokens(i["response"])
if tokens_current_history + tokens_batch < tokens_max_history:
tokens_current_history += tokens_batch
messages_combine.append(HumanMessagePromptTemplate.from_template(i["prompt"]))
messages_combine.append(AIMessagePromptTemplate.from_template(i["response"]))
p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
elif settings.LLM_NAME == "openai":
llm = OpenAI(openai_api_key=api_key, temperature=0)
@ -208,7 +213,7 @@ def api_answer():
result = run_async_chain(chain, question, chat_history)
else:
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
combine_prompt=c_prompt, question_prompt=q_prompt)
combine_prompt=chat_combine_template, question_prompt=q_prompt)
chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=3)
result = chain({"query": question})

@ -8,7 +8,7 @@ services:
ports:
- "5173:5173"
depends_on:
- backend
- backend
backend:
build: ./application

@ -7,6 +7,7 @@ export function fetchAnswerApi(
question: string,
apiKey: string,
selectedDocs: Doc,
history: Array<any> = [],
): Promise<Answer> {
let namePath = selectedDocs.name;
if (selectedDocs.language === namePath) {
@ -37,7 +38,7 @@ export function fetchAnswerApi(
question: question,
api_key: apiKey,
embeddings_key: apiKey,
history: localStorage.getItem('chatHistory'),
history: history,
active_docs: docPath,
}),
})

@ -19,6 +19,7 @@ export const fetchAnswer = createAsyncThunk<
question,
state.preference.apiKey,
state.preference.selectedDocs!,
state.conversation.queries,
);
return answer;
});

Loading…
Cancel
Save