mirror of
https://github.com/arc53/DocsGPT
synced 2024-11-05 21:21:02 +00:00
Added llm model variable
This commit is contained in:
parent
43622e7ab1
commit
c9d24b8f42
@ -28,21 +28,17 @@ from werkzeug.utils import secure_filename
|
||||
|
||||
from error import bad_request
|
||||
from worker import ingest_worker
|
||||
from core.settings import settings
|
||||
import celeryconfig
|
||||
|
||||
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
||||
|
||||
if os.getenv("LLM_NAME") is not None:
|
||||
llm_choice = os.getenv("LLM_NAME")
|
||||
else:
|
||||
llm_choice = "openai_chat"
|
||||
|
||||
if os.getenv("EMBEDDINGS_NAME") is not None:
|
||||
embeddings_choice = os.getenv("EMBEDDINGS_NAME")
|
||||
else:
|
||||
embeddings_choice = "openai_text-embedding-ada-002"
|
||||
|
||||
if llm_choice == "manifest":
|
||||
if settings.LLM_NAME == "manifest":
|
||||
from manifest import Manifest
|
||||
from langchain.llms.manifest import ManifestWrapper
|
||||
|
||||
@ -122,7 +118,7 @@ def ingest(self, directory, formats, name_job, filename, user):
|
||||
|
||||
@app.route("/")
|
||||
def home():
|
||||
return render_template("index.html", api_key_set=api_key_set, llm_choice=llm_choice,
|
||||
return render_template("index.html", api_key_set=api_key_set, llm_choice=settings.LLM_NAME,
|
||||
embeddings_choice=embeddings_choice)
|
||||
|
||||
|
||||
@ -182,7 +178,7 @@ def api_answer():
|
||||
|
||||
q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
|
||||
template_format="jinja2")
|
||||
if llm_choice == "openai_chat":
|
||||
if settings.LLM_NAME == "openai_chat":
|
||||
# llm = ChatOpenAI(openai_api_key=api_key, model_name="gpt-4")
|
||||
llm = ChatOpenAI(openai_api_key=api_key)
|
||||
messages_combine = [
|
||||
@ -195,16 +191,16 @@ def api_answer():
|
||||
HumanMessagePromptTemplate.from_template("{question}")
|
||||
]
|
||||
p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
|
||||
elif llm_choice == "openai":
|
||||
elif settings.LLM_NAME == "openai":
|
||||
llm = OpenAI(openai_api_key=api_key, temperature=0)
|
||||
elif llm_choice == "manifest":
|
||||
elif settings.LLM_NAME == "manifest":
|
||||
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048})
|
||||
elif llm_choice == "huggingface":
|
||||
elif settings.LLM_NAME == "huggingface":
|
||||
llm = HuggingFaceHub(repo_id="bigscience/bloom", huggingfacehub_api_token=api_key)
|
||||
elif llm_choice == "cohere":
|
||||
elif settings.LLM_NAME == "cohere":
|
||||
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
|
||||
|
||||
if llm_choice == "openai_chat":
|
||||
if settings.LLM_NAME == "openai_chat":
|
||||
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
|
||||
doc_chain = load_qa_chain(llm, chain_type="map_reduce", combine_prompt=p_chat_combine)
|
||||
chain = ConversationalRetrievalChain(
|
||||
|
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
LLM_NAME: str = "openai_chat"
|
||||
openai_token: str
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user