diff --git a/examples/api_request_parallel_processor.py b/examples/api_request_parallel_processor.py index 8b206a48..cc6cd47d 100644 --- a/examples/api_request_parallel_processor.py +++ b/examples/api_request_parallel_processor.py @@ -98,6 +98,7 @@ import asyncio # for running API calls concurrently import json # for saving results to a jsonl file import logging # for logging rate limit warnings and other messages import os # for reading API key +import re # for matching endpoint from request URL import tiktoken # for counting tokens import time # for sleeping after rate limit is hit from dataclasses import dataclass # for storing API inputs, outputs, and metadata @@ -138,7 +139,7 @@ async def process_api_requests_from_file( available_token_capacity = max_tokens_per_minute last_update_time = time.time() - # intialize flags + # initialize flags file_not_finished = True # after file is empty, we'll skip reading it logging.debug(f"Initialization complete.") @@ -151,13 +152,13 @@ async def process_api_requests_from_file( while True: # get next request (if one is not already waiting for capacity) if next_request is None: - if queue_of_requests_to_retry.empty() is False: + if not queue_of_requests_to_retry.empty(): next_request = queue_of_requests_to_retry.get_nowait() logging.debug(f"Retrying request {next_request.task_id}: {next_request}") elif file_not_finished: try: # get new request - request_json = eval(next(requests)) + request_json = json.loads(next(requests)) next_request = APIRequest( task_id=next(task_id_generator), request_json=request_json, @@ -199,7 +200,7 @@ async def process_api_requests_from_file( # call API asyncio.create_task( - next_request.call_API( + next_request.call_api( request_url=request_url, request_header=request_header, retry_queue=queue_of_requests_to_retry, @@ -259,7 +260,7 @@ class APIRequest: attempts_left: int result = [] - async def call_API( + async def call_api( self, request_url: str, request_header: dict, @@ -312,7 +313,8 @@ class APIRequest: def api_endpoint_from_url(request_url): """Extract the API endpoint from the request URL.""" - return request_url.split("/")[-1] + match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url) + return match[1] def append_to_jsonl(data, filename: str) -> None: @@ -330,21 +332,35 @@ def num_tokens_consumed_from_request( """Count the number of tokens in the request. Only supports completion and embedding requests.""" encoding = tiktoken.get_encoding(token_encoding_name) # if completions request, tokens = prompt + n * max_tokens - if api_endpoint == "completions": - prompt = request_json["prompt"] + if api_endpoint.endswith("completions"): max_tokens = request_json.get("max_tokens", 15) n = request_json.get("n", 1) completion_tokens = n * max_tokens - if isinstance(prompt, str): # single prompt - prompt_tokens = len(encoding.encode(prompt)) - num_tokens = prompt_tokens + completion_tokens - return num_tokens - elif isinstance(prompt, list): # multiple prompts - prompt_tokens = sum([len(encoding.encode(p)) for p in prompt]) - num_tokens = prompt_tokens + completion_tokens * len(prompt) - return num_tokens + + # chat completions + if api_endpoint.startswith("chat/"): + num_tokens = 0 + for message in request_json["messages"]: + num_tokens += 4 # every message follows {role/name}\n{content}\n + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": # if there's a name, the role is omitted + num_tokens -= 1 # role is always required and always 1 token + num_tokens += 2 # every reply is primed with assistant + return num_tokens + completion_tokens + # normal completions else: - raise TypeError('Expecting either string or list of strings for "prompt" field in completion request') + prompt = request_json["prompt"] + if isinstance(prompt, str): # single prompt + prompt_tokens = len(encoding.encode(prompt)) + num_tokens = prompt_tokens + completion_tokens + return num_tokens + elif isinstance(prompt, list): # multiple prompts + prompt_tokens = sum([len(encoding.encode(p)) for p in prompt]) + num_tokens = prompt_tokens + completion_tokens * len(prompt) + return num_tokens + else: + raise TypeError('Expecting either string or list of strings for "prompt" field in completion request') # if embeddings request, tokens = input tokens elif api_endpoint == "embeddings": input = request_json["input"]