You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
6.6 KiB

import logging
import sys
import docx2txt
from PyPDF2 import PdfReader
from numpy import array, average
from flask import current_app
from config import *
from utils import get_embeddings, get_pinecone_id_for_file_chunk
# Set up logging
format="%(asctime)s [%(levelname)s] %(message)s",
# Handle a file by extracting its text, creating embeddings, and upserting them to Pinecone
def handle_file(file, session_id, pinecone_index, tokenizer):
"""Handle a file by extracting its text, creating embeddings, and upserting them to Pinecone."""
filename = file.filename"[handle_file] Handling file: {}".format(filename))
# Get the file text dict from the current app config
file_text_dict = current_app.config["file_text_dict"]
# Extract text from the file
extracted_text = extract_text_from_file(file)
except ValueError as e:
"[handle_file] Error extracting text from file: {}".format(e))
raise e
# Save extracted text to file text dict
file_text_dict[filename] = extracted_text
# Handle the extracted text as a string
return handle_file_string(filename, session_id, extracted_text, pinecone_index, tokenizer, file_text_dict)
# Extract text from a file based on its mimetype
def extract_text_from_file(file):
"""Return the text content of a file."""
if file.mimetype == "application/pdf":
# Extract text from pdf using PyPDF2
reader = PdfReader(file)
extracted_text = ""
for page in reader.pages:
extracted_text += page.extract_text()
elif file.mimetype == "text/plain":
# Read text from plain text file
extracted_text ="utf-8")
elif file.mimetype == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
# Extract text from docx using docx2txt
extracted_text = docx2txt.process(file)
# Unsupported file type
raise ValueError("Unsupported file type: {}".format(file.mimetype))
return extracted_text
# Handle a file string by creating embeddings and upserting them to Pinecone
def handle_file_string(filename, session_id, file_body_string, pinecone_index, tokenizer, file_text_dict):
"""Handle a file string by creating embeddings and upserting them to Pinecone.""""[handle_file_string] Starting...")
# Clean up the file string by replacing newlines and double spaces
clean_file_body_string = file_body_string.replace(
"\n", "; ").replace(" ", " ")
# Add the filename to the text to embed
text_to_embed = "Filename is: {}; {}".format(
filename, clean_file_body_string)
# Create embeddings for the text
text_embeddings, average_embedding = create_embeddings_for_text(
text_to_embed, tokenizer)
"[handle_file_string] Created embedding for {}".format(filename))
except Exception as e:
"[handle_file_string] Error creating embedding: {}".format(e))
raise e
# Get the vectors array of triples: file_chunk_id, embedding, metadata for each embedding
# Metadata is a dict with keys: filename, file_chunk_index
vectors = []
for i, (text_chunk, embedding) in enumerate(text_embeddings):
id = get_pinecone_id_for_file_chunk(session_id, filename, i)
file_text_dict[id] = text_chunk
(id, embedding, {"filename": filename, "file_chunk_index": i}))
"[handle_file_string] Text chunk {}: {}".format(i, text_chunk))
# Split the vectors array into smaller batches of max length 2000
batches = [vectors[i:i+batch_size] for i in range(0, len(vectors), batch_size)]
# Upsert each batch to Pinecone
for batch in batches:
vectors=batch, namespace=session_id)
"[handle_file_string] Upserted batch of embeddings for {}".format(filename))
except Exception as e:
"[handle_file_string] Error upserting batch of embeddings to Pinecone: {}".format(e))
raise e
# Compute the column-wise average of a list of lists
def get_col_average_from_list_of_lists(list_of_lists):
"""Return the average of each column in a list of lists."""
if len(list_of_lists) == 1:
return list_of_lists[0]
list_of_lists_array = array(list_of_lists)
average_embedding = average(list_of_lists_array, axis=0)
return average_embedding.tolist()
# Create embeddings for a text using a tokenizer and an OpenAI engine
def create_embeddings_for_text(text, tokenizer):
"""Return a list of tuples (text_chunk, embedding) and an average embedding for a text."""
token_chunks = list(chunks(text, TEXT_EMBEDDING_CHUNK_SIZE, tokenizer))
text_chunks = [tokenizer.decode(chunk) for chunk in token_chunks]
# Split text_chunks into shorter arrays of max length 10
text_chunks_arrays = [text_chunks[i:i+MAX_TEXTS_TO_EMBED_BATCH_SIZE] for i in range(0, len(text_chunks), MAX_TEXTS_TO_EMBED_BATCH_SIZE)]
# Call get_embeddings for each shorter array and combine the results
embeddings = []
for text_chunks_array in text_chunks_arrays:
embeddings_response = get_embeddings(text_chunks_array, EMBEDDINGS_MODEL)
embeddings.extend([embedding["embedding"] for embedding in embeddings_response])
text_embeddings = list(zip(text_chunks, embeddings))
average_embedding = get_col_average_from_list_of_lists(embeddings)
return (text_embeddings, average_embedding)
# Split a text into smaller chunks of size n, preferably ending at the end of a sentence
def chunks(text, n, tokenizer):
tokens = tokenizer.encode(text)
"""Yield successive n-sized chunks from text."""
i = 0
while i < len(tokens):
# Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens
j = min(i + int(1.5 * n), len(tokens))
while j > i + int(0.5 * n):
# Decode the tokens and check for full stop or newline
chunk = tokenizer.decode(tokens[i:j])
if chunk.endswith(".") or chunk.endswith("\n"):
j -= 1
# If no end of sentence found, use n tokens as the chunk size
if j == i + int(0.5 * n):
j = min(i + n, len(tokens))
yield tokens[i:j]
i = j