mirror of https://github.com/arc53/DocsGPT
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
36 lines
1.3 KiB
Python
36 lines
1.3 KiB
Python
9 months ago
|
from application.llm.base import BaseLLM
|
||
|
|
||
|
class LlamaCpp(BaseLLM):
|
||
|
|
||
|
def __init__(self, api_key, llm_name='/Users/pavel/Desktop/docsgpt/application/models/orca-test.bin'):
|
||
|
global llama
|
||
|
from llama_cpp import Llama
|
||
|
|
||
|
llama = Llama(model_path=llm_name)
|
||
|
|
||
|
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||
|
context = messages[0]['content']
|
||
|
user_question = messages[-1]['content']
|
||
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||
|
|
||
|
result = llama(prompt, max_tokens=150, echo=False)
|
||
|
|
||
|
# import sys
|
||
|
# print(result['choices'][0]['text'].split('### Answer \n')[-1], file=sys.stderr)
|
||
|
|
||
|
return result['choices'][0]['text'].split('### Answer \n')[-1]
|
||
|
|
||
|
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||
|
context = messages[0]['content']
|
||
|
user_question = messages[-1]['content']
|
||
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||
|
|
||
|
result = llama(prompt, max_tokens=150, echo=False, stream=stream)
|
||
|
|
||
|
# import sys
|
||
|
# print(list(result), file=sys.stderr)
|
||
|
|
||
|
for item in result:
|
||
|
for choice in item['choices']:
|
||
|
yield choice['text']
|