Add OpenAI compatible API

pull/535/head
Roman Ageev 7 months ago
parent ae19b65095
commit 113e57b11d

@ -0,0 +1,85 @@
import argparse
import dataclasses
import uvicorn
from typing import Any
from fastapi import FastAPI, Request
from transformers import GenerationConfig, AutoTokenizer
from hivemind import get_logger
from petals.utils.auto_config import AutoDistributedModelForCausalLM
from petals.data_structures import FromPretrainedInfo
TIMEOUT_KEEP_ALIVE = 10 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
model = None
tokenizer = None
logger = get_logger(__name__)
@app.post("/generate")
async def generate(request: Request) -> dict[str, Any]:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the generation parameters (See `GenerationConfig` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
tokens = tokenizer(prompt, return_tensors="pt")["input_ids"]
sampling_params = GenerationConfig(**request_dict)
results_generator = model.generate(tokens, **sampling_params.to_diff_dict())
if stream:
# handle streaming case
raise NotImplementedError("Streaming is not supported yet")
# Non-streaming case
assert results_generator is not None
logger.debug(f"Generated {len(results_generator)} results")
return {"response": tokenizer.decode(results_generator[0])}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
# arguments for from_pretrained method
parser.add_argument("--model_name_or_path", type=str)
parser.add_argument("--low_cpu_mem_usage", type=bool, default=True)
parser.add_argument("--torch_dtype", type=str, default="auto")
parser.add_argument("--args", type=tuple, default=tuple())
args = parser.parse_args()
pretrained_info = FromPretrainedInfo(
model_name_or_path=args.model_name_or_path,
low_cpu_mem_usage=args.low_cpu_mem_usage,
torch_dtype=args.torch_dtype,
)
model = AutoDistributedModelForCausalLM.from_pretrained(*args.args, **dataclasses.asdict(pretrained_info))
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False, add_bos_token=False)
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE
)

