|
|
|
@ -2,7 +2,7 @@ import json
|
|
|
|
|
from typing import List, Dict, Iterable, AsyncIterable
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
from typing import Dict, List, Union
|
|
|
|
|
from typing import Dict, List, Union, Optional
|
|
|
|
|
from uuid import uuid4
|
|
|
|
|
import aiohttp
|
|
|
|
|
import asyncio
|
|
|
|
@ -24,8 +24,8 @@ class CompletionRequest(BaseModel):
|
|
|
|
|
prompt: Union[List[str], str] = Field(..., description='The prompt to begin completing from.')
|
|
|
|
|
max_tokens: int = Field(None, description='Max tokens to generate')
|
|
|
|
|
temperature: float = Field(settings.temp, description='Model temperature')
|
|
|
|
|
top_p: float = Field(settings.top_p, description='top_p')
|
|
|
|
|
top_k: int = Field(settings.top_k, description='top_k')
|
|
|
|
|
top_p: Optional[float] = Field(settings.top_p, description='top_p')
|
|
|
|
|
top_k: Optional[int] = Field(settings.top_k, description='top_k')
|
|
|
|
|
n: int = Field(1, description='How many completions to generate for each prompt')
|
|
|
|
|
stream: bool = Field(False, description='Stream responses')
|
|
|
|
|
repeat_penalty: float = Field(settings.repeat_penalty, description='Repeat penalty')
|
|
|
|
|