From a7000ee89eaecce1340801c98f23afcc02ae8d83 Mon Sep 17 00:00:00 2001 From: Tesfagabir Meharizghi Date: Tue, 1 Aug 2023 15:47:08 -0500 Subject: [PATCH] Callback handler for Amazon SageMaker Experiments (#8587) ## Description This PR implements a callback handler for SageMaker Experiments which is similar to that of mlflow. * When creating the callback handler, it takes the experiment's run object as an argument. All the callback outputs are then logged to the run object. * The output of each callback action (e.g., `on_llm_start`) is saved to S3 bucket as json file. * Optionally, you can also log additional information such as the LLM hyper-parameters to the same run object. * Once the callback object is no more needed, you will need to call the `flush_tracker()` method. This makes sure that any intermediate files are deleted. * A separate notebook example is provided to show how the callback is used. @3coins @agola11 --------- Co-authored-by: Tesfagabir Meharizghi --- .../providers/sagemaker_tracking.ipynb | 916 ++++++++++++++++++ .../langchain/langchain/callbacks/__init__.py | 2 + .../langchain/callbacks/sagemaker_callback.py | 280 ++++++ 3 files changed, 1198 insertions(+) create mode 100644 docs/extras/integrations/providers/sagemaker_tracking.ipynb create mode 100644 libs/langchain/langchain/callbacks/sagemaker_callback.py diff --git a/docs/extras/integrations/providers/sagemaker_tracking.ipynb b/docs/extras/integrations/providers/sagemaker_tracking.ipynb new file mode 100644 index 0000000000..a7678f16a3 --- /dev/null +++ b/docs/extras/integrations/providers/sagemaker_tracking.ipynb @@ -0,0 +1,916 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ef3909cf-72ca-4841-85c6-ef4e0eae3aaf", + "metadata": {}, + "source": [ + "# SageMaker Tracking\n", + "\n", + "This notebook shows how LangChain Callback can be used to log and track prompts and other LLM hyperparameters into SageMaker Experiments. Here, we use different scenarios to showcase the capability:\n", + "* **Scenario 1**: *Single LLM* - A case where a single LLM model is used to generate output based on a given prompt.\n", + "* **Scenario 2**: *Sequential Chain* - A case where a sequential chain of two LLM models is used.\n", + "* **Scenario 3**: *Agent with Tools (Chain of Thought)* - A case where multiple tools (search and math) are used in addition to an LLM.\n", + "\n", + "[Amazon SageMaker](https://aws.amazon.com/sagemaker/) is a fully managed service that is used to quickly and easily build, train and deploy machine learning (ML) models. \n", + "\n", + "[Amazon SageMaker Experiments](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments.html) is a capability of Amazon SageMaker that lets you organize, track, compare and evaluate ML experiments and model versions.\n", + "\n", + "In this notebook, we will create a single experiment to log the prompts from each scenario." + ] + }, + { + "cell_type": "markdown", + "id": "94c22cb4-3b1c-432b-b3be-0235eec79c5c", + "metadata": {}, + "source": [ + "## Installation and Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2353436d-17fe-4f58-a2f9-c299d56393fd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install sagemaker\n", + "!pip install openai\n", + "!pip install google-search-results" + ] + }, + { + "cell_type": "markdown", + "id": "65dcf62e-7a38-4119-adb9-d9e884e82499", + "metadata": { + "tags": [] + }, + "source": [ + "First, setup the required API keys\n", + "* OpenAI: https://platform.openai.com/account/api-keys (For OpenAI LLM model)\n", + "* Google SERP API: https://serpapi.com/manage-api-key (For Google Search Tool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ec2b898-0cfc-4308-8e86-569cd7b7cf41", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "## Add your API keys below\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", + "os.environ[\"SERPAPI_API_KEY\"] = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80968ebf-519f-46de-8703-97532ac39e3e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.chains import LLMChain, SimpleSequentialChain\n", + "from langchain.agents import initialize_agent, load_tools\n", + "from langchain.agents import Tool\n", + "from langchain.callbacks import SageMakerCallbackHandler\n", + "\n", + "from sagemaker.analytics import ExperimentAnalytics\n", + "from sagemaker.session import Session\n", + "from sagemaker.experiments.run import Run" + ] + }, + { + "cell_type": "markdown", + "id": "b67d031f-a01f-4009-ad29-c80ab8ad50ea", + "metadata": {}, + "source": [ + "## LLM Prompt Tracking" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da2d70ee-173b-469d-a718-54c33d862844", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#LLM Hyperparameters\n", + "HPARAMS = {\n", + " \"temperature\": 0.1,\n", + " \"model_name\": \"text-davinci-003\",\n", + "}\n", + "\n", + "#Bucket used to save prompt logs (Use `None` is used to save the default bucket or otherwise change it)\n", + "BUCKET_NAME = None\n", + "\n", + "#Experiment name\n", + "EXPERIMENT_NAME = \"langchain-sagemaker-tracker\"\n", + "\n", + "#Create SageMaker Session with the given bucket\n", + "session = Session(default_bucket=BUCKET_NAME)" + ] + }, + { + "cell_type": "markdown", + "id": "7239a39a-08d8-43cb-8922-81abdd5d9ebf", + "metadata": {}, + "source": [ + "### Scenario 1 - LLM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abc00335-50c8-4119-adb8-4c4ab8522e23", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "RUN_NAME = \"run-scenario-1\"\n", + "PROMPT_TEMPLATE = \"tell me a joke about {topic}\"\n", + "INPUT_VARIABLES = {\"topic\": \"fish\"}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a3a3cbe-db85-4255-8d8b-eaafdca8c6e2", + "metadata": {}, + "outputs": [], + "source": [ + "with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:\n", + "\n", + " # Create SageMaker Callback\n", + " sagemaker_callback = SageMakerCallbackHandler(run)\n", + "\n", + " # Define LLM model with callback\n", + " llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)\n", + "\n", + " # Create prompt template\n", + " prompt = PromptTemplate.from_template(template=PROMPT_TEMPLATE)\n", + "\n", + " # Create LLM Chain\n", + " chain = LLMChain(llm=llm, prompt=prompt, callbacks=[sagemaker_callback])\n", + "\n", + " # Run chain\n", + " chain.run(**INPUT_VARIABLES)\n", + "\n", + " # Reset the callback\n", + " sagemaker_callback.flush_tracker()" + ] + }, + { + "cell_type": "markdown", + "id": "7dc69934-9f42-40b7-9931-36a3371a38da", + "metadata": {}, + "source": [ + "### Scenario 2 - Sequential Chain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50b75ef9-9825-4ccc-8414-4cd7525a1b68", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "RUN_NAME = \"run-scenario-2\"\n", + "\n", + "PROMPT_TEMPLATE_1 = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n", + "Title: {title}\n", + "Playwright: This is a synopsis for the above play:\"\"\"\n", + "PROMPT_TEMPLATE_2 = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n", + "Play Synopsis: {synopsis}\n", + "Review from a New York Times play critic of the above play:\"\"\"\n", + "\n", + "INPUT_VARIABLES = {\n", + " \"input\": \"documentary about good video games that push the boundary of game design\"\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb7fff5f-e89f-40e2-96b4-3641a0b6e9b4", + "metadata": {}, + "outputs": [], + "source": [ + "with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:\n", + "\n", + " # Create SageMaker Callback\n", + " sagemaker_callback = SageMakerCallbackHandler(run)\n", + "\n", + " # Create prompt templates for the chain\n", + " prompt_template1 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_1)\n", + " prompt_template2 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_2)\n", + "\n", + " # Define LLM model with callback\n", + " llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)\n", + "\n", + " # Create chain1\n", + " chain1 = LLMChain(llm=llm, prompt=prompt_template1, callbacks=[sagemaker_callback])\n", + "\n", + " # Create chain2\n", + " chain2 = LLMChain(llm=llm, prompt=prompt_template2, callbacks=[sagemaker_callback])\n", + "\n", + " # Create Sequential chain\n", + " overall_chain = SimpleSequentialChain(chains=[chain1, chain2], callbacks=[sagemaker_callback])\n", + "\n", + " # Run overall sequential chain\n", + " overall_chain.run(**INPUT_VARIABLES)\n", + "\n", + " # Reset the callback\n", + " sagemaker_callback.flush_tracker()" + ] + }, + { + "cell_type": "markdown", + "id": "6b82bd0e-c626-4797-bb06-c1983f176315", + "metadata": {}, + "source": [ + "### Scenario 3 - Agent with Tools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5066f03-49dc-4868-be8e-d21ce22063fe", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "RUN_NAME = \"run-scenario-3\"\n", + "PROMPT_TEMPLATE = \"Who is the oldest person alive? And what is their current age raised to the power of 1.51?\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98385c42-9e44-4b03-b76d-007cb4797864", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:\n", + "\n", + " # Create SageMaker Callback\n", + " sagemaker_callback = SageMakerCallbackHandler(run)\n", + "\n", + " # Define LLM model with callback\n", + " llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)\n", + "\n", + " # Define tools\n", + " tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=[sagemaker_callback])\n", + "\n", + " # Initialize agent with all the tools\n", + " agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", callbacks=[sagemaker_callback])\n", + "\n", + " # Run agent\n", + " agent.run(input=PROMPT_TEMPLATE)\n", + "\n", + " # Reset the callback\n", + " sagemaker_callback.flush_tracker()" + ] + }, + { + "cell_type": "markdown", + "id": "c306a1c9-99f8-476d-96db-347746f5cfe0", + "metadata": { + "tags": [] + }, + "source": [ + "## Load Log Data\n", + "\n", + "Once the prompts are logged, we can easily load and convert them to Pandas DataFrame as follows." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec7b4af2-e01d-4f6c-9de5-70d2b4acb9e6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#Load\n", + "logs = ExperimentAnalytics(experiment_name=EXPERIMENT_NAME)\n", + "\n", + "#Convert as pandas dataframe\n", + "df = logs.dataframe(force_refresh=True)\n", + "\n", + "print(df.shape)\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "29991c75-f9cf-4c36-abfd-903c09fb170d", + "metadata": {}, + "source": [ + "As can be seen above, there are three runs (rows) in the experiment corresponding to each scenario. Each run logs the prompts and related LLM settings/hyperparameters as json and are saved in s3 bucket. Feel free to load and explore the log data from each json path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61a695d6-0aef-4284-9e12-eea8bc143dbd", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "availableInstances": [ + { + "_defaultOrder": 0, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.t3.medium", + "vcpuNum": 2 + }, + { + "_defaultOrder": 1, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.t3.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 2, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.t3.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 3, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.t3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 4, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 5, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 6, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 7, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 8, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 9, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 10, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 11, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 12, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5d.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 13, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5d.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 14, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5d.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 15, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5d.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 16, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5d.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 17, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5d.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 18, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5d.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 19, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 20, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": true, + "memoryGiB": 0, + "name": "ml.geospatial.interactive", + "supportedImageNames": [ + "sagemaker-geospatial-v1-0" + ], + "vcpuNum": 0 + }, + { + "_defaultOrder": 21, + "_isFastLaunch": true, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.c5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 22, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.c5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 23, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.c5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 24, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.c5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 25, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 72, + "name": "ml.c5.9xlarge", + "vcpuNum": 36 + }, + { + "_defaultOrder": 26, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 96, + "name": "ml.c5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 27, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 144, + "name": "ml.c5.18xlarge", + "vcpuNum": 72 + }, + { + "_defaultOrder": 28, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.c5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 29, + "_isFastLaunch": true, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g4dn.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 30, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g4dn.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 31, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g4dn.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 32, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g4dn.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 33, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g4dn.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 34, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g4dn.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 35, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 61, + "name": "ml.p3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 36, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 244, + "name": "ml.p3.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 37, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 488, + "name": "ml.p3.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 38, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.p3dn.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 39, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.r5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 40, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.r5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 41, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.r5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 42, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.r5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 43, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.r5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 44, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.r5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 45, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.r5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 46, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.r5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 47, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 48, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 49, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 50, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 51, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 52, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 53, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.g5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 54, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.g5.48xlarge", + "vcpuNum": 192 + } + ], + "instance_type": "ml.t3.large", + "kernelspec": { + "display_name": "conda_pytorch_p310", + "language": "python", + "name": "conda_pytorch_p310" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/callbacks/__init__.py b/libs/langchain/langchain/callbacks/__init__.py index ca70b0bd38..1262222239 100644 --- a/libs/langchain/langchain/callbacks/__init__.py +++ b/libs/langchain/langchain/callbacks/__init__.py @@ -27,6 +27,7 @@ from langchain.callbacks.manager import ( from langchain.callbacks.mlflow_callback import MlflowCallbackHandler from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.promptlayer_callback import PromptLayerCallbackHandler +from langchain.callbacks.sagemaker_callback import SageMakerCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler @@ -66,4 +67,5 @@ __all__ = [ "tracing_v2_enabled", "wandb_tracing_enabled", "FlyteCallbackHandler", + "SageMakerCallbackHandler", ] diff --git a/libs/langchain/langchain/callbacks/sagemaker_callback.py b/libs/langchain/langchain/callbacks/sagemaker_callback.py new file mode 100644 index 0000000000..c97461c330 --- /dev/null +++ b/libs/langchain/langchain/callbacks/sagemaker_callback.py @@ -0,0 +1,280 @@ +import json +import os +import shutil +import tempfile +from copy import deepcopy +from typing import Any, Dict, List, Optional, Union + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.callbacks.utils import ( + flatten_dict, +) +from langchain.schema import AgentAction, AgentFinish, LLMResult + + +def save_json(data: dict, file_path: str) -> None: + """Save dict to local file path. + + Parameters: + data (dict): The dictionary to be saved. + file_path (str): Local file path. + """ + with open(file_path, "w") as outfile: + json.dump(data, outfile) + + +class SageMakerCallbackHandler(BaseCallbackHandler): + """Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments. + + Parameters: + run (sagemaker.experiments.run.Run): Run object where the experiment is logged. + """ + + def __init__(self, run: Any) -> None: + """Initialize callback handler.""" + super().__init__() + + self.run = run + + self.metrics = { + "step": 0, + "starts": 0, + "ends": 0, + "errors": 0, + "text_ctr": 0, + "chain_starts": 0, + "chain_ends": 0, + "llm_starts": 0, + "llm_ends": 0, + "llm_streams": 0, + "tool_starts": 0, + "tool_ends": 0, + "agent_ends": 0, + } + + # Create a temporary directory + self.temp_dir = tempfile.mkdtemp() + + def _reset(self) -> None: + for k, v in self.metrics.items(): + self.metrics[k] = 0 + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts.""" + self.metrics["step"] += 1 + self.metrics["llm_starts"] += 1 + self.metrics["starts"] += 1 + + llm_starts = self.metrics["llm_starts"] + + resp: Dict[str, Any] = {} + resp.update({"action": "on_llm_start"}) + resp.update(flatten_dict(serialized)) + resp.update(self.metrics) + + for idx, prompt in enumerate(prompts): + prompt_resp = deepcopy(resp) + prompt_resp["prompt"] = prompt + self.jsonf( + prompt_resp, + self.temp_dir, + f"llm_start_{llm_starts}_prompt_{idx}", + ) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + self.metrics["step"] += 1 + self.metrics["llm_streams"] += 1 + + llm_streams = self.metrics["llm_streams"] + + resp: Dict[str, Any] = {} + resp.update({"action": "on_llm_new_token", "token": token}) + resp.update(self.metrics) + + self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}") + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + self.metrics["step"] += 1 + self.metrics["llm_ends"] += 1 + self.metrics["ends"] += 1 + + llm_ends = self.metrics["llm_ends"] + + resp: Dict[str, Any] = {} + resp.update({"action": "on_llm_end"}) + resp.update(flatten_dict(response.llm_output or {})) + + resp.update(self.metrics) + + for generations in response.generations: + for idx, generation in enumerate(generations): + generation_resp = deepcopy(resp) + generation_resp.update(flatten_dict(generation.dict())) + + self.jsonf( + resp, + self.temp_dir, + f"llm_end_{llm_ends}_generation_{idx}", + ) + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when LLM errors.""" + self.metrics["step"] += 1 + self.metrics["errors"] += 1 + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + self.metrics["step"] += 1 + self.metrics["chain_starts"] += 1 + self.metrics["starts"] += 1 + + chain_starts = self.metrics["chain_starts"] + + resp: Dict[str, Any] = {} + resp.update({"action": "on_chain_start"}) + resp.update(flatten_dict(serialized)) + resp.update(self.metrics) + + chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()]) + input_resp = deepcopy(resp) + input_resp["inputs"] = chain_input + + self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}") + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + self.metrics["step"] += 1 + self.metrics["chain_ends"] += 1 + self.metrics["ends"] += 1 + + chain_ends = self.metrics["chain_ends"] + + resp: Dict[str, Any] = {} + chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()]) + resp.update({"action": "on_chain_end", "outputs": chain_output}) + resp.update(self.metrics) + + self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}") + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when chain errors.""" + self.metrics["step"] += 1 + self.metrics["errors"] += 1 + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + self.metrics["step"] += 1 + self.metrics["tool_starts"] += 1 + self.metrics["starts"] += 1 + + tool_starts = self.metrics["tool_starts"] + + resp: Dict[str, Any] = {} + resp.update({"action": "on_tool_start", "input_str": input_str}) + resp.update(flatten_dict(serialized)) + resp.update(self.metrics) + + self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}") + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + self.metrics["step"] += 1 + self.metrics["tool_ends"] += 1 + self.metrics["ends"] += 1 + + tool_ends = self.metrics["tool_ends"] + + resp: Dict[str, Any] = {} + resp.update({"action": "on_tool_end", "output": output}) + resp.update(self.metrics) + + self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}") + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when tool errors.""" + self.metrics["step"] += 1 + self.metrics["errors"] += 1 + + def on_text(self, text: str, **kwargs: Any) -> None: + """ + Run when agent is ending. + """ + self.metrics["step"] += 1 + self.metrics["text_ctr"] += 1 + + text_ctr = self.metrics["text_ctr"] + + resp: Dict[str, Any] = {} + resp.update({"action": "on_text", "text": text}) + resp.update(self.metrics) + + self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}") + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run when agent ends running.""" + self.metrics["step"] += 1 + self.metrics["agent_ends"] += 1 + self.metrics["ends"] += 1 + + agent_ends = self.metrics["agent_ends"] + resp: Dict[str, Any] = {} + resp.update( + { + "action": "on_agent_finish", + "output": finish.return_values["output"], + "log": finish.log, + } + ) + resp.update(self.metrics) + + self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}") + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run on agent action.""" + self.metrics["step"] += 1 + self.metrics["tool_starts"] += 1 + self.metrics["starts"] += 1 + + tool_starts = self.metrics["tool_starts"] + resp: Dict[str, Any] = {} + resp.update( + { + "action": "on_agent_action", + "tool": action.tool, + "tool_input": action.tool_input, + "log": action.log, + } + ) + resp.update(self.metrics) + self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}") + + def jsonf( + self, + data: Dict[str, Any], + data_dir: str, + filename: str, + is_output: Optional[bool] = True, + ) -> None: + """To log the input data as json file artifact.""" + file_path = os.path.join(data_dir, f"{filename}.json") + save_json(data, file_path) + self.run.log_file(file_path, name=filename, is_output=is_output) + + def flush_tracker(self) -> None: + """Reset the steps and delete the temporary local directory.""" + self._reset() + shutil.rmtree(self.temp_dir)