@ -0,0 +1,669 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
import argparse
import asyncio
import dataclasses
import json
import time
from http import HTTPStatus
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
import fastapi
import uvicorn
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse, Response
from packaging import version
from hivemind import get_logger
from transformers import GenerationConfig, AutoTokenizer
from petals import AutoDistributedModelForCausalLM
from petals.api.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo, random_uuid
)
from petals.data_structures import FromPretrainedInfo
try:
import fastchat
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True
except ImportError:
_fastchat_available = False
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = get_logger(__name__)
app = fastapi.FastAPI()
model = None
tokenizer = None
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(ErrorResponse(message=message,
type="invalid_request_error").dict(),
status_code=status_code.value)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
async def check_model(request) -> Optional[JSONResponse]:
if request.model == served_model:
return
ret = create_error_response(
HTTPStatus.NOT_FOUND,
f"The model `{request.model}` does not exist.",
)
return ret
async def get_gen_prompt(request) -> str:
if not _fastchat_available:
raise ModuleNotFoundError(
"fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
f"fastchat version is low. Current version: {fastchat.__version__} "
"Please upgrade fastchat to use: `$ pip install -U fschat`")
conv = get_conversation_template(request.model)
conv = Conversation(
name=conv.name,
system_template=conv.system_template,
system_message=conv.system_message,
roles=conv.roles,
messages=list(conv.messages),
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
stop_token_ids=conv.stop_token_ids,
)
if isinstance(request.messages, str):
prompt = request.messages
else:
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
conv.system_message = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
async def check_length(
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None
) -> Tuple[List[int], Optional[JSONResponse]]:
assert (not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided."
if prompt_ids is not None:
input_ids = prompt_ids
else:
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)
if request.max_tokens is None:
request.max_tokens = max_model_len - token_num
if token_num + request.max_tokens > max_model_len:
return input_ids, create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
)
else:
return input_ids, None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards)
def create_logprobs(token_ids: List[int],
id_logprobs: List[Dict[int, float]],
initial_text_offset: int = 0) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()
})
return logprobs
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by Petals engine)
"""
logger.info(f"Received chat completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in Petals engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
try:
skip_special_tokens = request.skip_special_tokens
spaces_between_special_tokens = request.spaces_between_special_tokens
if request.ignore_eos:
logger.warning("ignore_eos is not supported yet")
if request.presence_penalty > 0.0:
logger.warning("presence_penalty is not supported yet")
generation_config = GenerationConfig(
num_return_sequences=request.n,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
max_length=request.max_tokens,
do_sample=True if request.temperature is not None or request.top_p is not None or request.top_k is not None else False,
num_beams=request.best_of if request.use_beam_search else 1,
early_stopping=request.stop,
eos_token_id=request.stop_token_ids,
repetition_penalty=1.0 + request.presence_penalty if request.presence_penalty else 1.0,
# output_scores=True, # Example: set to True if you want to output prediction scores
return_dict_in_generate=True,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
results_generator = model.generate(token_ids, **generation_config.to_diff_dict())
# def create_stream_response_json(
# index: int,
# text: str,
# finish_reason: Optional[str] = None,
# ) -> str:
# choice_data = ChatCompletionResponseStreamChoice(
# index=index,
# delta=DeltaMessage(content=text),
# finish_reason=finish_reason,
# )
# response = ChatCompletionStreamResponse(
# id=request_id,
# created=created_time,
# model=model_name,
# choices=[choice_data],
# )
# response_json = response.json(ensure_ascii=False)
#
# return response_json
#
# async def completion_stream_generator() -> AsyncGenerator[str, None]:
# # First chunk with role
# for i in range(request.n):
# choice_data = ChatCompletionResponseStreamChoice(
# index=i,
# delta=DeltaMessage(role="assistant"),
# finish_reason=None,
# )
# chunk = ChatCompletionStreamResponse(id=request_id,
# choices=[choice_data],
# model=model_name)
# data = chunk.json(exclude_unset=True, ensure_ascii=False)
# yield f"data: {data}\n\n"
#
# previous_texts = [""] * request.n
# previous_num_tokens = [0] * request.n
# async for res in result_generator:
# res: RequestOutput
# for output in res.outputs:
# i = output.index
# delta_text = output.text[len(previous_texts[i]):]
# previous_texts[i] = output.text
# previous_num_tokens[i] = len(output.token_ids)
# response_json = create_stream_response_json(
# index=i,
# text=delta_text,
# )
# yield f"data: {response_json}\n\n"
# if output.finish_reason is not None:
# response_json = create_stream_response_json(
# index=i,
# text="",
# finish_reason=output.finish_reason,
# )
# yield f"data: {response_json}\n\n"
# yield "data: [DONE]\n\n"
# Streaming response
if request.stream:
logger.warning("Streaming is not supported yet")
# return StreamingResponse(completion_stream_generator(),
# media_type="text/event-stream")
# Non-streaming response
# async for res in result_generator:
# if await raw_request.is_disconnected():
# # Abort the request if the client disconnects.
# # TODO: Abort the request if the client disconnects.
# return create_error_response(HTTPStatus.BAD_REQUEST,
# "Client disconnected")
# final_res = res
choices = []
for index, output_tokens in enumerate(results_generator):
text = tokenizer.decode(
output_tokens,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
choice_data = ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content=text),
finish_reason=None,
)
choices.append(choice_data)
num_prompt_tokens = len(token_ids)
num_generated_tokens = sum(
len(output) for output in results_generator)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the Petals engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by Petals engine)
- stream (to be supported by Petals engine)
- logprobs (to be supported by Petals engine)
"""
logger.info(f"Received completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.echo:
# We do not support echo since the Petals engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in Petals.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
prompt = request.prompt
elif isinstance(first_element, (str, list)):
# TODO: handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
prompt = request.prompt
if use_token_ids:
token_ids, error_check_ret = await check_length(request, prompt_ids=prompt)
else:
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
created_time = int(time.monotonic())
try:
skip_special_tokens = request.skip_special_tokens
spaces_between_special_tokens = request.spaces_between_special_tokens
if request.ignore_eos:
logger.warning("ignore_eos is not supported yet")
if request.presence_penalty > 0.0:
logger.warning("presence_penalty is not supported yet")
generation_config = GenerationConfig(
num_return_sequences=request.n,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
max_length=request.max_tokens,
do_sample=True if request.temperature is not None or request.top_p is not None or request.top_k is not None else False,
num_beams=request.best_of if request.use_beam_search else 1,
early_stopping=request.stop,
eos_token_id=request.stop_token_ids,
repetition_penalty=1.0 + request.presence_penalty if request.presence_penalty else 1.0,
# output_scores=True, # Example: set to True if you want to output prediction scores
return_dict_in_generate=True,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
results_generator = model.generate(token_ids, **generation_config.to_diff_dict())
if request.stream:
# handle streaming case
logger.warning("Streaming is not supported yet")
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
# stream = (request.stream
# and (request.best_of is None or request.n == request.best_of)
# and not request.use_beam_search)
#
# def create_stream_response_json(
# index: int,
# text: str,
# logprobs: Optional[LogProbs] = None,
# finish_reason: Optional[str] = None,
# ) -> str:
# choice_data = CompletionResponseStreamChoice(
# index=index,
# text=text,
# logprobs=logprobs,
# finish_reason=finish_reason,
# )
# response = CompletionStreamResponse(
# id=request_id,
# created=created_time,
# model=model_name,
# choices=[choice_data],
# )
# response_json = response.json(ensure_ascii=False)
#
# return response_json
#
# async def completion_stream_generator() -> AsyncGenerator[str, None]:
# previous_texts = [""] * request.n
# previous_num_tokens = [0] * request.n
# async for res in result_generator:
# res: RequestOutput
# for output in res.outputs:
# i = output.index
# delta_text = output.text[len(previous_texts[i]):]
# if request.logprobs is not None:
# logprobs = create_logprobs(
# output.token_ids[previous_num_tokens[i]:],
# output.logprobs[previous_num_tokens[i]:],
# len(previous_texts[i]))
# else:
# logprobs = None
# previous_texts[i] = output.text
# previous_num_tokens[i] = len(output.token_ids)
# response_json = create_stream_response_json(
# index=i,
# text=delta_text,
# logprobs=logprobs,
# )
# yield f"data: {response_json}\n\n"
# if output.finish_reason is not None:
# logprobs = (LogProbs()
# if request.logprobs is not None else None)
# response_json = create_stream_response_json(
# index=i,
# text="",
# logprobs=logprobs,
# finish_reason=output.finish_reason,
# )
# yield f"data: {response_json}\n\n"
# yield "data: [DONE]\n\n"
#
# # Streaming response
# if stream:
# return StreamingResponse(completion_stream_generator(),
# media_type="text/event-stream")
# Non-streaming response
# async for res in result_generator:
# if await raw_request.is_disconnected():
# # TODO: Abort the request if the client disconnects.
# return create_error_response(HTTPStatus.BAD_REQUEST,
# "Client disconnected")
# final_res = res
choices = []
for index, output_tokens in enumerate(results_generator):
if request.logprobs is not None:
# logprobs = create_logprobs(output.token_ids, output.logprobs)
logger.warning("logprobs is not supported yet")
text = tokenizer.decode(
output_tokens,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
logprobs = None
choice_data = CompletionResponseChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=None,
)
choices.append(choice_data)
num_prompt_tokens = len(token_ids)
num_generated_tokens = sum(
len(output) for output in results_generator)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Petals OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument("--model_name_or_path", type=str)
parser.add_argument("--low_cpu_mem_usage", type=bool, default=True)
parser.add_argument("--torch_dtype", type=str, default="auto")
parser.add_argument("--args", type=tuple, default=tuple())
args = parser.parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
logger.info(f"args: {args}")
if args.served_model_name is not None:
served_model = args.served_model_name
else:
served_model = args.model
pretrained_info = FromPretrainedInfo(
model_name_or_path=served_model,
low_cpu_mem_usage=args.low_cpu_mem_usage,
torch_dtype=args.torch_dtype,
)
model = AutoDistributedModelForCausalLM.from_pretrained(*args.args, **dataclasses.asdict(pretrained_info))
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False, add_bos_token=False)
max_model_len = model.config.max_position_embeddings
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE
)

