import itertools from collections import defaultdict from transformers import GPT2TokenizerFast import openai tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") MAX_TOKENS_LIMIT = 2048 def create_instruction(labels) -> str: """ Construct an instruction for a classification task. """ instruction = f"Please classify a piece of text into the following categories: {', '.join(labels)}." return f"{instruction.strip()}\n\n" def semantic_search( search_model, query_for_search, file_id=None, max_documents=None, examples=None ): """ :param examples: A list of {"text":...} or {"text": ..., "label": ...}. :return: a list of semantic search result dict of documents sorted by "score": [ { "document": ..., "object": "search_result", "score": ..., "text": ..., }, ... ] """ assert (examples is None) ^ (file_id is None) # xor if file_id is not None: # This is where you'd do an elastic search call. Since there isn't an example of this # we can query, we'll raise an error. # The return value from this would be a list of examples raise NotImplementedError() # This isn't quite accurate since Search is also being deprecated. See our search guide for more # information. search_result = openai.Search.create( model=search_model, documents=[x["text"] for x in examples], query=query_for_search, ) info_dict = {d["document"]: d for d in search_result["data"]} sorted_doc_ids = sorted( info_dict.keys(), key=lambda x: info_dict[x]["score"], reverse=True ) if max_documents: sorted_doc_ids = sorted_doc_ids[:max_documents] return [info_dict[i] for i in sorted_doc_ids] def select_by_length( sorted_doc_infos, max_token_len, lambda_fn=None, ): """ Give a list of (document ID, document content in string), we will select as many documents as possible as long as the total length does not go above `max_token_len`. :param sorted_doc_infos: A list of semantic search result dict of documents sorted by "score". :param max_token_len: The maximum token length for selected documents. :param lambda_fn: A function that takes in search results dict and output a formatted example for context stuffing. :return: A tuple of ( A concatenation of selected documents used as context, A list of selected document IDs ) """ if not sorted_doc_infos: return "", [] selected_indices = [] total_doc_tokens = 0 doc_dict = {} for i, doc_info in enumerate(sorted_doc_infos): doc = lambda_fn(doc_info) if lambda_fn else doc_info["text"] n_doc_tokens = len(tokenizer.encode(doc)) if total_doc_tokens + n_doc_tokens < max_token_len: total_doc_tokens += n_doc_tokens selected_indices.append(i) doc_dict[i] = doc # The top ranked documents should go at the end. selected_indices = selected_indices[::-1] context = "".join([doc_dict[i] for i in selected_indices]) selected_doc_infos = [sorted_doc_infos[i] for i in selected_indices] return context, selected_doc_infos def format_example_fn(x: dict) -> str: return "Text: {text}\nCategory: {label}\n---\n".format( text=x["text"].replace("\n", " ").strip(), label=x["label"].replace("\n", " ").strip(), ) def classifications( query, model, search_model="ada", examples=None, file=None, labels=None, temperature=0.0, logprobs=None, max_examples=200, logit_bias=None, alternative_query=None, max_tokens=16, ) -> dict: """ Given a prompt, a question and a list of examples, containing (text, label) pairs, it selects top relevant examples to construct a prompt for few-shot classification. The constructed prompt for the final completion call: ``` {{ an optional instruction }} Text: example 1 text Category: example 1 label --- Text: example 1 text Category: example 2 label --- Text: question Category: ``` The returned object has a structure like: { "label": "Happy", "model": "ada", "object": "classification", "selected_examples": [ { "document": ..., # document index, same as in search/ results. "text": ..., "label": ..., }, ... ], } """ query = query.replace("\n", " ").strip() logit_bias = logit_bias if logit_bias else {} labels = labels if labels else [] if file is None and examples is None: raise Exception("Please submit at least one of `examples` or `file`.") if file is not None and examples is not None: raise Exception("Please submit only one of `examples` or `file`.") instruction = create_instruction(labels) query_for_search = alternative_query if alternative_query is not None else query # Extract examples and example labels first. if file is not None: sorted_doc_infos = semantic_search( search_model, query_for_search, file_id=file, max_documents=max_examples, ) else: example_prompts = [ format_example_fn(dict(text=x, label=y)) for x, y in examples ] n_examples_tokens = [len(tokenizer.encode(x)) for x in example_prompts] query_prompt = f"Text: {query}\nCategory:" n_instruction_tokens = len(tokenizer.encode(instruction)) n_query_tokens = len(tokenizer.encode(query_prompt)) # Except all the required content, how many tokens left for context stuffing. leftover_token_len = MAX_TOKENS_LIMIT - ( n_instruction_tokens + n_query_tokens + max_tokens ) # Process when `examples` are provided but no `file` is provided. if examples: if (max_examples is None or max_examples >= len(examples)) and sum( n_examples_tokens ) < leftover_token_len: # If the total length of docs is short enough that we can add all examples, no search call. selected_indices = list(range(len(examples))) sorted_doc_infos = [ {"document": i, "text": examples[i][0], "label": examples[i][1]} for i in selected_indices ] elif max(n_examples_tokens) + n_query_tokens >= MAX_TOKENS_LIMIT: # If the prompt and the longest example together go above the limit: total_tokens = max(n_examples_tokens) + n_query_tokens raise Exception( user_message=f"The longest classification example, query and prompt together contain " f"{total_tokens} tokens, above the limit {MAX_TOKENS_LIMIT} for semantic search. " f"Please consider shortening your instruction, query or the longest example." ) else: # If we can add some context documents but not all of them, we should # query search endpoint to rank docs by score. sorted_doc_infos = semantic_search( search_model, query_for_search, examples=[{"text": x, "label": y} for x, y in examples], max_documents=max_examples, ) # Per label, we have a list of doc id sorted by its relevancy to the query. label_to_indices = defaultdict(list) for idx, d in enumerate(sorted_doc_infos): label_to_indices[d["label"]].append(idx) # Do a round robin for each of the different labels, taking the best match for each label. label_indices = [label_to_indices[label] for label in labels] mixed_indices = [ i for x in itertools.zip_longest(*label_indices) for i in x if i is not None ] sorted_doc_infos = [sorted_doc_infos[i] for i in mixed_indices] # Try to select as many examples as needed to fit into the context context, sorted_doc_infos = select_by_length( sorted_doc_infos, leftover_token_len, lambda_fn=format_example_fn, ) prompt = instruction + context + query_prompt completion_params = { "engine": model, "prompt": prompt, "temperature": temperature, "logprobs": logprobs, "logit_bias": logit_bias, "max_tokens": max_tokens, "stop": "\n", "n": 1, } completion_resp = openai.Completion.create( **completion_params, ) label = completion_resp["choices"][0]["text"] label = label.split("\n")[0].strip().lower().capitalize() if label not in labels: label = "Unknown" result = dict( # TODO: Add id for object persistence. object="classification", model=completion_resp["model"], label=label, completion=completion_resp["id"], ) result["selected_examples"] = sorted_doc_infos return result print( classifications( query="this is my test", model="davinci", search_model="ada", examples=[ ["this is my test", "davinci"], ["this is other test", "blahblah"], ], file=None, labels=["davinci", "blahblah"], temperature=0.1, logprobs=0, max_examples=200, logit_bias=None, alternative_query="different test", max_tokens=16, ) )