diff --git a/gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py b/gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py index a482b263..ba0dba6a 100644 --- a/gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py +++ b/gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py @@ -1,6 +1,9 @@ +import json + from fastapi import APIRouter, Depends, Response, Security, status +from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field -from typing import List, Dict +from typing import List, Dict, Iterable, AsyncIterable import logging from uuid import uuid4 from api_v1.settings import settings @@ -10,6 +13,7 @@ import time logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) + ### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml class CompletionRequest(BaseModel): @@ -28,10 +32,13 @@ class CompletionChoice(BaseModel): logprobs: float finish_reason: str + class CompletionUsage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int + + class CompletionResponse(BaseModel): id: str object: str = 'text_completion' @@ -41,46 +48,81 @@ class CompletionResponse(BaseModel): usage: CompletionUsage +class CompletionStreamResponse(BaseModel): + id: str + object: str = 'text_completion' + created: int + model: str + choices: List[CompletionChoice] + + router = APIRouter(prefix="/completions", tags=["Completion Endpoints"]) + +def stream_completion(output: Iterable, base_response: CompletionStreamResponse): + """ + Streams a GPT4All output to the client. + + Args: + output: The output of GPT4All.generate(), which is an iterable of tokens. + base_response: The base response object, which is cloned and modified for each token. + + Returns: + A Generator of CompletionStreamResponse objects, which are serialized to JSON Event Stream format. + """ + for token in output: + chunk = base_response.copy() + chunk.choices = [dict(CompletionChoice( + text=token, + index=0, + logprobs=-1, + finish_reason='' + ))] + yield f"data: {json.dumps(dict(chunk))}\n\n" + + @router.post("/", response_model=CompletionResponse) async def completions(request: CompletionRequest): ''' Completes a GPT4All model response. ''' - # global model - if request.stream: - raise NotImplementedError("Streaming is not yet implements") - model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path) output = model.generate(prompt=request.prompt, - n_predict = request.max_tokens, - top_k = 20, - top_p = request.top_p, - temp=request.temperature, - n_batch = 1024, - repeat_penalty = 1.2, - repeat_last_n = 10, - context_erase = 0) - - - return CompletionResponse( - id=str(uuid4()), - created=time.time(), - model=request.model, - choices=[dict(CompletionChoice( - text=output, - index=0, - logprobs=-1, - finish_reason='stop' - ))], - usage={ - 'prompt_tokens': 0, #TODO how to compute this? - 'completion_tokens': 0, - 'total_tokens': 0 - } - ) - + n_predict=request.max_tokens, + streaming=request.stream, + top_k=20, + top_p=request.top_p, + temp=request.temperature, + n_batch=1024, + repeat_penalty=1.2, + repeat_last_n=10) + # If streaming, we need to return a StreamingResponse + if request.stream: + base_chunk = CompletionStreamResponse( + id=str(uuid4()), + created=time.time(), + model=request.model, + choices=[] + ) + return StreamingResponse((response for response in stream_completion(output, base_chunk)), + media_type="text/event-stream") + else: + return CompletionResponse( + id=str(uuid4()), + created=time.time(), + model=request.model, + choices=[dict(CompletionChoice( + text=output, + index=0, + logprobs=-1, + finish_reason='stop' + ))], + usage={ + 'prompt_tokens': 0, #TODO how to compute this? + 'completion_tokens': 0, + 'total_tokens': 0 + } + ) diff --git a/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py b/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py index b207aa17..fad9bd24 100644 --- a/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py +++ b/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py @@ -23,6 +23,25 @@ def test_completion(): assert len(response['choices'][0]['text']) > len(prompt) print(response) + +def test_streaming_completion(): + model = "gpt4all-j-v1.3-groovy" + prompt = "Who is Michael Jordan?" + tokens = [] + for resp in openai.Completion.create( + model=model, + prompt=prompt, + max_tokens=50, + temperature=0.28, + top_p=0.95, + n=1, + echo=True, + stream=True): + tokens.append(resp.choices[0].text) + + assert (len(tokens) > 0) + assert (len("".join(tokens)) > len(prompt)) + # def test_chat_completions(): # model = "gpt4all-j-v1.3-groovy" # prompt = "Who is Michael Jordan?" @@ -30,6 +49,3 @@ def test_completion(): # model=model, # messages=[] # ) - - - diff --git a/gpt4all-api/gpt4all_api/requirements.txt b/gpt4all-api/gpt4all_api/requirements.txt index 2d75043b..af33bdd8 100644 --- a/gpt4all-api/gpt4all_api/requirements.txt +++ b/gpt4all-api/gpt4all_api/requirements.txt @@ -5,6 +5,6 @@ requests>=2.24.0 ujson>=2.0.2 fastapi>=0.95.0 Jinja2>=3.0 -gpt4all==0.2.3 +gpt4all==1.0.1 pytest openai \ No newline at end of file