Update to gpt4all version 1.0.1. Implement the Streaming version of the completions endpoint. Implemented an openai python client test for the new streaming functionality. (#1129)

Co-authored-by: Brandon <bbeiler@ridgelineintl.com>
This commit is contained in:
Brandon Beiler 2023-07-05 23:17:30 -04:00 committed by GitHub
parent affd0af51f
commit fb576fbd7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 36 deletions

View File

@ -1,6 +1,9 @@
import json
from fastapi import APIRouter, Depends, Response, Security, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Dict
from typing import List, Dict, Iterable, AsyncIterable
import logging
from uuid import uuid4
from api_v1.settings import settings
@ -10,6 +13,7 @@ import time
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
class CompletionRequest(BaseModel):
@ -28,10 +32,13 @@ class CompletionChoice(BaseModel):
logprobs: float
finish_reason: str
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class CompletionResponse(BaseModel):
id: str
object: str = 'text_completion'
@ -41,46 +48,81 @@ class CompletionResponse(BaseModel):
usage: CompletionUsage
class CompletionStreamResponse(BaseModel):
id: str
object: str = 'text_completion'
created: int
model: str
choices: List[CompletionChoice]
router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])
def stream_completion(output: Iterable, base_response: CompletionStreamResponse):
"""
Streams a GPT4All output to the client.
Args:
output: The output of GPT4All.generate(), which is an iterable of tokens.
base_response: The base response object, which is cloned and modified for each token.
Returns:
A Generator of CompletionStreamResponse objects, which are serialized to JSON Event Stream format.
"""
for token in output:
chunk = base_response.copy()
chunk.choices = [dict(CompletionChoice(
text=token,
index=0,
logprobs=-1,
finish_reason=''
))]
yield f"data: {json.dumps(dict(chunk))}\n\n"
@router.post("/", response_model=CompletionResponse)
async def completions(request: CompletionRequest):
'''
Completes a GPT4All model response.
'''
# global model
if request.stream:
raise NotImplementedError("Streaming is not yet implements")
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
output = model.generate(prompt=request.prompt,
n_predict = request.max_tokens,
top_k = 20,
top_p = request.top_p,
temp=request.temperature,
n_batch = 1024,
repeat_penalty = 1.2,
repeat_last_n = 10,
context_erase = 0)
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, #TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)
n_predict=request.max_tokens,
streaming=request.stream,
top_k=20,
top_p=request.top_p,
temp=request.temperature,
n_batch=1024,
repeat_penalty=1.2,
repeat_last_n=10)
# If streaming, we need to return a StreamingResponse
if request.stream:
base_chunk = CompletionStreamResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[]
)
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
media_type="text/event-stream")
else:
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, #TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)

View File

@ -23,6 +23,25 @@ def test_completion():
assert len(response['choices'][0]['text']) > len(prompt)
print(response)
def test_streaming_completion():
model = "gpt4all-j-v1.3-groovy"
prompt = "Who is Michael Jordan?"
tokens = []
for resp in openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=50,
temperature=0.28,
top_p=0.95,
n=1,
echo=True,
stream=True):
tokens.append(resp.choices[0].text)
assert (len(tokens) > 0)
assert (len("".join(tokens)) > len(prompt))
# def test_chat_completions():
# model = "gpt4all-j-v1.3-groovy"
# prompt = "Who is Michael Jordan?"
@ -30,6 +49,3 @@ def test_completion():
# model=model,
# messages=[]
# )

View File

@ -5,6 +5,6 @@ requests>=2.24.0
ujson>=2.0.2
fastapi>=0.95.0
Jinja2>=3.0
gpt4all==0.2.3
gpt4all==1.0.1
pytest
openai