mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-15 18:13:18 +00:00
Add chat endpoint support (#244)
* add chat endpoint support * supplement comment
This commit is contained in:
parent
4ec8058ffc
commit
afa9436334
@ -98,6 +98,7 @@ import asyncio # for running API calls concurrently
|
|||||||
import json # for saving results to a jsonl file
|
import json # for saving results to a jsonl file
|
||||||
import logging # for logging rate limit warnings and other messages
|
import logging # for logging rate limit warnings and other messages
|
||||||
import os # for reading API key
|
import os # for reading API key
|
||||||
|
import re # for matching endpoint from request URL
|
||||||
import tiktoken # for counting tokens
|
import tiktoken # for counting tokens
|
||||||
import time # for sleeping after rate limit is hit
|
import time # for sleeping after rate limit is hit
|
||||||
from dataclasses import dataclass # for storing API inputs, outputs, and metadata
|
from dataclasses import dataclass # for storing API inputs, outputs, and metadata
|
||||||
@ -138,7 +139,7 @@ async def process_api_requests_from_file(
|
|||||||
available_token_capacity = max_tokens_per_minute
|
available_token_capacity = max_tokens_per_minute
|
||||||
last_update_time = time.time()
|
last_update_time = time.time()
|
||||||
|
|
||||||
# intialize flags
|
# initialize flags
|
||||||
file_not_finished = True # after file is empty, we'll skip reading it
|
file_not_finished = True # after file is empty, we'll skip reading it
|
||||||
logging.debug(f"Initialization complete.")
|
logging.debug(f"Initialization complete.")
|
||||||
|
|
||||||
@ -151,13 +152,13 @@ async def process_api_requests_from_file(
|
|||||||
while True:
|
while True:
|
||||||
# get next request (if one is not already waiting for capacity)
|
# get next request (if one is not already waiting for capacity)
|
||||||
if next_request is None:
|
if next_request is None:
|
||||||
if queue_of_requests_to_retry.empty() is False:
|
if not queue_of_requests_to_retry.empty():
|
||||||
next_request = queue_of_requests_to_retry.get_nowait()
|
next_request = queue_of_requests_to_retry.get_nowait()
|
||||||
logging.debug(f"Retrying request {next_request.task_id}: {next_request}")
|
logging.debug(f"Retrying request {next_request.task_id}: {next_request}")
|
||||||
elif file_not_finished:
|
elif file_not_finished:
|
||||||
try:
|
try:
|
||||||
# get new request
|
# get new request
|
||||||
request_json = eval(next(requests))
|
request_json = json.loads(next(requests))
|
||||||
next_request = APIRequest(
|
next_request = APIRequest(
|
||||||
task_id=next(task_id_generator),
|
task_id=next(task_id_generator),
|
||||||
request_json=request_json,
|
request_json=request_json,
|
||||||
@ -199,7 +200,7 @@ async def process_api_requests_from_file(
|
|||||||
|
|
||||||
# call API
|
# call API
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
next_request.call_API(
|
next_request.call_api(
|
||||||
request_url=request_url,
|
request_url=request_url,
|
||||||
request_header=request_header,
|
request_header=request_header,
|
||||||
retry_queue=queue_of_requests_to_retry,
|
retry_queue=queue_of_requests_to_retry,
|
||||||
@ -259,7 +260,7 @@ class APIRequest:
|
|||||||
attempts_left: int
|
attempts_left: int
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
async def call_API(
|
async def call_api(
|
||||||
self,
|
self,
|
||||||
request_url: str,
|
request_url: str,
|
||||||
request_header: dict,
|
request_header: dict,
|
||||||
@ -312,7 +313,8 @@ class APIRequest:
|
|||||||
|
|
||||||
def api_endpoint_from_url(request_url):
|
def api_endpoint_from_url(request_url):
|
||||||
"""Extract the API endpoint from the request URL."""
|
"""Extract the API endpoint from the request URL."""
|
||||||
return request_url.split("/")[-1]
|
match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url)
|
||||||
|
return match[1]
|
||||||
|
|
||||||
|
|
||||||
def append_to_jsonl(data, filename: str) -> None:
|
def append_to_jsonl(data, filename: str) -> None:
|
||||||
@ -330,11 +332,25 @@ def num_tokens_consumed_from_request(
|
|||||||
"""Count the number of tokens in the request. Only supports completion and embedding requests."""
|
"""Count the number of tokens in the request. Only supports completion and embedding requests."""
|
||||||
encoding = tiktoken.get_encoding(token_encoding_name)
|
encoding = tiktoken.get_encoding(token_encoding_name)
|
||||||
# if completions request, tokens = prompt + n * max_tokens
|
# if completions request, tokens = prompt + n * max_tokens
|
||||||
if api_endpoint == "completions":
|
if api_endpoint.endswith("completions"):
|
||||||
prompt = request_json["prompt"]
|
|
||||||
max_tokens = request_json.get("max_tokens", 15)
|
max_tokens = request_json.get("max_tokens", 15)
|
||||||
n = request_json.get("n", 1)
|
n = request_json.get("n", 1)
|
||||||
completion_tokens = n * max_tokens
|
completion_tokens = n * max_tokens
|
||||||
|
|
||||||
|
# chat completions
|
||||||
|
if api_endpoint.startswith("chat/"):
|
||||||
|
num_tokens = 0
|
||||||
|
for message in request_json["messages"]:
|
||||||
|
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||||
|
for key, value in message.items():
|
||||||
|
num_tokens += len(encoding.encode(value))
|
||||||
|
if key == "name": # if there's a name, the role is omitted
|
||||||
|
num_tokens -= 1 # role is always required and always 1 token
|
||||||
|
num_tokens += 2 # every reply is primed with <im_start>assistant
|
||||||
|
return num_tokens + completion_tokens
|
||||||
|
# normal completions
|
||||||
|
else:
|
||||||
|
prompt = request_json["prompt"]
|
||||||
if isinstance(prompt, str): # single prompt
|
if isinstance(prompt, str): # single prompt
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
num_tokens = prompt_tokens + completion_tokens
|
num_tokens = prompt_tokens + completion_tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user