diff --git a/examples/api_request_parallel_processor.py b/examples/api_request_parallel_processor.py index 7a156154..2ce6a135 100644 --- a/examples/api_request_parallel_processor.py +++ b/examples/api_request_parallel_processor.py @@ -101,7 +101,10 @@ import os # for reading API key import re # for matching endpoint from request URL import tiktoken # for counting tokens import time # for sleeping after rate limit is hit -from dataclasses import dataclass, field # for storing API inputs, outputs, and metadata +from dataclasses import ( + dataclass, + field, +) # for storing API inputs, outputs, and metadata async def process_api_requests_from_file( @@ -118,7 +121,9 @@ async def process_api_requests_from_file( """Processes API requests in parallel, throttling to stay under rate limits.""" # constants seconds_to_pause_after_rate_limit_error = 15 - seconds_to_sleep_each_loop = 0.001 # 1 ms limits max throughput to 1,000 requests per second + seconds_to_sleep_each_loop = ( + 0.001 # 1 ms limits max throughput to 1,000 requests per second + ) # initialize logging logging.basicConfig(level=logging_level) @@ -130,8 +135,12 @@ async def process_api_requests_from_file( # initialize trackers queue_of_requests_to_retry = asyncio.Queue() - task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ... - status_tracker = StatusTracker() # single instance to track a collection of variables + task_id_generator = ( + task_id_generator_function() + ) # generates integer IDs of 1, 2, 3, ... + status_tracker = ( + StatusTracker() + ) # single instance to track a collection of variables next_request = None # variable to hold the next request to call # initialize available capacity counts @@ -148,90 +157,115 @@ async def process_api_requests_from_file( # `requests` will provide requests one at a time requests = file.__iter__() logging.debug(f"File opened. Entering main loop") - - while True: - # get next request (if one is not already waiting for capacity) - if next_request is None: - if not queue_of_requests_to_retry.empty(): - next_request = queue_of_requests_to_retry.get_nowait() - logging.debug(f"Retrying request {next_request.task_id}: {next_request}") - elif file_not_finished: - try: - # get new request - request_json = json.loads(next(requests)) - next_request = APIRequest( - task_id=next(task_id_generator), - request_json=request_json, - token_consumption=num_tokens_consumed_from_request(request_json, api_endpoint, token_encoding_name), - attempts_left=max_attempts, - metadata=request_json.pop("metadata", None) + async with aiohttp.ClientSession() as session: # Initialize ClientSession here + while True: + # get next request (if one is not already waiting for capacity) + if next_request is None: + if not queue_of_requests_to_retry.empty(): + next_request = queue_of_requests_to_retry.get_nowait() + logging.debug( + f"Retrying request {next_request.task_id}: {next_request}" ) - status_tracker.num_tasks_started += 1 - status_tracker.num_tasks_in_progress += 1 - logging.debug(f"Reading request {next_request.task_id}: {next_request}") - except StopIteration: - # if file runs out, set flag to stop reading it - logging.debug("Read file exhausted") - file_not_finished = False - - # update available capacity - current_time = time.time() - seconds_since_update = current_time - last_update_time - available_request_capacity = min( - available_request_capacity + max_requests_per_minute * seconds_since_update / 60.0, - max_requests_per_minute, - ) - available_token_capacity = min( - available_token_capacity + max_tokens_per_minute * seconds_since_update / 60.0, - max_tokens_per_minute, - ) - last_update_time = current_time - - # if enough capacity available, call API - if next_request: - next_request_tokens = next_request.token_consumption - if ( - available_request_capacity >= 1 - and available_token_capacity >= next_request_tokens - ): - # update counters - available_request_capacity -= 1 - available_token_capacity -= next_request_tokens - next_request.attempts_left -= 1 - - # call API - asyncio.create_task( - next_request.call_api( - request_url=request_url, - request_header=request_header, - retry_queue=queue_of_requests_to_retry, - save_filepath=save_filepath, - status_tracker=status_tracker, + elif file_not_finished: + try: + # get new request + request_json = json.loads(next(requests)) + next_request = APIRequest( + task_id=next(task_id_generator), + request_json=request_json, + token_consumption=num_tokens_consumed_from_request( + request_json, api_endpoint, token_encoding_name + ), + attempts_left=max_attempts, + metadata=request_json.pop("metadata", None), + ) + status_tracker.num_tasks_started += 1 + status_tracker.num_tasks_in_progress += 1 + logging.debug( + f"Reading request {next_request.task_id}: {next_request}" + ) + except StopIteration: + # if file runs out, set flag to stop reading it + logging.debug("Read file exhausted") + file_not_finished = False + + # update available capacity + current_time = time.time() + seconds_since_update = current_time - last_update_time + available_request_capacity = min( + available_request_capacity + + max_requests_per_minute * seconds_since_update / 60.0, + max_requests_per_minute, + ) + available_token_capacity = min( + available_token_capacity + + max_tokens_per_minute * seconds_since_update / 60.0, + max_tokens_per_minute, + ) + last_update_time = current_time + + # if enough capacity available, call API + if next_request: + next_request_tokens = next_request.token_consumption + if ( + available_request_capacity >= 1 + and available_token_capacity >= next_request_tokens + ): + # update counters + available_request_capacity -= 1 + available_token_capacity -= next_request_tokens + next_request.attempts_left -= 1 + + # call API + asyncio.create_task( + next_request.call_api( + session=session, + request_url=request_url, + request_header=request_header, + retry_queue=queue_of_requests_to_retry, + save_filepath=save_filepath, + status_tracker=status_tracker, + ) ) - ) - next_request = None # reset next_request to empty + next_request = None # reset next_request to empty - # if all tasks are finished, break - if status_tracker.num_tasks_in_progress == 0: - break + # if all tasks are finished, break + if status_tracker.num_tasks_in_progress == 0: + break - # main loop sleeps briefly so concurrent tasks can run - await asyncio.sleep(seconds_to_sleep_each_loop) + # main loop sleeps briefly so concurrent tasks can run + await asyncio.sleep(seconds_to_sleep_each_loop) - # if a rate limit error was hit recently, pause to cool down - seconds_since_rate_limit_error = (time.time() - status_tracker.time_of_last_rate_limit_error) - if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error: - remaining_seconds_to_pause = (seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error) - await asyncio.sleep(remaining_seconds_to_pause) - # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago - logging.warn(f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}") + # if a rate limit error was hit recently, pause to cool down + seconds_since_rate_limit_error = ( + time.time() - status_tracker.time_of_last_rate_limit_error + ) + if ( + seconds_since_rate_limit_error + < seconds_to_pause_after_rate_limit_error + ): + remaining_seconds_to_pause = ( + seconds_to_pause_after_rate_limit_error + - seconds_since_rate_limit_error + ) + await asyncio.sleep(remaining_seconds_to_pause) + # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago + logging.warn( + f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}" + ) # after finishing, log final status - logging.info(f"""Parallel processing complete. Results saved to {save_filepath}""") + logging.info( + f"""Parallel processing complete. Results saved to {save_filepath}""" + ) if status_tracker.num_tasks_failed > 0: - logging.warning(f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}.") + logging.warning( + f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}." + ) if status_tracker.num_rate_limit_errors > 0: - logging.warning(f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.") + logging.warning( + f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate." + ) # dataclasses @@ -264,6 +298,7 @@ class APIRequest: async def call_api( self, + session: aiohttp.ClientSession, request_url: str, request_header: dict, retry_queue: asyncio.Queue, @@ -274,11 +309,10 @@ class APIRequest: logging.info(f"Starting request #{self.task_id}") error = None try: - async with aiohttp.ClientSession() as session: - async with session.post( - url=request_url, headers=request_header, json=self.request_json - ) as response: - response = await response.json() + async with session.post( + url=request_url, headers=request_header, json=self.request_json + ) as response: + response = await response.json() if "error" in response: logging.warning( f"Request {self.task_id} failed with error {response['error']}" @@ -288,9 +322,13 @@ class APIRequest: if "Rate limit" in response["error"].get("message", ""): status_tracker.time_of_last_rate_limit_error = time.time() status_tracker.num_rate_limit_errors += 1 - status_tracker.num_api_errors -= 1 # rate limit errors are counted separately + status_tracker.num_api_errors -= ( + 1 # rate limit errors are counted separately + ) - except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them + except ( + Exception + ) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them logging.warning(f"Request {self.task_id} failed with Exception {e}") status_tracker.num_other_errors += 1 error = e @@ -299,7 +337,9 @@ class APIRequest: if self.attempts_left: retry_queue.put_nowait(self) else: - logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}") + logging.error( + f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}" + ) data = ( [self.request_json, [str(e) for e in self.result], self.metadata] if self.metadata @@ -325,7 +365,7 @@ class APIRequest: def api_endpoint_from_url(request_url): """Extract the API endpoint from the request URL.""" - match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url) + match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url) return match[1] @@ -372,7 +412,9 @@ def num_tokens_consumed_from_request( num_tokens = prompt_tokens + completion_tokens * len(prompt) return num_tokens else: - raise TypeError('Expecting either string or list of strings for "prompt" field in completion request') + raise TypeError( + 'Expecting either string or list of strings for "prompt" field in completion request' + ) # if embeddings request, tokens = input tokens elif api_endpoint == "embeddings": input = request_json["input"] @@ -383,10 +425,14 @@ def num_tokens_consumed_from_request( num_tokens = sum([len(encoding.encode(i)) for i in input]) return num_tokens else: - raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request') + raise TypeError( + 'Expecting either string or list of strings for "inputs" field in embedding request' + ) # more logic needed to support other API calls (e.g., edits, inserts, DALL-E) else: - raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script') + raise NotImplementedError( + f'API endpoint "{api_endpoint}" not implemented in this script' + ) def task_id_generator_function():