mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Callback handler for Amazon SageMaker Experiments (#8587)
## Description This PR implements a callback handler for SageMaker Experiments which is similar to that of mlflow. * When creating the callback handler, it takes the experiment's run object as an argument. All the callback outputs are then logged to the run object. * The output of each callback action (e.g., `on_llm_start`) is saved to S3 bucket as json file. * Optionally, you can also log additional information such as the LLM hyper-parameters to the same run object. * Once the callback object is no more needed, you will need to call the `flush_tracker()` method. This makes sure that any intermediate files are deleted. * A separate notebook example is provided to show how the callback is used. @3coins @agola11 --------- Co-authored-by: Tesfagabir Meharizghi <mehariz@amazon.com>
This commit is contained in:
parent
9c2b29a1cb
commit
a7000ee89e
916
docs/extras/integrations/providers/sagemaker_tracking.ipynb
Normal file
916
docs/extras/integrations/providers/sagemaker_tracking.ipynb
Normal file
@ -0,0 +1,916 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ef3909cf-72ca-4841-85c6-ef4e0eae3aaf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SageMaker Tracking\n",
|
||||
"\n",
|
||||
"This notebook shows how LangChain Callback can be used to log and track prompts and other LLM hyperparameters into SageMaker Experiments. Here, we use different scenarios to showcase the capability:\n",
|
||||
"* **Scenario 1**: *Single LLM* - A case where a single LLM model is used to generate output based on a given prompt.\n",
|
||||
"* **Scenario 2**: *Sequential Chain* - A case where a sequential chain of two LLM models is used.\n",
|
||||
"* **Scenario 3**: *Agent with Tools (Chain of Thought)* - A case where multiple tools (search and math) are used in addition to an LLM.\n",
|
||||
"\n",
|
||||
"[Amazon SageMaker](https://aws.amazon.com/sagemaker/) is a fully managed service that is used to quickly and easily build, train and deploy machine learning (ML) models. \n",
|
||||
"\n",
|
||||
"[Amazon SageMaker Experiments](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments.html) is a capability of Amazon SageMaker that lets you organize, track, compare and evaluate ML experiments and model versions.\n",
|
||||
"\n",
|
||||
"In this notebook, we will create a single experiment to log the prompts from each scenario."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "94c22cb4-3b1c-432b-b3be-0235eec79c5c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installation and Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2353436d-17fe-4f58-a2f9-c299d56393fd",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install sagemaker\n",
|
||||
"!pip install openai\n",
|
||||
"!pip install google-search-results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "65dcf62e-7a38-4119-adb9-d9e884e82499",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"First, setup the required API keys\n",
|
||||
"* OpenAI: https://platform.openai.com/account/api-keys (For OpenAI LLM model)\n",
|
||||
"* Google SERP API: https://serpapi.com/manage-api-key (For Google Search Tool)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5ec2b898-0cfc-4308-8e86-569cd7b7cf41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"## Add your API keys below\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"<ADD-KEY-HERE>\"\n",
|
||||
"os.environ[\"SERPAPI_API_KEY\"] = \"<ADD-KEY-HERE>\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80968ebf-519f-46de-8703-97532ac39e3e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.chains import LLMChain, SimpleSequentialChain\n",
|
||||
"from langchain.agents import initialize_agent, load_tools\n",
|
||||
"from langchain.agents import Tool\n",
|
||||
"from langchain.callbacks import SageMakerCallbackHandler\n",
|
||||
"\n",
|
||||
"from sagemaker.analytics import ExperimentAnalytics\n",
|
||||
"from sagemaker.session import Session\n",
|
||||
"from sagemaker.experiments.run import Run"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b67d031f-a01f-4009-ad29-c80ab8ad50ea",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## LLM Prompt Tracking"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "da2d70ee-173b-469d-a718-54c33d862844",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#LLM Hyperparameters\n",
|
||||
"HPARAMS = {\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
" \"model_name\": \"text-davinci-003\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"#Bucket used to save prompt logs (Use `None` is used to save the default bucket or otherwise change it)\n",
|
||||
"BUCKET_NAME = None\n",
|
||||
"\n",
|
||||
"#Experiment name\n",
|
||||
"EXPERIMENT_NAME = \"langchain-sagemaker-tracker\"\n",
|
||||
"\n",
|
||||
"#Create SageMaker Session with the given bucket\n",
|
||||
"session = Session(default_bucket=BUCKET_NAME)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7239a39a-08d8-43cb-8922-81abdd5d9ebf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Scenario 1 - LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "abc00335-50c8-4119-adb8-4c4ab8522e23",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"RUN_NAME = \"run-scenario-1\"\n",
|
||||
"PROMPT_TEMPLATE = \"tell me a joke about {topic}\"\n",
|
||||
"INPUT_VARIABLES = {\"topic\": \"fish\"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4a3a3cbe-db85-4255-8d8b-eaafdca8c6e2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:\n",
|
||||
"\n",
|
||||
" # Create SageMaker Callback\n",
|
||||
" sagemaker_callback = SageMakerCallbackHandler(run)\n",
|
||||
"\n",
|
||||
" # Define LLM model with callback\n",
|
||||
" llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)\n",
|
||||
"\n",
|
||||
" # Create prompt template\n",
|
||||
" prompt = PromptTemplate.from_template(template=PROMPT_TEMPLATE)\n",
|
||||
"\n",
|
||||
" # Create LLM Chain\n",
|
||||
" chain = LLMChain(llm=llm, prompt=prompt, callbacks=[sagemaker_callback])\n",
|
||||
"\n",
|
||||
" # Run chain\n",
|
||||
" chain.run(**INPUT_VARIABLES)\n",
|
||||
"\n",
|
||||
" # Reset the callback\n",
|
||||
" sagemaker_callback.flush_tracker()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7dc69934-9f42-40b7-9931-36a3371a38da",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Scenario 2 - Sequential Chain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "50b75ef9-9825-4ccc-8414-4cd7525a1b68",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"RUN_NAME = \"run-scenario-2\"\n",
|
||||
"\n",
|
||||
"PROMPT_TEMPLATE_1 = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
|
||||
"Title: {title}\n",
|
||||
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||
"PROMPT_TEMPLATE_2 = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n",
|
||||
"Play Synopsis: {synopsis}\n",
|
||||
"Review from a New York Times play critic of the above play:\"\"\"\n",
|
||||
"\n",
|
||||
"INPUT_VARIABLES = {\n",
|
||||
" \"input\": \"documentary about good video games that push the boundary of game design\"\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fb7fff5f-e89f-40e2-96b4-3641a0b6e9b4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:\n",
|
||||
"\n",
|
||||
" # Create SageMaker Callback\n",
|
||||
" sagemaker_callback = SageMakerCallbackHandler(run)\n",
|
||||
"\n",
|
||||
" # Create prompt templates for the chain\n",
|
||||
" prompt_template1 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_1)\n",
|
||||
" prompt_template2 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_2)\n",
|
||||
"\n",
|
||||
" # Define LLM model with callback\n",
|
||||
" llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)\n",
|
||||
"\n",
|
||||
" # Create chain1\n",
|
||||
" chain1 = LLMChain(llm=llm, prompt=prompt_template1, callbacks=[sagemaker_callback])\n",
|
||||
"\n",
|
||||
" # Create chain2\n",
|
||||
" chain2 = LLMChain(llm=llm, prompt=prompt_template2, callbacks=[sagemaker_callback])\n",
|
||||
"\n",
|
||||
" # Create Sequential chain\n",
|
||||
" overall_chain = SimpleSequentialChain(chains=[chain1, chain2], callbacks=[sagemaker_callback])\n",
|
||||
"\n",
|
||||
" # Run overall sequential chain\n",
|
||||
" overall_chain.run(**INPUT_VARIABLES)\n",
|
||||
"\n",
|
||||
" # Reset the callback\n",
|
||||
" sagemaker_callback.flush_tracker()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6b82bd0e-c626-4797-bb06-c1983f176315",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Scenario 3 - Agent with Tools"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b5066f03-49dc-4868-be8e-d21ce22063fe",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"RUN_NAME = \"run-scenario-3\"\n",
|
||||
"PROMPT_TEMPLATE = \"Who is the oldest person alive? And what is their current age raised to the power of 1.51?\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "98385c42-9e44-4b03-b76d-007cb4797864",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:\n",
|
||||
"\n",
|
||||
" # Create SageMaker Callback\n",
|
||||
" sagemaker_callback = SageMakerCallbackHandler(run)\n",
|
||||
"\n",
|
||||
" # Define LLM model with callback\n",
|
||||
" llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)\n",
|
||||
"\n",
|
||||
" # Define tools\n",
|
||||
" tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=[sagemaker_callback])\n",
|
||||
"\n",
|
||||
" # Initialize agent with all the tools\n",
|
||||
" agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", callbacks=[sagemaker_callback])\n",
|
||||
"\n",
|
||||
" # Run agent\n",
|
||||
" agent.run(input=PROMPT_TEMPLATE)\n",
|
||||
"\n",
|
||||
" # Reset the callback\n",
|
||||
" sagemaker_callback.flush_tracker()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c306a1c9-99f8-476d-96db-347746f5cfe0",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"## Load Log Data\n",
|
||||
"\n",
|
||||
"Once the prompts are logged, we can easily load and convert them to Pandas DataFrame as follows."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec7b4af2-e01d-4f6c-9de5-70d2b4acb9e6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#Load\n",
|
||||
"logs = ExperimentAnalytics(experiment_name=EXPERIMENT_NAME)\n",
|
||||
"\n",
|
||||
"#Convert as pandas dataframe\n",
|
||||
"df = logs.dataframe(force_refresh=True)\n",
|
||||
"\n",
|
||||
"print(df.shape)\n",
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "29991c75-f9cf-4c36-abfd-903c09fb170d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As can be seen above, there are three runs (rows) in the experiment corresponding to each scenario. Each run logs the prompts and related LLM settings/hyperparameters as json and are saved in s3 bucket. Feel free to load and explore the log data from each json path."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "61a695d6-0aef-4284-9e12-eea8bc143dbd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.t3.large",
|
||||
"kernelspec": {
|
||||
"display_name": "conda_pytorch_p310",
|
||||
"language": "python",
|
||||
"name": "conda_pytorch_p310"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -27,6 +27,7 @@ from langchain.callbacks.manager import (
|
||||
from langchain.callbacks.mlflow_callback import MlflowCallbackHandler
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.promptlayer_callback import PromptLayerCallbackHandler
|
||||
from langchain.callbacks.sagemaker_callback import SageMakerCallbackHandler
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
@ -66,4 +67,5 @@ __all__ = [
|
||||
"tracing_v2_enabled",
|
||||
"wandb_tracing_enabled",
|
||||
"FlyteCallbackHandler",
|
||||
"SageMakerCallbackHandler",
|
||||
]
|
||||
|
280
libs/langchain/langchain/callbacks/sagemaker_callback.py
Normal file
280
libs/langchain/langchain/callbacks/sagemaker_callback.py
Normal file
@ -0,0 +1,280 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
flatten_dict,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def save_json(data: dict, file_path: str) -> None:
|
||||
"""Save dict to local file path.
|
||||
|
||||
Parameters:
|
||||
data (dict): The dictionary to be saved.
|
||||
file_path (str): Local file path.
|
||||
"""
|
||||
with open(file_path, "w") as outfile:
|
||||
json.dump(data, outfile)
|
||||
|
||||
|
||||
class SageMakerCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments.
|
||||
|
||||
Parameters:
|
||||
run (sagemaker.experiments.run.Run): Run object where the experiment is logged.
|
||||
"""
|
||||
|
||||
def __init__(self, run: Any) -> None:
|
||||
"""Initialize callback handler."""
|
||||
super().__init__()
|
||||
|
||||
self.run = run
|
||||
|
||||
self.metrics = {
|
||||
"step": 0,
|
||||
"starts": 0,
|
||||
"ends": 0,
|
||||
"errors": 0,
|
||||
"text_ctr": 0,
|
||||
"chain_starts": 0,
|
||||
"chain_ends": 0,
|
||||
"llm_starts": 0,
|
||||
"llm_ends": 0,
|
||||
"llm_streams": 0,
|
||||
"tool_starts": 0,
|
||||
"tool_ends": 0,
|
||||
"agent_ends": 0,
|
||||
}
|
||||
|
||||
# Create a temporary directory
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
def _reset(self) -> None:
|
||||
for k, v in self.metrics.items():
|
||||
self.metrics[k] = 0
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
llm_starts = self.metrics["llm_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
for idx, prompt in enumerate(prompts):
|
||||
prompt_resp = deepcopy(resp)
|
||||
prompt_resp["prompt"] = prompt
|
||||
self.jsonf(
|
||||
prompt_resp,
|
||||
self.temp_dir,
|
||||
f"llm_start_{llm_starts}_prompt_{idx}",
|
||||
)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_streams"] += 1
|
||||
|
||||
llm_streams = self.metrics["llm_streams"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
llm_ends = self.metrics["llm_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
|
||||
resp.update(self.metrics)
|
||||
|
||||
for generations in response.generations:
|
||||
for idx, generation in enumerate(generations):
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
|
||||
self.jsonf(
|
||||
resp,
|
||||
self.temp_dir,
|
||||
f"llm_end_{llm_ends}_generation_{idx}",
|
||||
)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
chain_starts = self.metrics["chain_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["inputs"] = chain_input
|
||||
|
||||
self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
chain_ends = self.metrics["chain_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
|
||||
resp.update({"action": "on_chain_end", "outputs": chain_output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
tool_starts = self.metrics["tool_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
tool_ends = self.metrics["tool_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["text_ctr"] += 1
|
||||
|
||||
text_ctr = self.metrics["text_ctr"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["agent_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
agent_ends = self.metrics["agent_ends"]
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
tool_starts = self.metrics["tool_starts"]
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.metrics)
|
||||
self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")
|
||||
|
||||
def jsonf(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
data_dir: str,
|
||||
filename: str,
|
||||
is_output: Optional[bool] = True,
|
||||
) -> None:
|
||||
"""To log the input data as json file artifact."""
|
||||
file_path = os.path.join(data_dir, f"{filename}.json")
|
||||
save_json(data, file_path)
|
||||
self.run.log_file(file_path, name=filename, is_output=is_output)
|
||||
|
||||
def flush_tracker(self) -> None:
|
||||
"""Reset the steps and delete the temporary local directory."""
|
||||
self._reset()
|
||||
shutil.rmtree(self.temp_dir)
|
Loading…
Reference in New Issue
Block a user