@ -0,0 +1,182 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py
import time
import uuid
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
def random_uuid() -> str:
return str(uuid.uuid4().hex)
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: Optional[str] = None
class ModelPermission(BaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: str = False
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "vllm"
root: Optional[str] = None
parent: Optional[str] = None
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by vLLM
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
class CompletionRequest(BaseModel):
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by vLLM
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class CompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

@ -1,8 +1,10 @@
import dataclasses
import os
from enum import Enum
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import pydantic
import torch
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
@ -115,3 +117,10 @@ class InferenceMetadata:
prefix_length: int
cache_handles: Tuple[Handle, ...]
active_adapter: Optional[str]
@dataclasses.dataclass(frozen=True)
class FromPretrainedInfo:
model_name_or_path: Union[str, os.PathLike, None]
low_cpu_mem_usage: Optional[bool] = None
torch_dtype: Optional[Union[str, torch.dtype]] = None

@ -0,0 +1,15 @@
import requests
def test_api():
request = {
"prompt": "A cat sat on a mat",
"stream": False,
"max_new_tokens": 4,
}
response = requests.post("http://localhost:8000/generate", json=request)
assert response.status_code == 200
response.raise_for_status()
assert "response" in response.json(), "Response should contain a 'response' field"
assert isinstance(response.json()["response"], str), "Response should be a string"
assert len(response.json()["response"]) > len(request["prompt"]), "Response should be longer than the prompt"
Loading…
Cancel
Save