import streamlit as st from streamlit_chat import message from database import get_redis_connection from chatbot import RetrievalAssistant, Message # Initialise database ## Initialise Redis connection redis_client = get_redis_connection() # Set instruction # System prompt requiring Question and Year to be extracted from the user system_prompt = ''' You are a helpful Formula 1 knowledge base assistant. You need to capture a Question and Year from each customer. The Question is their query on Formula 1, and the Year is the year of the applicable Formula 1 season. Think about this step by step: - The user will ask a Question - You will ask them for the Year if their question didn't include a Year - Once you have the Year, say "searching for answers". Example: User: I'd like to know the cost cap for a power unit Assistant: Certainly, what year would you like this for? User: 2023 please. Assistant: Searching for answers. ''' ### CHATBOT APP st.set_page_config( page_title="Streamlit Chat - Demo", page_icon=":robot:" ) st.title('Formula 1 Chatbot') st.subheader("Help us help you learn about Formula 1") if 'generated' not in st.session_state: st.session_state['generated'] = [] if 'past' not in st.session_state: st.session_state['past'] = [] def query(question): response = st.session_state['chat'].ask_assistant(question) return response prompt = st.text_input(f"What do you want to know: ", key="input") if st.button('Submit', key='generationSubmit'): # Initialization if 'chat' not in st.session_state: st.session_state['chat'] = RetrievalAssistant() messages = [] system_message = Message('system',system_prompt) messages.append(system_message.message()) else: messages = [] user_message = Message('user',prompt) messages.append(user_message.message()) response = query(messages) # Debugging step to print the whole response #st.write(response) st.session_state.past.append(prompt) st.session_state.generated.append(response['content']) if st.session_state['generated']: for i in range(len(st.session_state['generated'])-1, -1, -1): message(st.session_state["generated"][i], key=str(i)) message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')