You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
32 lines
1.0 KiB
Python
32 lines
1.0 KiB
Python
import os
|
|
import openai
|
|
from tenacity import (
|
|
retry,
|
|
stop_after_attempt, # type: ignore
|
|
wait_random_exponential, # type: ignore
|
|
)
|
|
|
|
from typing import Optional, List, Union
|
|
|
|
openai.api_key = os.getenv('OPENAI_API_KEY')
|
|
|
|
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
|
def get_completion(prompt: Union[str, List[str]], max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> Union[str, List[str]]:
|
|
assert (not is_batched and isinstance(prompt, str)) or (is_batched and isinstance(prompt, list))
|
|
response = openai.Completion.create(
|
|
model='text-davinci-003',
|
|
prompt=prompt,
|
|
temperature=0.0,
|
|
max_tokens=max_tokens,
|
|
top_p=1,
|
|
frequency_penalty=0.0,
|
|
presence_penalty=0.0,
|
|
stop=stop_strs,
|
|
)
|
|
if is_batched:
|
|
res: List[str] = [""] * len(prompt)
|
|
for choice in response.choices:
|
|
res[choice.index] = choice.text
|
|
return res
|
|
return response.choices[0].text
|