Add chat endpoint support (#244)

* add chat endpoint support

* supplement comment
This commit is contained in:
Herobs 2023-03-23 04:21:36 +08:00 committed by GitHub
parent 4ec8058ffc
commit afa9436334
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -98,6 +98,7 @@ import asyncio # for running API calls concurrently
import json # for saving results to a jsonl file import json # for saving results to a jsonl file
import logging # for logging rate limit warnings and other messages import logging # for logging rate limit warnings and other messages
import os # for reading API key import os # for reading API key
import re # for matching endpoint from request URL
import tiktoken # for counting tokens import tiktoken # for counting tokens
import time # for sleeping after rate limit is hit import time # for sleeping after rate limit is hit
from dataclasses import dataclass # for storing API inputs, outputs, and metadata from dataclasses import dataclass # for storing API inputs, outputs, and metadata
@ -138,7 +139,7 @@ async def process_api_requests_from_file(
available_token_capacity = max_tokens_per_minute available_token_capacity = max_tokens_per_minute
last_update_time = time.time() last_update_time = time.time()
# intialize flags # initialize flags
file_not_finished = True # after file is empty, we'll skip reading it file_not_finished = True # after file is empty, we'll skip reading it
logging.debug(f"Initialization complete.") logging.debug(f"Initialization complete.")
@ -151,13 +152,13 @@ async def process_api_requests_from_file(
while True: while True:
# get next request (if one is not already waiting for capacity) # get next request (if one is not already waiting for capacity)
if next_request is None: if next_request is None:
if queue_of_requests_to_retry.empty() is False: if not queue_of_requests_to_retry.empty():
next_request = queue_of_requests_to_retry.get_nowait() next_request = queue_of_requests_to_retry.get_nowait()
logging.debug(f"Retrying request {next_request.task_id}: {next_request}") logging.debug(f"Retrying request {next_request.task_id}: {next_request}")
elif file_not_finished: elif file_not_finished:
try: try:
# get new request # get new request
request_json = eval(next(requests)) request_json = json.loads(next(requests))
next_request = APIRequest( next_request = APIRequest(
task_id=next(task_id_generator), task_id=next(task_id_generator),
request_json=request_json, request_json=request_json,
@ -199,7 +200,7 @@ async def process_api_requests_from_file(
# call API # call API
asyncio.create_task( asyncio.create_task(
next_request.call_API( next_request.call_api(
request_url=request_url, request_url=request_url,
request_header=request_header, request_header=request_header,
retry_queue=queue_of_requests_to_retry, retry_queue=queue_of_requests_to_retry,
@ -259,7 +260,7 @@ class APIRequest:
attempts_left: int attempts_left: int
result = [] result = []
async def call_API( async def call_api(
self, self,
request_url: str, request_url: str,
request_header: dict, request_header: dict,
@ -312,7 +313,8 @@ class APIRequest:
def api_endpoint_from_url(request_url): def api_endpoint_from_url(request_url):
"""Extract the API endpoint from the request URL.""" """Extract the API endpoint from the request URL."""
return request_url.split("/")[-1] match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url)
return match[1]
def append_to_jsonl(data, filename: str) -> None: def append_to_jsonl(data, filename: str) -> None:
@ -330,11 +332,25 @@ def num_tokens_consumed_from_request(
"""Count the number of tokens in the request. Only supports completion and embedding requests.""" """Count the number of tokens in the request. Only supports completion and embedding requests."""
encoding = tiktoken.get_encoding(token_encoding_name) encoding = tiktoken.get_encoding(token_encoding_name)
# if completions request, tokens = prompt + n * max_tokens # if completions request, tokens = prompt + n * max_tokens
if api_endpoint == "completions": if api_endpoint.endswith("completions"):
prompt = request_json["prompt"]
max_tokens = request_json.get("max_tokens", 15) max_tokens = request_json.get("max_tokens", 15)
n = request_json.get("n", 1) n = request_json.get("n", 1)
completion_tokens = n * max_tokens completion_tokens = n * max_tokens
# chat completions
if api_endpoint.startswith("chat/"):
num_tokens = 0
for message in request_json["messages"]:
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens -= 1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens + completion_tokens
# normal completions
else:
prompt = request_json["prompt"]
if isinstance(prompt, str): # single prompt if isinstance(prompt, str): # single prompt
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
num_tokens = prompt_tokens + completion_tokens num_tokens = prompt_tokens + completion_tokens