Add chat endpoint support (#244)

* add chat endpoint support

* supplement comment
pull/278/head
Herobs 1 year ago committed by GitHub
parent 4ec8058ffc
commit afa9436334
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -98,6 +98,7 @@ import asyncio # for running API calls concurrently
import json # for saving results to a jsonl file
import logging # for logging rate limit warnings and other messages
import os # for reading API key
import re # for matching endpoint from request URL
import tiktoken # for counting tokens
import time # for sleeping after rate limit is hit
from dataclasses import dataclass # for storing API inputs, outputs, and metadata
@ -138,7 +139,7 @@ async def process_api_requests_from_file(
available_token_capacity = max_tokens_per_minute
last_update_time = time.time()
# intialize flags
# initialize flags
file_not_finished = True # after file is empty, we'll skip reading it
logging.debug(f"Initialization complete.")
@ -151,13 +152,13 @@ async def process_api_requests_from_file(
while True:
# get next request (if one is not already waiting for capacity)
if next_request is None:
if queue_of_requests_to_retry.empty() is False:
if not queue_of_requests_to_retry.empty():
next_request = queue_of_requests_to_retry.get_nowait()
logging.debug(f"Retrying request {next_request.task_id}: {next_request}")
elif file_not_finished:
try:
# get new request
request_json = eval(next(requests))
request_json = json.loads(next(requests))
next_request = APIRequest(
task_id=next(task_id_generator),
request_json=request_json,
@ -199,7 +200,7 @@ async def process_api_requests_from_file(
# call API
asyncio.create_task(
next_request.call_API(
next_request.call_api(
request_url=request_url,
request_header=request_header,
retry_queue=queue_of_requests_to_retry,
@ -259,7 +260,7 @@ class APIRequest:
attempts_left: int
result = []
async def call_API(
async def call_api(
self,
request_url: str,
request_header: dict,
@ -312,7 +313,8 @@ class APIRequest:
def api_endpoint_from_url(request_url):
"""Extract the API endpoint from the request URL."""
return request_url.split("/")[-1]
match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url)
return match[1]
def append_to_jsonl(data, filename: str) -> None:
@ -330,21 +332,35 @@ def num_tokens_consumed_from_request(
"""Count the number of tokens in the request. Only supports completion and embedding requests."""
encoding = tiktoken.get_encoding(token_encoding_name)
# if completions request, tokens = prompt + n * max_tokens
if api_endpoint == "completions":
prompt = request_json["prompt"]
if api_endpoint.endswith("completions"):
max_tokens = request_json.get("max_tokens", 15)
n = request_json.get("n", 1)
completion_tokens = n * max_tokens
if isinstance(prompt, str): # single prompt
prompt_tokens = len(encoding.encode(prompt))
num_tokens = prompt_tokens + completion_tokens
return num_tokens
elif isinstance(prompt, list): # multiple prompts
prompt_tokens = sum([len(encoding.encode(p)) for p in prompt])
num_tokens = prompt_tokens + completion_tokens * len(prompt)
return num_tokens
# chat completions
if api_endpoint.startswith("chat/"):
num_tokens = 0
for message in request_json["messages"]:
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens -= 1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens + completion_tokens
# normal completions
else:
raise TypeError('Expecting either string or list of strings for "prompt" field in completion request')
prompt = request_json["prompt"]
if isinstance(prompt, str): # single prompt
prompt_tokens = len(encoding.encode(prompt))
num_tokens = prompt_tokens + completion_tokens
return num_tokens
elif isinstance(prompt, list): # multiple prompts
prompt_tokens = sum([len(encoding.encode(p)) for p in prompt])
num_tokens = prompt_tokens + completion_tokens * len(prompt)
return num_tokens
else:
raise TypeError('Expecting either string or list of strings for "prompt" field in completion request')
# if embeddings request, tokens = input tokens
elif api_endpoint == "embeddings":
input = request_json["input"]

Loading…
Cancel
Save