mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-19 15:25:37 +00:00
84 lines
2.3 KiB
Python
84 lines
2.3 KiB
Python
import streamlit as st
|
|
from streamlit_chat import message
|
|
|
|
from database import get_redis_connection
|
|
from chatbot import RetrievalAssistant, Message
|
|
|
|
# Initialise database
|
|
|
|
## Initialise Redis connection
|
|
redis_client = get_redis_connection()
|
|
|
|
# Set instruction
|
|
|
|
# System prompt requiring Question and Year to be extracted from the user
|
|
system_prompt = '''
|
|
You are a helpful Formula 1 knowledge base assistant. You need to capture a Question and Year from each customer.
|
|
The Question is their query on Formula 1, and the Year is the year of the applicable Formula 1 season.
|
|
Think about this step by step:
|
|
- The user will ask a Question
|
|
- You will ask them for the Year if their question didn't include a Year
|
|
- Once you have the Year, say "searching for answers".
|
|
|
|
Example:
|
|
|
|
User: I'd like to know the cost cap for a power unit
|
|
|
|
Assistant: Certainly, what year would you like this for?
|
|
|
|
User: 2023 please.
|
|
|
|
Assistant: Searching for answers.
|
|
'''
|
|
|
|
### CHATBOT APP
|
|
|
|
st.set_page_config(
|
|
page_title="Streamlit Chat - Demo",
|
|
page_icon=":robot:"
|
|
)
|
|
|
|
st.title('Formula 1 Chatbot')
|
|
st.subheader("Help us help you learn about Formula 1")
|
|
|
|
if 'generated' not in st.session_state:
|
|
st.session_state['generated'] = []
|
|
|
|
if 'past' not in st.session_state:
|
|
st.session_state['past'] = []
|
|
|
|
def query(question):
|
|
response = st.session_state['chat'].ask_assistant(question)
|
|
return response
|
|
|
|
prompt = st.text_input(f"What do you want to know: ", key="input")
|
|
|
|
if st.button('Submit', key='generationSubmit'):
|
|
|
|
# Initialization
|
|
if 'chat' not in st.session_state:
|
|
st.session_state['chat'] = RetrievalAssistant()
|
|
messages = []
|
|
system_message = Message('system',system_prompt)
|
|
messages.append(system_message.message())
|
|
else:
|
|
messages = []
|
|
|
|
|
|
user_message = Message('user',prompt)
|
|
messages.append(user_message.message())
|
|
|
|
response = query(messages)
|
|
|
|
# Debugging step to print the whole response
|
|
#st.write(response)
|
|
|
|
st.session_state.past.append(prompt)
|
|
st.session_state.generated.append(response['content'])
|
|
|
|
if st.session_state['generated']:
|
|
|
|
for i in range(len(st.session_state['generated'])-1, -1, -1):
|
|
message(st.session_state["generated"][i], key=str(i))
|
|
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
|