From ce31b3e72e8a90cca89653ca8f69f4427286c04e Mon Sep 17 00:00:00 2001 From: Gustav von Zitzewitz Date: Fri, 12 May 2023 15:44:26 +0200 Subject: [PATCH] add usage/costs --- README.md | 2 +- app.py | 25 +++++++++++++++++++------ constants.py | 6 ++++++ utils.py | 46 +++++++++++++++++++++++++++++++++------------- 4 files changed, 59 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 62e4a91..1f62845 100644 --- a/README.md +++ b/README.md @@ -17,4 +17,4 @@ This is an app that let's you ask questions about any data source by leveraging - As default context this git repository is taken so you can directly start asking question about its functionality without chosing an own data source. - To run locally or deploy somewhere, execute `cp .env.template .env` and set necessary keys in the newly created secrets file. Other options are manually setting of environment variables, or creating a `.streamlit/secrets.toml` file and storing credentials there. - Your data won't load? Feel free to open an Issue or PR and contribute! -- Finally, yes, Chad in `DataChad` refers to the well-known [meme](https://www.google.com/search?q=chad+meme) +- Yes, Chad in `DataChad` refers to the well-known [meme](https://www.google.com/search?q=chad+meme) diff --git a/app.py b/app.py index cd8b980..6100d0e 100644 --- a/app.py +++ b/app.py @@ -9,6 +9,7 @@ from constants import ( DEFAULT_DATA_SOURCE, OPENAI_HELP, PAGE_ICON, + USAGE_HELP, ) from utils import ( authenticate, @@ -33,14 +34,16 @@ st.markdown( ) # Initialise session state variables -if "chat_history" not in st.session_state: - st.session_state["chat_history"] = [] -if "generated" not in st.session_state: - st.session_state["generated"] = [] if "past" not in st.session_state: st.session_state["past"] = [] +if "usage" not in st.session_state: + st.session_state["usage"] = {} +if "generated" not in st.session_state: + st.session_state["generated"] = [] if "auth_ok" not in st.session_state: st.session_state["auth_ok"] = False +if "chat_history" not in st.session_state: + st.session_state["chat_history"] = [] if "data_source" not in st.session_state: st.session_state["data_source"] = "" if "uploaded_file" not in st.session_state: @@ -84,6 +87,7 @@ with st.sidebar: clear_button = st.button("Clear Conversation", key="clear") + # the chain can only be initialized after authentication is OK if "chain" not in st.session_state: build_chain_and_clear_history(DEFAULT_DATA_SOURCE) @@ -115,7 +119,6 @@ if uploaded_file and uploaded_file != st.session_state["uploaded_file"]: delete_uploaded_file(uploaded_file) st.session_state["uploaded_file"] = uploaded_file - # container for chat history response_container = st.container() # container for text box @@ -131,9 +134,19 @@ with container: st.session_state["past"].append(user_input) st.session_state["generated"].append(output) - if st.session_state["generated"]: with response_container: for i in range(len(st.session_state["generated"])): message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") message(st.session_state["generated"][i], key=str(i)) + + +# Usage sidebar +# Put at the end to display even the first input +with st.sidebar: + if st.session_state["usage"]: + st.divider() + st.title("Usage", help=USAGE_HELP) + col1, col2 = st.columns(2) + col1.metric("Total Tokens", st.session_state["usage"]["total_tokens"]) + col2.metric("Total Costs in $", st.session_state["usage"]["total_cost"]) diff --git a/constants.py b/constants.py index 413144a..682d09d 100644 --- a/constants.py +++ b/constants.py @@ -23,3 +23,9 @@ You can create an ActiveLoops account (including 500GB of free database storage) Once you are logged in, you find the API token [here](https://app.activeloop.ai/profile/gustavz/apitoken).\n The organisation name is your username, or you can create new organisations [here](https://app.activeloop.ai/organization/new/create) """ + +USAGE_HELP = f""" +These are the accumulated OpenAI API usage metrics.\n +The app uses '{MODEL}' for chat and 'text-embedding-ada-002' for embeddings.\n +Learn more about OpenAI's pricing [here](https://openai.com/pricing#language-models) +""" diff --git a/utils.py b/utils.py index 2a87e90..b6d8204 100644 --- a/utils.py +++ b/utils.py @@ -7,6 +7,7 @@ import sys import deeplake import openai import streamlit as st +from langchain.callbacks import get_openai_callback from langchain.chains import ConversationalRetrievalChain from langchain.chat_models import ChatOpenAI from langchain.document_loaders import ( @@ -101,7 +102,7 @@ def save_uploaded_file(uploaded_file): file = open(file_path, "wb") file.write(file_bytes) file.close() - logger.info(f"Saved {file_path}") + logger.info(f"Saved: {file_path}") return file_path @@ -110,7 +111,7 @@ def delete_uploaded_file(uploaded_file): file_path = DATA_PATH / uploaded_file.name if os.path.exists(DATA_PATH): os.remove(file_path) - logger.info(f"Removed {file_path}") + logger.info(f"Removed: {file_path}") def load_git(data_source): @@ -152,14 +153,14 @@ def load_any_data_source(data_source): loader = None if is_dir: loader = DirectoryLoader(data_source, recursive=True, silent_errors=True) - if is_git: + elif is_git: return load_git(data_source) - if is_web: + elif is_web: if is_pdf: loader = OnlinePDFLoader(data_source) else: loader = WebBaseLoader(data_source) - if is_file: + elif is_file: if is_text: loader = TextLoader(data_source) elif is_notebook: @@ -179,7 +180,7 @@ def load_any_data_source(data_source): if loader: text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) docs = loader.load_and_split(text_splitter) - logger.info(f"Loaded {len(docs)} document chucks") + logger.info(f"Loaded: {len(docs)} document chucks") return docs error_msg = f"Failed to load {data_source}" @@ -205,7 +206,7 @@ def setup_vector_store(data_source): dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}" if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]): with st.spinner("Loading vector store..."): - logger.info(f"{dataset_path} exists -> loading") + logger.info(f"Dataset '{dataset_path}' exists -> loading") vector_store = DeepLake( dataset_path=dataset_path, read_only=True, @@ -213,8 +214,10 @@ def setup_vector_store(data_source): token=st.session_state["activeloop_token"], ) else: - with st.spinner("Reading, embedding and uploading data to hub..."): - logger.info(f"{dataset_path} does not exist -> uploading") + with st.spinner( + "Reading, embedding and uploading data to hub..." + ), get_openai_callback() as cb: + logger.info(f"Dataset '{dataset_path}' does not exist -> uploading") docs = load_any_data_source(data_source) vector_store = DeepLake.from_documents( docs, @@ -222,6 +225,7 @@ def setup_vector_store(data_source): dataset_path=f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}", token=st.session_state["activeloop_token"], ) + update_usage(cb) return vector_store @@ -247,7 +251,7 @@ def get_chain(data_source): verbose=True, max_tokens_limit=3375, ) - logger.info(f"{data_source} is ready to go!") + logger.info(f"Data source '{data_source}' is ready to go!") return chain @@ -258,12 +262,28 @@ def build_chain_and_clear_history(data_source): st.session_state["chat_history"] = [] +def update_usage(cb): + # Accumulate API call usage via callbacks + logger.info(f"Usage: {cb}") + callback_properties = [ + "total_tokens", + "prompt_tokens", + "completion_tokens", + "total_cost", + ] + for prop in callback_properties: + value = getattr(cb, prop, 0) + st.session_state["usage"].setdefault(prop, 0) + st.session_state["usage"][prop] += value + + def generate_response(prompt): # call the chain to generate responses and add them to the chat history - with st.spinner("Generating response"): + with st.spinner("Generating response"), get_openai_callback() as cb: response = st.session_state["chain"]( {"question": prompt, "chat_history": st.session_state["chat_history"]} ) - logger.info(f"{response=}") - st.session_state["chat_history"].append((prompt, response["answer"])) + update_usage(cb) + logger.info(f"Response: '{response}'") + st.session_state["chat_history"].append((prompt, response["answer"])) return response["answer"]