mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
209 lines
5.4 KiB
Plaintext
209 lines
5.4 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "4460f924-1738-4dc5-999f-c26383aba0a4",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Custom String Evaluator\n",
|
||
|
"\n",
|
||
|
"You can make your own custom string evaluators by inheriting from the `StringEvaluator` class and implementing the `_evaluate_strings` (and `_aevaluate_strings` for async support) methods.\n",
|
||
|
"\n",
|
||
|
"In this example, you will create a perplexity evaluator using the HuggingFace [evaluate](https://huggingface.co/docs/evaluate/index) library.\n",
|
||
|
"[Perplexity](https://en.wikipedia.org/wiki/Perplexity) is a measure of how well the generated text would be predicted by the model used to compute the metric."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "90ec5942-4b14-47b1-baff-9dd2a9f17a4e",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# %pip install evaluate > /dev/null"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "54fdba68-0ae7-4102-a45b-dabab86c97ac",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from typing import Any, Optional\n",
|
||
|
"\n",
|
||
|
"from langchain.evaluation import StringEvaluator\n",
|
||
|
"from evaluate import load\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class PerplexityEvaluator(StringEvaluator):\n",
|
||
|
" \"\"\"Evaluate the perplexity of a predicted string.\"\"\"\n",
|
||
|
"\n",
|
||
|
" def __init__(self, model_id: str = \"gpt2\"):\n",
|
||
|
" self.model_id = model_id\n",
|
||
|
" self.metric_fn = load(\n",
|
||
|
" \"perplexity\", module_type=\"metric\", model_id=self.model_id, pad_token=0\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def _evaluate_strings(\n",
|
||
|
" self,\n",
|
||
|
" *,\n",
|
||
|
" prediction: str,\n",
|
||
|
" reference: Optional[str] = None,\n",
|
||
|
" input: Optional[str] = None,\n",
|
||
|
" **kwargs: Any,\n",
|
||
|
" ) -> dict:\n",
|
||
|
" results = self.metric_fn.compute(\n",
|
||
|
" predictions=[prediction], model_id=self.model_id\n",
|
||
|
" )\n",
|
||
|
" ppl = results[\"perplexities\"][0]\n",
|
||
|
" return {\"score\": ppl}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "52767568-8075-4f77-93c9-80e1a7e5cba3",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"evaluator = PerplexityEvaluator()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "697ee0c0-d1ae-4a55-a542-a0f8e602c28a",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Using pad_token, but it is not set yet.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||
|
"To disable this warning, you can either:\n",
|
||
|
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||
|
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"application/vnd.jupyter.widget-view+json": {
|
||
|
"model_id": "467109d44654486e8b415288a319fc2c",
|
||
|
"version_major": 2,
|
||
|
"version_minor": 0
|
||
|
},
|
||
|
"text/plain": [
|
||
|
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'score': 190.3675537109375}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"evaluator.evaluate_strings(prediction=\"The rains in Spain fall mainly on the plain.\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "5089d9d1-eae6-4d47-b4f6-479e5d887d74",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Using pad_token, but it is not set yet.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"application/vnd.jupyter.widget-view+json": {
|
||
|
"model_id": "d3266f6f06d746e1bb03ce4aca07d9b9",
|
||
|
"version_major": 2,
|
||
|
"version_minor": 0
|
||
|
},
|
||
|
"text/plain": [
|
||
|
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'score': 1982.0709228515625}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# The perplexity is much higher since LangChain was introduced after 'gpt-2' was released and because it is never used in the following context.\n",
|
||
|
"evaluator.evaluate_strings(prediction=\"The rains in Spain fall mainly on LangChain.\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "5eaa178f-6ba3-47ae-b3dc-1b196af6d213",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.11.2"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|