openai-cookbook/apps/chatbot-kickstarter/chat.py

84 lines
2.3 KiB
Python
Raw Normal View History

import streamlit as st
from streamlit_chat import message
from database import get_redis_connection
from chatbot import RetrievalAssistant, Message
# Initialise database
## Initialise Redis connection
redis_client = get_redis_connection()
# Set instruction
# System prompt requiring Question and Year to be extracted from the user
system_prompt = '''
You are a helpful Formula 1 knowledge base assistant. You need to capture a Question and Year from each customer.
The Question is their query on Formula 1, and the Year is the year of the applicable Formula 1 season.
Think about this step by step:
- The user will ask a Question
- You will ask them for the Year if their question didn't include a Year
- Once you have the Year, say "searching for answers".
Example:
User: I'd like to know the cost cap for a power unit
Assistant: Certainly, what year would you like this for?
User: 2023 please.
Assistant: Searching for answers.
'''
### CHATBOT APP
st.set_page_config(
page_title="Streamlit Chat - Demo",
page_icon=":robot:"
)
st.title('Formula 1 Chatbot')
st.subheader("Help us help you learn about Formula 1")
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
def query(question):
response = st.session_state['chat'].ask_assistant(question)
return response
prompt = st.text_input("What do you want to know: ","", key="input")
if st.button('Submit', key='generationSubmit'):
# Initialization
if 'chat' not in st.session_state:
st.session_state['chat'] = RetrievalAssistant()
messages = []
system_message = Message('system',system_prompt)
messages.append(system_message.message())
else:
messages = []
user_message = Message('user',prompt)
messages.append(user_message.message())
response = query(messages)
# Debugging step to print the whole response
#st.write(response)
st.session_state.past.append(prompt)
st.session_state.generated.append(response['content'])
if st.session_state['generated']:
for i in range(len(st.session_state['generated'])-1, -1, -1):
message(st.session_state["generated"][i], key=str(i))
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')