mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-04 12:00:10 +00:00
Update to gpt4all version 1.0.1. Implement the Streaming version of the completions endpoint. Implemented an openai python client test for the new streaming functionality. (#1129)
Co-authored-by: Brandon <bbeiler@ridgelineintl.com>
This commit is contained in:
parent
affd0af51f
commit
fb576fbd7e
@ -1,6 +1,9 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Response, Security, status
|
from fastapi import APIRouter, Depends, Response, Security, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Iterable, AsyncIterable
|
||||||
import logging
|
import logging
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from api_v1.settings import settings
|
from api_v1.settings import settings
|
||||||
@ -10,6 +13,7 @@ import time
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
@ -28,10 +32,13 @@ class CompletionChoice(BaseModel):
|
|||||||
logprobs: float
|
logprobs: float
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
class CompletionUsage(BaseModel):
|
class CompletionUsage(BaseModel):
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: str = 'text_completion'
|
object: str = 'text_completion'
|
||||||
@ -41,46 +48,81 @@ class CompletionResponse(BaseModel):
|
|||||||
usage: CompletionUsage
|
usage: CompletionUsage
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionStreamResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = 'text_completion'
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionChoice]
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])
|
router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])
|
||||||
|
|
||||||
|
|
||||||
|
def stream_completion(output: Iterable, base_response: CompletionStreamResponse):
|
||||||
|
"""
|
||||||
|
Streams a GPT4All output to the client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: The output of GPT4All.generate(), which is an iterable of tokens.
|
||||||
|
base_response: The base response object, which is cloned and modified for each token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Generator of CompletionStreamResponse objects, which are serialized to JSON Event Stream format.
|
||||||
|
"""
|
||||||
|
for token in output:
|
||||||
|
chunk = base_response.copy()
|
||||||
|
chunk.choices = [dict(CompletionChoice(
|
||||||
|
text=token,
|
||||||
|
index=0,
|
||||||
|
logprobs=-1,
|
||||||
|
finish_reason=''
|
||||||
|
))]
|
||||||
|
yield f"data: {json.dumps(dict(chunk))}\n\n"
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=CompletionResponse)
|
@router.post("/", response_model=CompletionResponse)
|
||||||
async def completions(request: CompletionRequest):
|
async def completions(request: CompletionRequest):
|
||||||
'''
|
'''
|
||||||
Completes a GPT4All model response.
|
Completes a GPT4All model response.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# global model
|
|
||||||
if request.stream:
|
|
||||||
raise NotImplementedError("Streaming is not yet implements")
|
|
||||||
|
|
||||||
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
|
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
|
||||||
|
|
||||||
output = model.generate(prompt=request.prompt,
|
output = model.generate(prompt=request.prompt,
|
||||||
n_predict = request.max_tokens,
|
n_predict=request.max_tokens,
|
||||||
top_k = 20,
|
streaming=request.stream,
|
||||||
top_p = request.top_p,
|
top_k=20,
|
||||||
temp=request.temperature,
|
top_p=request.top_p,
|
||||||
n_batch = 1024,
|
temp=request.temperature,
|
||||||
repeat_penalty = 1.2,
|
n_batch=1024,
|
||||||
repeat_last_n = 10,
|
repeat_penalty=1.2,
|
||||||
context_erase = 0)
|
repeat_last_n=10)
|
||||||
|
|
||||||
|
|
||||||
return CompletionResponse(
|
|
||||||
id=str(uuid4()),
|
|
||||||
created=time.time(),
|
|
||||||
model=request.model,
|
|
||||||
choices=[dict(CompletionChoice(
|
|
||||||
text=output,
|
|
||||||
index=0,
|
|
||||||
logprobs=-1,
|
|
||||||
finish_reason='stop'
|
|
||||||
))],
|
|
||||||
usage={
|
|
||||||
'prompt_tokens': 0, #TODO how to compute this?
|
|
||||||
'completion_tokens': 0,
|
|
||||||
'total_tokens': 0
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
# If streaming, we need to return a StreamingResponse
|
||||||
|
if request.stream:
|
||||||
|
base_chunk = CompletionStreamResponse(
|
||||||
|
id=str(uuid4()),
|
||||||
|
created=time.time(),
|
||||||
|
model=request.model,
|
||||||
|
choices=[]
|
||||||
|
)
|
||||||
|
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
|
||||||
|
media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
return CompletionResponse(
|
||||||
|
id=str(uuid4()),
|
||||||
|
created=time.time(),
|
||||||
|
model=request.model,
|
||||||
|
choices=[dict(CompletionChoice(
|
||||||
|
text=output,
|
||||||
|
index=0,
|
||||||
|
logprobs=-1,
|
||||||
|
finish_reason='stop'
|
||||||
|
))],
|
||||||
|
usage={
|
||||||
|
'prompt_tokens': 0, #TODO how to compute this?
|
||||||
|
'completion_tokens': 0,
|
||||||
|
'total_tokens': 0
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@ -23,6 +23,25 @@ def test_completion():
|
|||||||
assert len(response['choices'][0]['text']) > len(prompt)
|
assert len(response['choices'][0]['text']) > len(prompt)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_completion():
|
||||||
|
model = "gpt4all-j-v1.3-groovy"
|
||||||
|
prompt = "Who is Michael Jordan?"
|
||||||
|
tokens = []
|
||||||
|
for resp in openai.Completion.create(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=50,
|
||||||
|
temperature=0.28,
|
||||||
|
top_p=0.95,
|
||||||
|
n=1,
|
||||||
|
echo=True,
|
||||||
|
stream=True):
|
||||||
|
tokens.append(resp.choices[0].text)
|
||||||
|
|
||||||
|
assert (len(tokens) > 0)
|
||||||
|
assert (len("".join(tokens)) > len(prompt))
|
||||||
|
|
||||||
# def test_chat_completions():
|
# def test_chat_completions():
|
||||||
# model = "gpt4all-j-v1.3-groovy"
|
# model = "gpt4all-j-v1.3-groovy"
|
||||||
# prompt = "Who is Michael Jordan?"
|
# prompt = "Who is Michael Jordan?"
|
||||||
@ -30,6 +49,3 @@ def test_completion():
|
|||||||
# model=model,
|
# model=model,
|
||||||
# messages=[]
|
# messages=[]
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,6 +5,6 @@ requests>=2.24.0
|
|||||||
ujson>=2.0.2
|
ujson>=2.0.2
|
||||||
fastapi>=0.95.0
|
fastapi>=0.95.0
|
||||||
Jinja2>=3.0
|
Jinja2>=3.0
|
||||||
gpt4all==0.2.3
|
gpt4all==1.0.1
|
||||||
pytest
|
pytest
|
||||||
openai
|
openai
|
Loading…
Reference in New Issue
Block a user