mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-17 15:29:46 +00:00
252 lines
8.0 KiB
Python
252 lines
8.0 KiB
Python
import textwrap as tr
|
|
from typing import List, Optional
|
|
|
|
import matplotlib.pyplot as plt
|
|
import plotly.express as px
|
|
from scipy import spatial
|
|
from sklearn.decomposition import PCA
|
|
from sklearn.manifold import TSNE
|
|
from sklearn.metrics import average_precision_score, precision_recall_curve
|
|
|
|
from openai import OpenAI
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
client = OpenAI(max_retries=5)
|
|
|
|
|
|
def get_embedding(text: str, model="text-embedding-3-small", **kwargs) -> List[float]:
|
|
# replace newlines, which can negatively affect performance.
|
|
text = text.replace("\n", " ")
|
|
|
|
response = client.embeddings.create(input=[text], model=model, **kwargs)
|
|
|
|
return response.data[0].embedding
|
|
|
|
|
|
async def aget_embedding(
|
|
text: str, model="text-embedding-3-small", **kwargs
|
|
) -> List[float]:
|
|
# replace newlines, which can negatively affect performance.
|
|
text = text.replace("\n", " ")
|
|
|
|
return (await client.embeddings.create(input=[text], model=model, **kwargs))[
|
|
"data"
|
|
][0]["embedding"]
|
|
|
|
|
|
def get_embeddings(
|
|
list_of_text: List[str], model="text-embedding-3-small", **kwargs
|
|
) -> List[List[float]]:
|
|
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
|
|
|
# replace newlines, which can negatively affect performance.
|
|
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
|
|
|
data = client.embeddings.create(input=list_of_text, model=model, **kwargs).data
|
|
return [d.embedding for d in data]
|
|
|
|
|
|
async def aget_embeddings(
|
|
list_of_text: List[str], model="text-embedding-3-small", **kwargs
|
|
) -> List[List[float]]:
|
|
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
|
|
|
# replace newlines, which can negatively affect performance.
|
|
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
|
|
|
data = (
|
|
await client.embeddings.create(input=list_of_text, model=model, **kwargs)
|
|
).data
|
|
return [d.embedding for d in data]
|
|
|
|
|
|
def cosine_similarity(a, b):
|
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
|
|
|
|
|
def plot_multiclass_precision_recall(
|
|
y_score, y_true_untransformed, class_list, classifier_name
|
|
):
|
|
"""
|
|
Precision-Recall plotting for a multiclass problem. It plots average precision-recall, per class precision recall and reference f1 contours.
|
|
|
|
Code slightly modified, but heavily based on https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html
|
|
"""
|
|
n_classes = len(class_list)
|
|
y_true = pd.concat(
|
|
[(y_true_untransformed == class_list[i]) for i in range(n_classes)], axis=1
|
|
).values
|
|
|
|
# For each class
|
|
precision = dict()
|
|
recall = dict()
|
|
average_precision = dict()
|
|
for i in range(n_classes):
|
|
precision[i], recall[i], _ = precision_recall_curve(y_true[:, i], y_score[:, i])
|
|
average_precision[i] = average_precision_score(y_true[:, i], y_score[:, i])
|
|
|
|
# A "micro-average": quantifying score on all classes jointly
|
|
precision_micro, recall_micro, _ = precision_recall_curve(
|
|
y_true.ravel(), y_score.ravel()
|
|
)
|
|
average_precision_micro = average_precision_score(y_true, y_score, average="micro")
|
|
print(
|
|
str(classifier_name)
|
|
+ " - Average precision score over all classes: {0:0.2f}".format(
|
|
average_precision_micro
|
|
)
|
|
)
|
|
|
|
# setup plot details
|
|
plt.figure(figsize=(9, 10))
|
|
f_scores = np.linspace(0.2, 0.8, num=4)
|
|
lines = []
|
|
labels = []
|
|
for f_score in f_scores:
|
|
x = np.linspace(0.01, 1)
|
|
y = f_score * x / (2 * x - f_score)
|
|
(l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
|
|
plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))
|
|
|
|
lines.append(l)
|
|
labels.append("iso-f1 curves")
|
|
(l,) = plt.plot(recall_micro, precision_micro, color="gold", lw=2)
|
|
lines.append(l)
|
|
labels.append(
|
|
"average Precision-recall (auprc = {0:0.2f})" "".format(average_precision_micro)
|
|
)
|
|
|
|
for i in range(n_classes):
|
|
(l,) = plt.plot(recall[i], precision[i], lw=2)
|
|
lines.append(l)
|
|
labels.append(
|
|
"Precision-recall for class `{0}` (auprc = {1:0.2f})"
|
|
"".format(class_list[i], average_precision[i])
|
|
)
|
|
|
|
fig = plt.gcf()
|
|
fig.subplots_adjust(bottom=0.25)
|
|
plt.xlim([0.0, 1.0])
|
|
plt.ylim([0.0, 1.05])
|
|
plt.xlabel("Recall")
|
|
plt.ylabel("Precision")
|
|
plt.title(f"{classifier_name}: Precision-Recall curve for each class")
|
|
plt.legend(lines, labels)
|
|
|
|
|
|
def distances_from_embeddings(
|
|
query_embedding: List[float],
|
|
embeddings: List[List[float]],
|
|
distance_metric="cosine",
|
|
) -> List[List]:
|
|
"""Return the distances between a query embedding and a list of embeddings."""
|
|
distance_metrics = {
|
|
"cosine": spatial.distance.cosine,
|
|
"L1": spatial.distance.cityblock,
|
|
"L2": spatial.distance.euclidean,
|
|
"Linf": spatial.distance.chebyshev,
|
|
}
|
|
distances = [
|
|
distance_metrics[distance_metric](query_embedding, embedding)
|
|
for embedding in embeddings
|
|
]
|
|
return distances
|
|
|
|
|
|
def indices_of_nearest_neighbors_from_distances(distances) -> np.ndarray:
|
|
"""Return a list of indices of nearest neighbors from a list of distances."""
|
|
return np.argsort(distances)
|
|
|
|
|
|
def pca_components_from_embeddings(
|
|
embeddings: List[List[float]], n_components=2
|
|
) -> np.ndarray:
|
|
"""Return the PCA components of a list of embeddings."""
|
|
pca = PCA(n_components=n_components)
|
|
array_of_embeddings = np.array(embeddings)
|
|
return pca.fit_transform(array_of_embeddings)
|
|
|
|
|
|
def tsne_components_from_embeddings(
|
|
embeddings: List[List[float]], n_components=2, **kwargs
|
|
) -> np.ndarray:
|
|
"""Returns t-SNE components of a list of embeddings."""
|
|
# use better defaults if not specified
|
|
if "init" not in kwargs.keys():
|
|
kwargs["init"] = "pca"
|
|
if "learning_rate" not in kwargs.keys():
|
|
kwargs["learning_rate"] = "auto"
|
|
tsne = TSNE(n_components=n_components, **kwargs)
|
|
array_of_embeddings = np.array(embeddings)
|
|
return tsne.fit_transform(array_of_embeddings)
|
|
|
|
|
|
def chart_from_components(
|
|
components: np.ndarray,
|
|
labels: Optional[List[str]] = None,
|
|
strings: Optional[List[str]] = None,
|
|
x_title="Component 0",
|
|
y_title="Component 1",
|
|
mark_size=5,
|
|
**kwargs,
|
|
):
|
|
"""Return an interactive 2D chart of embedding components."""
|
|
empty_list = ["" for _ in components]
|
|
data = pd.DataFrame(
|
|
{
|
|
x_title: components[:, 0],
|
|
y_title: components[:, 1],
|
|
"label": labels if labels else empty_list,
|
|
"string": ["<br>".join(tr.wrap(string, width=30)) for string in strings]
|
|
if strings
|
|
else empty_list,
|
|
}
|
|
)
|
|
chart = px.scatter(
|
|
data,
|
|
x=x_title,
|
|
y=y_title,
|
|
color="label" if labels else None,
|
|
symbol="label" if labels else None,
|
|
hover_data=["string"] if strings else None,
|
|
**kwargs,
|
|
).update_traces(marker=dict(size=mark_size))
|
|
return chart
|
|
|
|
|
|
def chart_from_components_3D(
|
|
components: np.ndarray,
|
|
labels: Optional[List[str]] = None,
|
|
strings: Optional[List[str]] = None,
|
|
x_title: str = "Component 0",
|
|
y_title: str = "Component 1",
|
|
z_title: str = "Compontent 2",
|
|
mark_size: int = 5,
|
|
**kwargs,
|
|
):
|
|
"""Return an interactive 3D chart of embedding components."""
|
|
empty_list = ["" for _ in components]
|
|
data = pd.DataFrame(
|
|
{
|
|
x_title: components[:, 0],
|
|
y_title: components[:, 1],
|
|
z_title: components[:, 2],
|
|
"label": labels if labels else empty_list,
|
|
"string": ["<br>".join(tr.wrap(string, width=30)) for string in strings]
|
|
if strings
|
|
else empty_list,
|
|
}
|
|
)
|
|
chart = px.scatter_3d(
|
|
data,
|
|
x=x_title,
|
|
y=y_title,
|
|
z=z_title,
|
|
color="label" if labels else None,
|
|
symbol="label" if labels else None,
|
|
hover_data=["string"] if strings else None,
|
|
**kwargs,
|
|
).update_traces(marker=dict(size=mark_size))
|
|
return chart
|