refactor to enable mode and model selection
parent
bcd4395907
commit
56cd7e3ba5
@ -0,0 +1,77 @@
|
||||
import streamlit as st
|
||||
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
|
||||
from langchain.chains import ConversationalRetrievalChain
|
||||
|
||||
from datachad.constants import PAGE_ICON
|
||||
from datachad.database import get_vector_store
|
||||
from datachad.models import get_model
|
||||
from datachad.utils import logger
|
||||
|
||||
|
||||
def get_chain() -> ConversationalRetrievalChain:
|
||||
# create the langchain that will be called to generate responses
|
||||
vector_store = get_vector_store()
|
||||
retriever = vector_store.as_retriever()
|
||||
# Search params "fetch_k" and "k" define how many documents are pulled from the hub
|
||||
# and selected after the document matching to build the context
|
||||
# that is fed to the model together with your prompt
|
||||
search_kwargs = {
|
||||
"maximal_marginal_relevance": True,
|
||||
"distance_metric": "cos",
|
||||
"fetch_k": st.session_state["fetch_k"],
|
||||
"k": st.session_state["k"],
|
||||
}
|
||||
retriever.search_kwargs.update(search_kwargs)
|
||||
model = get_model()
|
||||
chain = ConversationalRetrievalChain.from_llm(
|
||||
model,
|
||||
retriever=retriever,
|
||||
chain_type="stuff",
|
||||
verbose=True,
|
||||
# we limit the maximum number of used tokens
|
||||
# to prevent running into the models token limit of 4096
|
||||
max_tokens_limit=st.session_state["max_tokens"],
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
def update_chain() -> None:
|
||||
# Build chain with parameters from session state and store it back
|
||||
# Also delete chat history to not confuse the bot with old context
|
||||
try:
|
||||
st.session_state["chain"] = get_chain()
|
||||
st.session_state["chat_history"] = []
|
||||
msg = f"Data source '{st.session_state['data_source']}' is ready to go!"
|
||||
logger.info(msg)
|
||||
st.info(msg, icon=PAGE_ICON)
|
||||
except Exception as e:
|
||||
msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}"
|
||||
logger.error(msg)
|
||||
st.error(msg, icon=PAGE_ICON)
|
||||
|
||||
|
||||
def update_usage(cb: OpenAICallbackHandler) -> None:
|
||||
# Accumulate API call usage via callbacks
|
||||
logger.info(f"Usage: {cb}")
|
||||
callback_properties = [
|
||||
"total_tokens",
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_cost",
|
||||
]
|
||||
for prop in callback_properties:
|
||||
value = getattr(cb, prop, 0)
|
||||
st.session_state["usage"].setdefault(prop, 0)
|
||||
st.session_state["usage"][prop] += value
|
||||
|
||||
|
||||
def generate_response(prompt: str) -> str:
|
||||
# call the chain to generate responses and add them to the chat history
|
||||
with st.spinner("Generating response"), get_openai_callback() as cb:
|
||||
response = st.session_state["chain"](
|
||||
{"question": prompt, "chat_history": st.session_state["chat_history"]}
|
||||
)
|
||||
update_usage(cb)
|
||||
logger.info(f"Response: '{response}'")
|
||||
st.session_state["chat_history"].append((prompt, response["answer"]))
|
||||
return response["answer"]
|
@ -0,0 +1,51 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import deeplake
|
||||
import streamlit as st
|
||||
from langchain.vectorstores import DeepLake, VectorStore
|
||||
|
||||
from datachad.constants import DATA_PATH
|
||||
from datachad.loader import load_data_source
|
||||
from datachad.models import MODES, get_embeddings
|
||||
from datachad.utils import logger
|
||||
|
||||
|
||||
def get_dataset_path() -> str:
|
||||
# replace all non-word characters with dashes
|
||||
# to get a string that can be used to create a new dataset
|
||||
dataset_name = re.sub(r"\W+", "-", st.session_state["data_source"])
|
||||
dataset_name = re.sub(r"--+", "- ", dataset_name).strip("-")
|
||||
if st.session_state["mode"] == MODES.LOCAL:
|
||||
if not os.path.exists(DATA_PATH):
|
||||
os.makedirs(DATA_PATH)
|
||||
dataset_path = str(DATA_PATH / dataset_name)
|
||||
else:
|
||||
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{dataset_name}-{st.session_state['chunk_size']}"
|
||||
return dataset_path
|
||||
|
||||
|
||||
def get_vector_store() -> VectorStore:
|
||||
# either load existing vector store or upload a new one to the hub
|
||||
embeddings = get_embeddings()
|
||||
dataset_path = get_dataset_path()
|
||||
if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]):
|
||||
with st.spinner("Loading vector store..."):
|
||||
logger.info(f"Dataset '{dataset_path}' exists -> loading")
|
||||
vector_store = DeepLake(
|
||||
dataset_path=dataset_path,
|
||||
read_only=True,
|
||||
embedding_function=embeddings,
|
||||
token=st.session_state["activeloop_token"],
|
||||
)
|
||||
else:
|
||||
with st.spinner("Reading, embedding and uploading data to hub..."):
|
||||
logger.info(f"Dataset '{dataset_path}' does not exist -> uploading")
|
||||
docs = load_data_source()
|
||||
vector_store = DeepLake.from_documents(
|
||||
docs,
|
||||
embeddings,
|
||||
dataset_path=dataset_path,
|
||||
token=st.session_state["activeloop_token"],
|
||||
)
|
||||
return vector_store
|
@ -0,0 +1,111 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import streamlit as st
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings
|
||||
from langchain.llms import GPT4All, LlamaCpp
|
||||
|
||||
from datachad.utils import logger
|
||||
|
||||
|
||||
class Enum:
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [v for k, v in cls.__dict__.items() if not k.startswith("_")]
|
||||
|
||||
@classmethod
|
||||
def dict(cls):
|
||||
return {k: v for k, v in cls.__dict__.items() if not k.startswith("_")}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Model:
|
||||
name: str
|
||||
mode: str
|
||||
embedding: str
|
||||
path: str = None # for local models only
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MODES(Enum):
|
||||
OPENAI = "OpenAI"
|
||||
LOCAL = "Local"
|
||||
|
||||
|
||||
class EMBEDDINGS(Enum):
|
||||
OPENAI = "openai"
|
||||
HUGGINGFACE = "all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
class MODELS(Enum):
|
||||
GPT35TURBO = Model("gpt-3.5-turbo", MODES.OPENAI, EMBEDDINGS.OPENAI)
|
||||
GPT4 = Model("gpt-4", MODES.OPENAI, EMBEDDINGS.OPENAI)
|
||||
LLAMACPP = Model(
|
||||
"LLAMA", MODES.LOCAL, EMBEDDINGS.HUGGINGFACE, "models/llamacpp.bin"
|
||||
)
|
||||
GPT4ALL = Model(
|
||||
"GPT4All", MODES.LOCAL, EMBEDDINGS.HUGGINGFACE, "models/gpt4all.bin"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_mode(cls, mode):
|
||||
return [v for v in cls.values() if isinstance(v, Model) and v.mode == mode]
|
||||
|
||||
|
||||
def get_model() -> BaseLanguageModel:
|
||||
match st.session_state["model"].name:
|
||||
case MODELS.GPT35TURBO.name:
|
||||
model = ChatOpenAI(
|
||||
model_name=st.session_state["model"].name,
|
||||
temperature=st.session_state["temperature"],
|
||||
openai_api_key=st.session_state["openai_api_key"],
|
||||
)
|
||||
case MODELS.GPT4.name:
|
||||
model = ChatOpenAI(
|
||||
model_name=st.session_state["model"].name,
|
||||
temperature=st.session_state["temperature"],
|
||||
openai_api_key=st.session_state["openai_api_key"],
|
||||
)
|
||||
case MODELS.LLAMACPP.name:
|
||||
model = LlamaCpp(
|
||||
model_path=st.session_state["model"].path,
|
||||
n_ctx=st.session_state["model_n_ctx"],
|
||||
temperature=st.session_state["temperature"],
|
||||
verbose=True,
|
||||
)
|
||||
case MODELS.GPT4ALL.name:
|
||||
model = GPT4All(
|
||||
model=st.session_state["model"].path,
|
||||
n_ctx=st.session_state["model_n_ctx"],
|
||||
backend="gptj",
|
||||
temp=st.session_state["temperature"],
|
||||
verbose=True,
|
||||
)
|
||||
# Add more models as needed
|
||||
case _default:
|
||||
msg = f"Model {st.session_state['model']} not supported!"
|
||||
logger.error(msg)
|
||||
st.error(msg)
|
||||
exit
|
||||
return model
|
||||
|
||||
|
||||
def get_embeddings() -> Embeddings:
|
||||
match st.session_state["model"].embedding:
|
||||
case EMBEDDINGS.OPENAI:
|
||||
embeddings = OpenAIEmbeddings(
|
||||
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
|
||||
)
|
||||
case EMBEDDINGS.HUGGINGFACE:
|
||||
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS.HUGGINGFACE)
|
||||
# Add more embeddings as needed
|
||||
case _default:
|
||||
msg = f"Embeddings {st.session_state['embeddings']} not supported!"
|
||||
logger.error(msg)
|
||||
st.error(msg)
|
||||
exit
|
||||
return embeddings
|
@ -0,0 +1,104 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import deeplake
|
||||
import openai
|
||||
import streamlit as st
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from datachad.constants import APP_NAME, DATA_PATH, PAGE_ICON
|
||||
|
||||
# loads environment variables
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(APP_NAME)
|
||||
|
||||
|
||||
def configure_logger(debug: int = 0) -> None:
|
||||
# boilerplate code to enable logging in the streamlit app console
|
||||
log_level = logging.DEBUG if debug == 1 else logging.INFO
|
||||
logger.setLevel(log_level)
|
||||
|
||||
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
stream_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter("%(message)s")
|
||||
|
||||
stream_handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(stream_handler)
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
configure_logger(0)
|
||||
|
||||
|
||||
def authenticate(
|
||||
openai_api_key: str, activeloop_token: str, activeloop_org_name: str
|
||||
) -> None:
|
||||
# Validate all credentials are set and correct
|
||||
# Check for env variables to enable local dev and deployments with shared credentials
|
||||
openai_api_key = (
|
||||
openai_api_key
|
||||
or os.environ.get("OPENAI_API_KEY")
|
||||
or st.secrets.get("OPENAI_API_KEY")
|
||||
)
|
||||
activeloop_token = (
|
||||
activeloop_token
|
||||
or os.environ.get("ACTIVELOOP_TOKEN")
|
||||
or st.secrets.get("ACTIVELOOP_TOKEN")
|
||||
)
|
||||
activeloop_org_name = (
|
||||
activeloop_org_name
|
||||
or os.environ.get("ACTIVELOOP_ORG_NAME")
|
||||
or st.secrets.get("ACTIVELOOP_ORG_NAME")
|
||||
)
|
||||
if not (openai_api_key and activeloop_token and activeloop_org_name):
|
||||
st.session_state["auth_ok"] = False
|
||||
st.error("Credentials neither set nor stored", icon=PAGE_ICON)
|
||||
return
|
||||
try:
|
||||
# Try to access openai and deeplake
|
||||
with st.spinner("Authentifying..."):
|
||||
openai.api_key = openai_api_key
|
||||
openai.Model.list()
|
||||
deeplake.exists(
|
||||
f"hub://{activeloop_org_name}/DataChad-Authentication-Check",
|
||||
token=activeloop_token,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication failed with {e}")
|
||||
st.session_state["auth_ok"] = False
|
||||
st.error("Authentication failed", icon=PAGE_ICON)
|
||||
return
|
||||
# store credentials in the session state
|
||||
st.session_state["auth_ok"] = True
|
||||
st.session_state["openai_api_key"] = openai_api_key
|
||||
st.session_state["activeloop_token"] = activeloop_token
|
||||
st.session_state["activeloop_org_name"] = activeloop_org_name
|
||||
logger.info("Authentification successful!")
|
||||
|
||||
|
||||
def save_uploaded_file() -> str:
|
||||
# streamlit uploaded files need to be stored locally
|
||||
# before embedded and uploaded to the hub
|
||||
uploaded_file = st.session_state["uploaded_file"]
|
||||
if not os.path.exists(DATA_PATH):
|
||||
os.makedirs(DATA_PATH)
|
||||
file_path = str(DATA_PATH / uploaded_file.name)
|
||||
uploaded_file.seek(0)
|
||||
file_bytes = uploaded_file.read()
|
||||
file = open(file_path, "wb")
|
||||
file.write(file_bytes)
|
||||
file.close()
|
||||
logger.info(f"Saved: {file_path}")
|
||||
return file_path
|
||||
|
||||
|
||||
def delete_uploaded_file() -> None:
|
||||
# cleanup locally stored files
|
||||
file_path = DATA_PATH / st.session_state["uploaded_file"].name
|
||||
if os.path.exists(DATA_PATH):
|
||||
os.remove(file_path)
|
||||
logger.info(f"Removed: {file_path}")
|
Loading…
Reference in New Issue