This commit is contained in:
Alex 2023-05-25 15:14:47 +01:00
parent 6b6737613a
commit 0b78480977
2 changed files with 21 additions and 2 deletions

View File

@ -27,6 +27,7 @@ from langchain.prompts.chat import (
)
from pymongo import MongoClient
from werkzeug.utils import secure_filename
from langchain.llms import GPT4All
from core.settings import settings
from error import bad_request
@ -195,6 +196,9 @@ def api_answer():
llm = HuggingFaceHub(repo_id="bigscience/bloom", huggingfacehub_api_token=api_key)
elif settings.LLM_NAME == "cohere":
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
elif settings.LLM_NAME == "gpt4all":
llm = GPT4All(model="/Users/alextu/Library/Application Support/nomic.ai/GPT4All/",
backend='gpt4all-j-v1.3-groovy')
else:
raise ValueError("unknown LLM model")
@ -210,6 +214,19 @@ def api_answer():
# result = chain({"question": question, "chat_history": chat_history})
# generate async with async generate method
result = run_async_chain(chain, question, chat_history)
elif settings.LLM_NAME == "gpt4all":
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
doc_chain = load_qa_chain(llm, chain_type="map_reduce", combine_prompt=p_chat_combine)
chain = ConversationalRetrievalChain(
retriever=docsearch.as_retriever(k=2),
question_generator=question_generator,
combine_docs_chain=doc_chain,
)
chat_history = []
# result = chain({"question": question, "chat_history": chat_history})
# generate async with async generate method
result = run_async_chain(chain, question, chat_history)
else:
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
combine_prompt=chat_combine_template, question_prompt=q_prompt)

View File

@ -26,7 +26,8 @@ ecdsa==0.18.0
entrypoints==0.4
faiss-cpu==1.7.3
filelock==3.9.0
Flask==2.3.2
Flask==2.2.3
Flask-Cors==3.0.10
frozenlist==1.3.3
geojson==2.5.0
greenlet==2.0.2
@ -39,7 +40,8 @@ Jinja2==3.1.2
jmespath==1.0.1
joblib==1.2.0
kombu==5.2.4
langchain==0.0.126
langchain==0.0.179
loguru==0.6.0
lxml==4.9.2
MarkupSafe==2.1.2
marshmallow==3.19.0