From 889c8d175892f55966397e3d970277d9b6544198 Mon Sep 17 00:00:00 2001 From: David Okpare Date: Thu, 10 Aug 2023 15:43:07 +0100 Subject: [PATCH] Add embeddings endpoint for gpt4all-api (#1314) * Add embeddings endpoint * Add test for embedding endpoint --- .../app/api_v1/routes/embeddings.py | 65 +++++++++++++++++++ .../gpt4all_api/app/tests/test_endpoints.py | 14 ++++ 2 files changed, 79 insertions(+) create mode 100644 gpt4all-api/gpt4all_api/app/api_v1/routes/embeddings.py diff --git a/gpt4all-api/gpt4all_api/app/api_v1/routes/embeddings.py b/gpt4all-api/gpt4all_api/app/api_v1/routes/embeddings.py new file mode 100644 index 00000000..50a5590f --- /dev/null +++ b/gpt4all-api/gpt4all_api/app/api_v1/routes/embeddings.py @@ -0,0 +1,65 @@ +from typing import List, Union +from fastapi import APIRouter +from api_v1.settings import settings +from gpt4all import Embed4All +from pydantic import BaseModel, Field + +### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml + + +class EmbeddingRequest(BaseModel): + model: str = Field( + settings.model, description="The model to generate an embedding from." + ) + input: Union[str, List[str], List[int], List[List[int]]] = Field( + ..., description="Input text to embed, encoded as a string or array of tokens." + ) + + +class EmbeddingUsage(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + + +class Embedding(BaseModel): + index: int = 0 + object: str = "embedding" + embedding: List[float] + + +class EmbeddingResponse(BaseModel): + object: str = "list" + model: str + data: List[Embedding] + usage: EmbeddingUsage + + +router = APIRouter(prefix="/embeddings", tags=["Embedding Endpoints"]) + +embedder = Embed4All() + + +def get_embedding(data: EmbeddingRequest) -> EmbeddingResponse: + """ + Calculates the embedding for the given input using a specified model. + + Args: + data (EmbeddingRequest): An EmbeddingRequest object containing the input data + and model name. + + Returns: + EmbeddingResponse: An EmbeddingResponse object encapsulating the calculated embedding, + usage info, and the model name. + """ + embedding = embedder.embed(data.input) + return EmbeddingResponse( + data=[Embedding(embedding=embedding)], usage=EmbeddingUsage(), model=data.model + ) + + +@router.post("/", response_model=EmbeddingResponse) +def embeddings(data: EmbeddingRequest): + """ + Creates a GPT4All embedding + """ + return get_embedding(data) diff --git a/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py b/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py index f9315cb3..a7f3f13c 100644 --- a/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py +++ b/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py @@ -1,6 +1,8 @@ """ Use the OpenAI python API to test gpt4all models. """ +from typing import List, get_args + import openai openai.api_base = "http://localhost:4891/v1" @@ -43,3 +45,15 @@ def test_batched_completion(): ) assert len(response['choices'][0]['text']) > len(prompt) assert len(response['choices']) == 3 + + +def test_embedding(): + model = "ggml-all-MiniLM-L6-v2-f16.bin" + prompt = "Who is Michael Jordan?" + response = openai.Embedding.create(model=model, input=prompt) + output = response["data"][0]["embedding"] + args = get_args(List[float]) + + assert response["model"] == model + assert isinstance(output, list) + assert all(isinstance(x, args) for x in output)