Callback handler for Amazon SageMaker Experiments (#8587)

## Description

This PR implements a callback handler for SageMaker Experiments which is
similar to that of mlflow.
* When creating the callback handler, it takes the experiment's run
object as an argument. All the callback outputs are then logged to the
run object.
* The output of each callback action (e.g., `on_llm_start`) is saved to
S3 bucket as json file.
* Optionally, you can also log additional information such as the LLM
hyper-parameters to the same run object.
* Once the callback object is no more needed, you will need to call the
`flush_tracker()` method. This makes sure that any intermediate files
are deleted.
* A separate notebook example is provided to show how the callback is
used.

@3coins  @agola11

---------

Co-authored-by: Tesfagabir Meharizghi <mehariz@amazon.com>
This commit is contained in:
Tesfagabir Meharizghi 2023-08-01 15:47:08 -05:00 committed by GitHub
parent 9c2b29a1cb
commit a7000ee89e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 1198 additions and 0 deletions

View File

@ -0,0 +1,916 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ef3909cf-72ca-4841-85c6-ef4e0eae3aaf",
"metadata": {},
"source": [
"# SageMaker Tracking\n",
"\n",
"This notebook shows how LangChain Callback can be used to log and track prompts and other LLM hyperparameters into SageMaker Experiments. Here, we use different scenarios to showcase the capability:\n",
"* **Scenario 1**: *Single LLM* - A case where a single LLM model is used to generate output based on a given prompt.\n",
"* **Scenario 2**: *Sequential Chain* - A case where a sequential chain of two LLM models is used.\n",
"* **Scenario 3**: *Agent with Tools (Chain of Thought)* - A case where multiple tools (search and math) are used in addition to an LLM.\n",
"\n",
"[Amazon SageMaker](https://aws.amazon.com/sagemaker/) is a fully managed service that is used to quickly and easily build, train and deploy machine learning (ML) models. \n",
"\n",
"[Amazon SageMaker Experiments](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments.html) is a capability of Amazon SageMaker that lets you organize, track, compare and evaluate ML experiments and model versions.\n",
"\n",
"In this notebook, we will create a single experiment to log the prompts from each scenario."
]
},
{
"cell_type": "markdown",
"id": "94c22cb4-3b1c-432b-b3be-0235eec79c5c",
"metadata": {},
"source": [
"## Installation and Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2353436d-17fe-4f58-a2f9-c299d56393fd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install sagemaker\n",
"!pip install openai\n",
"!pip install google-search-results"
]
},
{
"cell_type": "markdown",
"id": "65dcf62e-7a38-4119-adb9-d9e884e82499",
"metadata": {
"tags": []
},
"source": [
"First, setup the required API keys\n",
"* OpenAI: https://platform.openai.com/account/api-keys (For OpenAI LLM model)\n",
"* Google SERP API: https://serpapi.com/manage-api-key (For Google Search Tool)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ec2b898-0cfc-4308-8e86-569cd7b7cf41",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"## Add your API keys below\n",
"os.environ[\"OPENAI_API_KEY\"] = \"<ADD-KEY-HERE>\"\n",
"os.environ[\"SERPAPI_API_KEY\"] = \"<ADD-KEY-HERE>\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80968ebf-519f-46de-8703-97532ac39e3e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain, SimpleSequentialChain\n",
"from langchain.agents import initialize_agent, load_tools\n",
"from langchain.agents import Tool\n",
"from langchain.callbacks import SageMakerCallbackHandler\n",
"\n",
"from sagemaker.analytics import ExperimentAnalytics\n",
"from sagemaker.session import Session\n",
"from sagemaker.experiments.run import Run"
]
},
{
"cell_type": "markdown",
"id": "b67d031f-a01f-4009-ad29-c80ab8ad50ea",
"metadata": {},
"source": [
"## LLM Prompt Tracking"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da2d70ee-173b-469d-a718-54c33d862844",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"#LLM Hyperparameters\n",
"HPARAMS = {\n",
" \"temperature\": 0.1,\n",
" \"model_name\": \"text-davinci-003\",\n",
"}\n",
"\n",
"#Bucket used to save prompt logs (Use `None` is used to save the default bucket or otherwise change it)\n",
"BUCKET_NAME = None\n",
"\n",
"#Experiment name\n",
"EXPERIMENT_NAME = \"langchain-sagemaker-tracker\"\n",
"\n",
"#Create SageMaker Session with the given bucket\n",
"session = Session(default_bucket=BUCKET_NAME)"
]
},
{
"cell_type": "markdown",
"id": "7239a39a-08d8-43cb-8922-81abdd5d9ebf",
"metadata": {},
"source": [
"### Scenario 1 - LLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "abc00335-50c8-4119-adb8-4c4ab8522e23",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"RUN_NAME = \"run-scenario-1\"\n",
"PROMPT_TEMPLATE = \"tell me a joke about {topic}\"\n",
"INPUT_VARIABLES = {\"topic\": \"fish\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a3a3cbe-db85-4255-8d8b-eaafdca8c6e2",
"metadata": {},
"outputs": [],
"source": [
"with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:\n",
"\n",
" # Create SageMaker Callback\n",
" sagemaker_callback = SageMakerCallbackHandler(run)\n",
"\n",
" # Define LLM model with callback\n",
" llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)\n",
"\n",
" # Create prompt template\n",
" prompt = PromptTemplate.from_template(template=PROMPT_TEMPLATE)\n",
"\n",
" # Create LLM Chain\n",
" chain = LLMChain(llm=llm, prompt=prompt, callbacks=[sagemaker_callback])\n",
"\n",
" # Run chain\n",
" chain.run(**INPUT_VARIABLES)\n",
"\n",
" # Reset the callback\n",
" sagemaker_callback.flush_tracker()"
]
},
{
"cell_type": "markdown",
"id": "7dc69934-9f42-40b7-9931-36a3371a38da",
"metadata": {},
"source": [
"### Scenario 2 - Sequential Chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50b75ef9-9825-4ccc-8414-4cd7525a1b68",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"RUN_NAME = \"run-scenario-2\"\n",
"\n",
"PROMPT_TEMPLATE_1 = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
"Title: {title}\n",
"Playwright: This is a synopsis for the above play:\"\"\"\n",
"PROMPT_TEMPLATE_2 = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n",
"Play Synopsis: {synopsis}\n",
"Review from a New York Times play critic of the above play:\"\"\"\n",
"\n",
"INPUT_VARIABLES = {\n",
" \"input\": \"documentary about good video games that push the boundary of game design\"\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb7fff5f-e89f-40e2-96b4-3641a0b6e9b4",
"metadata": {},
"outputs": [],
"source": [
"with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:\n",
"\n",
" # Create SageMaker Callback\n",
" sagemaker_callback = SageMakerCallbackHandler(run)\n",
"\n",
" # Create prompt templates for the chain\n",
" prompt_template1 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_1)\n",
" prompt_template2 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_2)\n",
"\n",
" # Define LLM model with callback\n",
" llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)\n",
"\n",
" # Create chain1\n",
" chain1 = LLMChain(llm=llm, prompt=prompt_template1, callbacks=[sagemaker_callback])\n",
"\n",
" # Create chain2\n",
" chain2 = LLMChain(llm=llm, prompt=prompt_template2, callbacks=[sagemaker_callback])\n",
"\n",
" # Create Sequential chain\n",
" overall_chain = SimpleSequentialChain(chains=[chain1, chain2], callbacks=[sagemaker_callback])\n",
"\n",
" # Run overall sequential chain\n",
" overall_chain.run(**INPUT_VARIABLES)\n",
"\n",
" # Reset the callback\n",
" sagemaker_callback.flush_tracker()"
]
},
{
"cell_type": "markdown",
"id": "6b82bd0e-c626-4797-bb06-c1983f176315",
"metadata": {},
"source": [
"### Scenario 3 - Agent with Tools"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5066f03-49dc-4868-be8e-d21ce22063fe",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"RUN_NAME = \"run-scenario-3\"\n",
"PROMPT_TEMPLATE = \"Who is the oldest person alive? And what is their current age raised to the power of 1.51?\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98385c42-9e44-4b03-b76d-007cb4797864",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:\n",
"\n",
" # Create SageMaker Callback\n",
" sagemaker_callback = SageMakerCallbackHandler(run)\n",
"\n",
" # Define LLM model with callback\n",
" llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)\n",
"\n",
" # Define tools\n",
" tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=[sagemaker_callback])\n",
"\n",
" # Initialize agent with all the tools\n",
" agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", callbacks=[sagemaker_callback])\n",
"\n",
" # Run agent\n",
" agent.run(input=PROMPT_TEMPLATE)\n",
"\n",
" # Reset the callback\n",
" sagemaker_callback.flush_tracker()"
]
},
{
"cell_type": "markdown",
"id": "c306a1c9-99f8-476d-96db-347746f5cfe0",
"metadata": {
"tags": []
},
"source": [
"## Load Log Data\n",
"\n",
"Once the prompts are logged, we can easily load and convert them to Pandas DataFrame as follows."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec7b4af2-e01d-4f6c-9de5-70d2b4acb9e6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"#Load\n",
"logs = ExperimentAnalytics(experiment_name=EXPERIMENT_NAME)\n",
"\n",
"#Convert as pandas dataframe\n",
"df = logs.dataframe(force_refresh=True)\n",
"\n",
"print(df.shape)\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"id": "29991c75-f9cf-4c36-abfd-903c09fb170d",
"metadata": {},
"source": [
"As can be seen above, there are three runs (rows) in the experiment corresponding to each scenario. Each run logs the prompts and related LLM settings/hyperparameters as json and are saved in s3 bucket. Feel free to load and explore the log data from each json path."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61a695d6-0aef-4284-9e12-eea8bc143dbd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"availableInstances": [
{
"_defaultOrder": 0,
"_isFastLaunch": true,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 4,
"name": "ml.t3.medium",
"vcpuNum": 2
},
{
"_defaultOrder": 1,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 8,
"name": "ml.t3.large",
"vcpuNum": 2
},
{
"_defaultOrder": 2,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.t3.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 3,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.t3.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 4,
"_isFastLaunch": true,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 8,
"name": "ml.m5.large",
"vcpuNum": 2
},
{
"_defaultOrder": 5,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.m5.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 6,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.m5.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 7,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 64,
"name": "ml.m5.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 8,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 128,
"name": "ml.m5.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 9,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 192,
"name": "ml.m5.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 10,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 256,
"name": "ml.m5.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 11,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 384,
"name": "ml.m5.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 12,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 8,
"name": "ml.m5d.large",
"vcpuNum": 2
},
{
"_defaultOrder": 13,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.m5d.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 14,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.m5d.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 15,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 64,
"name": "ml.m5d.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 16,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 128,
"name": "ml.m5d.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 17,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 192,
"name": "ml.m5d.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 18,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 256,
"name": "ml.m5d.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 19,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 384,
"name": "ml.m5d.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 20,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": true,
"memoryGiB": 0,
"name": "ml.geospatial.interactive",
"supportedImageNames": [
"sagemaker-geospatial-v1-0"
],
"vcpuNum": 0
},
{
"_defaultOrder": 21,
"_isFastLaunch": true,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 4,
"name": "ml.c5.large",
"vcpuNum": 2
},
{
"_defaultOrder": 22,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 8,
"name": "ml.c5.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 23,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.c5.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 24,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.c5.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 25,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 72,
"name": "ml.c5.9xlarge",
"vcpuNum": 36
},
{
"_defaultOrder": 26,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 96,
"name": "ml.c5.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 27,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 144,
"name": "ml.c5.18xlarge",
"vcpuNum": 72
},
{
"_defaultOrder": 28,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 192,
"name": "ml.c5.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 29,
"_isFastLaunch": true,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.g4dn.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 30,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.g4dn.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 31,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 64,
"name": "ml.g4dn.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 32,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 128,
"name": "ml.g4dn.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 33,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 4,
"hideHardwareSpecs": false,
"memoryGiB": 192,
"name": "ml.g4dn.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 34,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 256,
"name": "ml.g4dn.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 35,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 61,
"name": "ml.p3.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 36,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 4,
"hideHardwareSpecs": false,
"memoryGiB": 244,
"name": "ml.p3.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 37,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 8,
"hideHardwareSpecs": false,
"memoryGiB": 488,
"name": "ml.p3.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 38,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 8,
"hideHardwareSpecs": false,
"memoryGiB": 768,
"name": "ml.p3dn.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 39,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.r5.large",
"vcpuNum": 2
},
{
"_defaultOrder": 40,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.r5.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 41,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 64,
"name": "ml.r5.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 42,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 128,
"name": "ml.r5.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 43,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 256,
"name": "ml.r5.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 44,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 384,
"name": "ml.r5.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 45,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 512,
"name": "ml.r5.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 46,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 768,
"name": "ml.r5.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 47,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.g5.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 48,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.g5.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 49,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 64,
"name": "ml.g5.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 50,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 128,
"name": "ml.g5.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 51,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 256,
"name": "ml.g5.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 52,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 4,
"hideHardwareSpecs": false,
"memoryGiB": 192,
"name": "ml.g5.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 53,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 4,
"hideHardwareSpecs": false,
"memoryGiB": 384,
"name": "ml.g5.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 54,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 8,
"hideHardwareSpecs": false,
"memoryGiB": 768,
"name": "ml.g5.48xlarge",
"vcpuNum": 192
}
],
"instance_type": "ml.t3.large",
"kernelspec": {
"display_name": "conda_pytorch_p310",
"language": "python",
"name": "conda_pytorch_p310"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -27,6 +27,7 @@ from langchain.callbacks.manager import (
from langchain.callbacks.mlflow_callback import MlflowCallbackHandler
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.promptlayer_callback import PromptLayerCallbackHandler
from langchain.callbacks.sagemaker_callback import SageMakerCallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@ -66,4 +67,5 @@ __all__ = [
"tracing_v2_enabled",
"wandb_tracing_enabled",
"FlyteCallbackHandler",
"SageMakerCallbackHandler",
]

View File

@ -0,0 +1,280 @@
import json
import os
import shutil
import tempfile
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
flatten_dict,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
def save_json(data: dict, file_path: str) -> None:
"""Save dict to local file path.
Parameters:
data (dict): The dictionary to be saved.
file_path (str): Local file path.
"""
with open(file_path, "w") as outfile:
json.dump(data, outfile)
class SageMakerCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments.
Parameters:
run (sagemaker.experiments.run.Run): Run object where the experiment is logged.
"""
def __init__(self, run: Any) -> None:
"""Initialize callback handler."""
super().__init__()
self.run = run
self.metrics = {
"step": 0,
"starts": 0,
"ends": 0,
"errors": 0,
"text_ctr": 0,
"chain_starts": 0,
"chain_ends": 0,
"llm_starts": 0,
"llm_ends": 0,
"llm_streams": 0,
"tool_starts": 0,
"tool_ends": 0,
"agent_ends": 0,
}
# Create a temporary directory
self.temp_dir = tempfile.mkdtemp()
def _reset(self) -> None:
for k, v in self.metrics.items():
self.metrics[k] = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.metrics["step"] += 1
self.metrics["llm_starts"] += 1
self.metrics["starts"] += 1
llm_starts = self.metrics["llm_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
for idx, prompt in enumerate(prompts):
prompt_resp = deepcopy(resp)
prompt_resp["prompt"] = prompt
self.jsonf(
prompt_resp,
self.temp_dir,
f"llm_start_{llm_starts}_prompt_{idx}",
)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.metrics["step"] += 1
self.metrics["llm_streams"] += 1
llm_streams = self.metrics["llm_streams"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.metrics["step"] += 1
self.metrics["llm_ends"] += 1
self.metrics["ends"] += 1
llm_ends = self.metrics["llm_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.metrics)
for generations in response.generations:
for idx, generation in enumerate(generations):
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
self.jsonf(
resp,
self.temp_dir,
f"llm_end_{llm_ends}_generation_{idx}",
)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.metrics["step"] += 1
self.metrics["chain_starts"] += 1
self.metrics["starts"] += 1
chain_starts = self.metrics["chain_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input
self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.metrics["step"] += 1
self.metrics["chain_ends"] += 1
self.metrics["ends"] += 1
chain_ends = self.metrics["chain_ends"]
resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.metrics["step"] += 1
self.metrics["tool_ends"] += 1
self.metrics["ends"] += 1
tool_ends = self.metrics["tool_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.metrics["step"] += 1
self.metrics["text_ctr"] += 1
text_ctr = self.metrics["text_ctr"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_text", "text": text})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.metrics["step"] += 1
self.metrics["agent_ends"] += 1
self.metrics["ends"] += 1
agent_ends = self.metrics["agent_ends"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")
def jsonf(
self,
data: Dict[str, Any],
data_dir: str,
filename: str,
is_output: Optional[bool] = True,
) -> None:
"""To log the input data as json file artifact."""
file_path = os.path.join(data_dir, f"{filename}.json")
save_json(data, file_path)
self.run.log_file(file_path, name=filename, is_output=is_output)
def flush_tracker(self) -> None:
"""Reset the steps and delete the temporary local directory."""
self._reset()
shutil.rmtree(self.temp_dir)