mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-04 06:00:33 +00:00
Add chat endpoint support (#244)
* add chat endpoint support * supplement comment
This commit is contained in:
parent
3f6c086e95
commit
deed48a1e3
@ -98,6 +98,7 @@ import asyncio # for running API calls concurrently
|
||||
import json # for saving results to a jsonl file
|
||||
import logging # for logging rate limit warnings and other messages
|
||||
import os # for reading API key
|
||||
import re # for matching endpoint from request URL
|
||||
import tiktoken # for counting tokens
|
||||
import time # for sleeping after rate limit is hit
|
||||
from dataclasses import dataclass # for storing API inputs, outputs, and metadata
|
||||
@ -138,7 +139,7 @@ async def process_api_requests_from_file(
|
||||
available_token_capacity = max_tokens_per_minute
|
||||
last_update_time = time.time()
|
||||
|
||||
# intialize flags
|
||||
# initialize flags
|
||||
file_not_finished = True # after file is empty, we'll skip reading it
|
||||
logging.debug(f"Initialization complete.")
|
||||
|
||||
@ -151,13 +152,13 @@ async def process_api_requests_from_file(
|
||||
while True:
|
||||
# get next request (if one is not already waiting for capacity)
|
||||
if next_request is None:
|
||||
if queue_of_requests_to_retry.empty() is False:
|
||||
if not queue_of_requests_to_retry.empty():
|
||||
next_request = queue_of_requests_to_retry.get_nowait()
|
||||
logging.debug(f"Retrying request {next_request.task_id}: {next_request}")
|
||||
elif file_not_finished:
|
||||
try:
|
||||
# get new request
|
||||
request_json = eval(next(requests))
|
||||
request_json = json.loads(next(requests))
|
||||
next_request = APIRequest(
|
||||
task_id=next(task_id_generator),
|
||||
request_json=request_json,
|
||||
@ -199,7 +200,7 @@ async def process_api_requests_from_file(
|
||||
|
||||
# call API
|
||||
asyncio.create_task(
|
||||
next_request.call_API(
|
||||
next_request.call_api(
|
||||
request_url=request_url,
|
||||
request_header=request_header,
|
||||
retry_queue=queue_of_requests_to_retry,
|
||||
@ -259,7 +260,7 @@ class APIRequest:
|
||||
attempts_left: int
|
||||
result = []
|
||||
|
||||
async def call_API(
|
||||
async def call_api(
|
||||
self,
|
||||
request_url: str,
|
||||
request_header: dict,
|
||||
@ -312,7 +313,8 @@ class APIRequest:
|
||||
|
||||
def api_endpoint_from_url(request_url):
|
||||
"""Extract the API endpoint from the request URL."""
|
||||
return request_url.split("/")[-1]
|
||||
match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url)
|
||||
return match[1]
|
||||
|
||||
|
||||
def append_to_jsonl(data, filename: str) -> None:
|
||||
@ -330,21 +332,35 @@ def num_tokens_consumed_from_request(
|
||||
"""Count the number of tokens in the request. Only supports completion and embedding requests."""
|
||||
encoding = tiktoken.get_encoding(token_encoding_name)
|
||||
# if completions request, tokens = prompt + n * max_tokens
|
||||
if api_endpoint == "completions":
|
||||
prompt = request_json["prompt"]
|
||||
if api_endpoint.endswith("completions"):
|
||||
max_tokens = request_json.get("max_tokens", 15)
|
||||
n = request_json.get("n", 1)
|
||||
completion_tokens = n * max_tokens
|
||||
if isinstance(prompt, str): # single prompt
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
num_tokens = prompt_tokens + completion_tokens
|
||||
return num_tokens
|
||||
elif isinstance(prompt, list): # multiple prompts
|
||||
prompt_tokens = sum([len(encoding.encode(p)) for p in prompt])
|
||||
num_tokens = prompt_tokens + completion_tokens * len(prompt)
|
||||
return num_tokens
|
||||
|
||||
# chat completions
|
||||
if api_endpoint.startswith("chat/"):
|
||||
num_tokens = 0
|
||||
for message in request_json["messages"]:
|
||||
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name": # if there's a name, the role is omitted
|
||||
num_tokens -= 1 # role is always required and always 1 token
|
||||
num_tokens += 2 # every reply is primed with <im_start>assistant
|
||||
return num_tokens + completion_tokens
|
||||
# normal completions
|
||||
else:
|
||||
raise TypeError('Expecting either string or list of strings for "prompt" field in completion request')
|
||||
prompt = request_json["prompt"]
|
||||
if isinstance(prompt, str): # single prompt
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
num_tokens = prompt_tokens + completion_tokens
|
||||
return num_tokens
|
||||
elif isinstance(prompt, list): # multiple prompts
|
||||
prompt_tokens = sum([len(encoding.encode(p)) for p in prompt])
|
||||
num_tokens = prompt_tokens + completion_tokens * len(prompt)
|
||||
return num_tokens
|
||||
else:
|
||||
raise TypeError('Expecting either string or list of strings for "prompt" field in completion request')
|
||||
# if embeddings request, tokens = input tokens
|
||||
elif api_endpoint == "embeddings":
|
||||
input = request_json["input"]
|
||||
|
Loading…
Reference in New Issue
Block a user