2023-05-11 09:38:30 +00:00
|
|
|
from langchain.agents import Tool
|
|
|
|
import pandas as pd
|
|
|
|
import streamlit as st
|
|
|
|
from streamlit_chat import message
|
|
|
|
|
|
|
|
from database import get_redis_connection
|
2023-05-12 11:53:06 +00:00
|
|
|
from assistant import (
|
|
|
|
answer_user_question,
|
|
|
|
initiate_agent,
|
|
|
|
answer_question_hyde,
|
|
|
|
ask_gpt,
|
|
|
|
)
|
2023-05-11 09:38:30 +00:00
|
|
|
|
|
|
|
# Initialise database
|
|
|
|
|
|
|
|
## Initialise Redis connection
|
|
|
|
redis_client = get_redis_connection()
|
|
|
|
|
|
|
|
|
|
|
|
### CHATBOT APP
|
|
|
|
|
|
|
|
# --- GENERAL SETTINGS ---
|
|
|
|
PAGE_TITLE: str = "Knowledge Retrieval Bot"
|
|
|
|
PAGE_ICON: str = "🤖"
|
|
|
|
|
|
|
|
st.set_page_config(page_title=PAGE_TITLE, page_icon=PAGE_ICON)
|
|
|
|
|
|
|
|
st.title("Wiki Chatbot")
|
|
|
|
st.subheader("Learn things - random things!")
|
|
|
|
|
|
|
|
# Using object notation
|
|
|
|
add_selectbox = st.sidebar.selectbox(
|
|
|
|
"What kind of search?", ("Standard vector search", "HyDE")
|
|
|
|
)
|
|
|
|
|
|
|
|
# Define which tools the agent can use to answer user queries
|
|
|
|
tools = [
|
|
|
|
Tool(
|
|
|
|
name="Search",
|
|
|
|
func=answer_user_question
|
|
|
|
if add_selectbox == "Standard vector search"
|
|
|
|
else answer_question_hyde,
|
|
|
|
description="Useful for when you need to answer general knowledge questions. Input should be a fully formed question.",
|
2023-05-12 11:53:06 +00:00
|
|
|
),
|
|
|
|
Tool(
|
|
|
|
name="Ask",
|
|
|
|
func=ask_gpt,
|
|
|
|
description="Useful if the question is not general knowledge. Input should be a fully formed question.",
|
|
|
|
),
|
2023-05-11 09:38:30 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
if "generated" not in st.session_state:
|
|
|
|
st.session_state["generated"] = []
|
|
|
|
|
|
|
|
if "past" not in st.session_state:
|
|
|
|
st.session_state["past"] = []
|
|
|
|
|
|
|
|
|
|
|
|
def query(question):
|
|
|
|
response = st.session_state["chat"].ask_assistant(question)
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
prompt = st.text_input("What do you want to know: ", "", key="input")
|
|
|
|
|
|
|
|
if st.button("Submit", key="generationSubmit"):
|
|
|
|
with st.spinner("Thinking..."):
|
|
|
|
# Initialization
|
|
|
|
if "agent" not in st.session_state:
|
|
|
|
st.session_state["agent"] = initiate_agent(tools)
|
|
|
|
|
|
|
|
response = st.session_state["agent"].run(prompt)
|
|
|
|
|
|
|
|
st.session_state.past.append(prompt)
|
|
|
|
st.session_state.generated.append(response)
|
|
|
|
|
|
|
|
if len(st.session_state["generated"]) > 0:
|
|
|
|
for i in range(len(st.session_state["generated"]) - 1, -1, -1):
|
|
|
|
message(st.session_state["generated"][i], key=str(i))
|
|
|
|
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
|
|
|
|
|
|
|
|
with st.expander("See search results"):
|
|
|
|
|
|
|
|
results = list(pd.read_csv("results.csv")["result"])
|
|
|
|
|
|
|
|
st.write(results)
|