|
|
@ -43,6 +43,10 @@ from worker import ingest_worker
|
|
|
|
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
|
|
|
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
if settings.LLM_NAME == "gpt4":
|
|
|
|
|
|
|
|
gpt_model = 'gpt-4'
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
gpt_model = 'gpt-3.5-turbo'
|
|
|
|
|
|
|
|
|
|
|
|
if settings.LLM_NAME == "manifest":
|
|
|
|
if settings.LLM_NAME == "manifest":
|
|
|
|
from manifest import Manifest
|
|
|
|
from manifest import Manifest
|
|
|
@ -195,7 +199,7 @@ def complete_stream(question, docsearch, chat_history, api_key):
|
|
|
|
messages_combine.append({"role": "user", "content": i["prompt"]})
|
|
|
|
messages_combine.append({"role": "user", "content": i["prompt"]})
|
|
|
|
messages_combine.append({"role": "system", "content": i["response"]})
|
|
|
|
messages_combine.append({"role": "system", "content": i["response"]})
|
|
|
|
messages_combine.append({"role": "user", "content": question})
|
|
|
|
messages_combine.append({"role": "user", "content": question})
|
|
|
|
completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", engine=settings.AZURE_DEPLOYMENT_NAME,
|
|
|
|
completion = openai.ChatCompletion.create(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
|
|
|
|
messages=messages_combine, stream=True, max_tokens=500, temperature=0)
|
|
|
|
messages=messages_combine, stream=True, max_tokens=500, temperature=0)
|
|
|
|
|
|
|
|
|
|
|
|
for line in completion:
|
|
|
|
for line in completion:
|
|
|
@ -208,26 +212,27 @@ def complete_stream(question, docsearch, chat_history, api_key):
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/stream", methods=["POST", "GET"])
|
|
|
|
@app.route("/stream", methods=["POST"])
|
|
|
|
def stream():
|
|
|
|
def stream():
|
|
|
|
|
|
|
|
data = request.get_json()
|
|
|
|
# get parameter from url question
|
|
|
|
# get parameter from url question
|
|
|
|
question = request.args.get("question")
|
|
|
|
question = data["question"]
|
|
|
|
history = request.args.get("history")
|
|
|
|
history = data["history"]
|
|
|
|
# history to json object from string
|
|
|
|
# history to json object from string
|
|
|
|
history = json.loads(history)
|
|
|
|
history = json.loads(history)
|
|
|
|
|
|
|
|
|
|
|
|
# check if active_docs is set
|
|
|
|
# check if active_docs is set
|
|
|
|
|
|
|
|
|
|
|
|
if not api_key_set:
|
|
|
|
if not api_key_set:
|
|
|
|
api_key = request.args.get("api_key")
|
|
|
|
api_key = data["api_key"]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
api_key = settings.API_KEY
|
|
|
|
api_key = settings.API_KEY
|
|
|
|
if not embeddings_key_set:
|
|
|
|
if not embeddings_key_set:
|
|
|
|
embeddings_key = request.args.get("embeddings_key")
|
|
|
|
embeddings_key = data["embeddings_key"]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
embeddings_key = settings.EMBEDDINGS_KEY
|
|
|
|
embeddings_key = settings.EMBEDDINGS_KEY
|
|
|
|
if "active_docs" in request.args:
|
|
|
|
if "active_docs" in data:
|
|
|
|
vectorstore = get_vectorstore({"active_docs": request.args.get("active_docs")})
|
|
|
|
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
vectorstore = ""
|
|
|
|
vectorstore = ""
|
|
|
|
docsearch = get_docsearch(vectorstore, embeddings_key)
|
|
|
|
docsearch = get_docsearch(vectorstore, embeddings_key)
|
|
|
@ -279,7 +284,7 @@ def api_answer():
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
logger.debug("plain OpenAI")
|
|
|
|
logger.debug("plain OpenAI")
|
|
|
|
llm = ChatOpenAI(openai_api_key=api_key) # optional parameter: model_name="gpt-4"
|
|
|
|
llm = ChatOpenAI(openai_api_key=api_key, model_name=gpt_model) # optional parameter: model_name="gpt-4"
|
|
|
|
messages_combine = [SystemMessagePromptTemplate.from_template(chat_combine_template)]
|
|
|
|
messages_combine = [SystemMessagePromptTemplate.from_template(chat_combine_template)]
|
|
|
|
if history:
|
|
|
|
if history:
|
|
|
|
tokens_current_history = 0
|
|
|
|
tokens_current_history = 0
|
|
|
@ -597,4 +602,4 @@ def after_request(response):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
app.run(debug=True, port=5001)
|
|
|
|
app.run(debug=True, port=7091)
|
|
|
|