Adding embedding support using huggingface to app.py

pull/920/head
chatgpt-tricks 11 months ago committed by GitHub
parent c5691c5993
commit 596e1d899f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ import random
import string import string
import time import time
from typing import Any from typing import Any
import requests
from flask import Flask, request from flask import Flask, request
from flask_cors import CORS from flask_cors import CORS
@ -88,6 +88,70 @@ def chat_completions():
return app.response_class(streaming(), mimetype="text/event-stream") return app.response_class(streaming(), mimetype="text/event-stream")
#Get the embedding from huggingface
def get_embedding(input_text, token):
huggingface_token = token
embedding_model = "sentence-transformers/all-mpnet-base-v2"
max_token_length = 500
# Load the tokenizer for the "all-mpnet-base-v2" model
tokenizer = AutoTokenizer.from_pretrained(embedding_model)
# Tokenize the text and split the tokens into chunks of 500 tokens each
tokens = tokenizer.tokenize(input_text)
token_chunks = [tokens[i:i + max_token_length] for i in range(0, len(tokens), max_token_length)]
# Initialize an empty list
embeddings = []
# Create embeddings for each chunk
for chunk in token_chunks:
# Convert the chunk tokens back to text
chunk_text = tokenizer.convert_tokens_to_string(chunk)
# Use the Hugging Face API to get embeddings for the chunk
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{embedding_model}"
headers = {"Authorization": f"Bearer {huggingface_token}"}
chunk_text = chunk_text.replace("\n", " ")
# Make a POST request to get the chunk's embedding
response = requests.post(api_url, headers=headers, json={"inputs": chunk_text, "options": {"wait_for_model": True}})
# Parse the response and extract the embedding
chunk_embedding = response.json()
# Append the embedding to the list
embeddings.append(chunk_embedding)
#averaging all the embeddings
#this isn't very effective
#someone a better idea?
num_embeddings = len(embeddings)
average_embedding = [sum(x) / num_embeddings for x in zip(*embeddings)]
embedding = average_embedding
return embedding
@app.route("/embeddings", methods=["POST"])
def embeddings():
input_text_list = request.get_json().get("input")
input_text = ' '.join(map(str, input_text_list))
token = request.headers.get('Authorization').replace("Bearer ", "")
embedding = get_embedding(input_text, token)
return {
"data": [
{
"embedding": embedding,
"index": 0,
"object": "embedding"
}
],
"model": "text-embedding-ada-002",
"object": "list",
"usage": {
"prompt_tokens": None,
"total_tokens": None
}
}
def main(): def main():
app.run(host="0.0.0.0", port=1337, debug=True) app.run(host="0.0.0.0", port=1337, debug=True)

Loading…
Cancel
Save