Update to gpt4all version 1.0.1. Implement the Streaming version of the completions endpoint. Implemented an openai python client test for the new streaming functionality. (#1129)

Co-authored-by: Brandon <bbeiler@ridgelineintl.com>
This commit is contained in:
Brandon Beiler 2023-07-05 23:17:30 -04:00 committed by GitHub
parent affd0af51f
commit fb576fbd7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 36 deletions

View File

@ -1,6 +1,9 @@
import json
from fastapi import APIRouter, Depends, Response, Security, status from fastapi import APIRouter, Depends, Response, Security, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict from typing import List, Dict, Iterable, AsyncIterable
import logging import logging
from uuid import uuid4 from uuid import uuid4
from api_v1.settings import settings from api_v1.settings import settings
@ -10,6 +13,7 @@ import time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml ### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
@ -28,10 +32,13 @@ class CompletionChoice(BaseModel):
logprobs: float logprobs: float
finish_reason: str finish_reason: str
class CompletionUsage(BaseModel): class CompletionUsage(BaseModel):
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
total_tokens: int total_tokens: int
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
id: str id: str
object: str = 'text_completion' object: str = 'text_completion'
@ -41,46 +48,81 @@ class CompletionResponse(BaseModel):
usage: CompletionUsage usage: CompletionUsage
class CompletionStreamResponse(BaseModel):
id: str
object: str = 'text_completion'
created: int
model: str
choices: List[CompletionChoice]
router = APIRouter(prefix="/completions", tags=["Completion Endpoints"]) router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])
def stream_completion(output: Iterable, base_response: CompletionStreamResponse):
"""
Streams a GPT4All output to the client.
Args:
output: The output of GPT4All.generate(), which is an iterable of tokens.
base_response: The base response object, which is cloned and modified for each token.
Returns:
A Generator of CompletionStreamResponse objects, which are serialized to JSON Event Stream format.
"""
for token in output:
chunk = base_response.copy()
chunk.choices = [dict(CompletionChoice(
text=token,
index=0,
logprobs=-1,
finish_reason=''
))]
yield f"data: {json.dumps(dict(chunk))}\n\n"
@router.post("/", response_model=CompletionResponse) @router.post("/", response_model=CompletionResponse)
async def completions(request: CompletionRequest): async def completions(request: CompletionRequest):
''' '''
Completes a GPT4All model response. Completes a GPT4All model response.
''' '''
# global model
if request.stream:
raise NotImplementedError("Streaming is not yet implements")
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path) model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
output = model.generate(prompt=request.prompt, output = model.generate(prompt=request.prompt,
n_predict = request.max_tokens, n_predict=request.max_tokens,
top_k = 20, streaming=request.stream,
top_p = request.top_p, top_k=20,
temp=request.temperature, top_p=request.top_p,
n_batch = 1024, temp=request.temperature,
repeat_penalty = 1.2, n_batch=1024,
repeat_last_n = 10, repeat_penalty=1.2,
context_erase = 0) repeat_last_n=10)
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, #TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)
# If streaming, we need to return a StreamingResponse
if request.stream:
base_chunk = CompletionStreamResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[]
)
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
media_type="text/event-stream")
else:
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, #TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)

View File

@ -23,6 +23,25 @@ def test_completion():
assert len(response['choices'][0]['text']) > len(prompt) assert len(response['choices'][0]['text']) > len(prompt)
print(response) print(response)
def test_streaming_completion():
model = "gpt4all-j-v1.3-groovy"
prompt = "Who is Michael Jordan?"
tokens = []
for resp in openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=50,
temperature=0.28,
top_p=0.95,
n=1,
echo=True,
stream=True):
tokens.append(resp.choices[0].text)
assert (len(tokens) > 0)
assert (len("".join(tokens)) > len(prompt))
# def test_chat_completions(): # def test_chat_completions():
# model = "gpt4all-j-v1.3-groovy" # model = "gpt4all-j-v1.3-groovy"
# prompt = "Who is Michael Jordan?" # prompt = "Who is Michael Jordan?"
@ -30,6 +49,3 @@ def test_completion():
# model=model, # model=model,
# messages=[] # messages=[]
# ) # )

View File

@ -5,6 +5,6 @@ requests>=2.24.0
ujson>=2.0.2 ujson>=2.0.2
fastapi>=0.95.0 fastapi>=0.95.0
Jinja2>=3.0 Jinja2>=3.0
gpt4all==0.2.3 gpt4all==1.0.1
pytest pytest
openai openai