mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
ed58eeb9c5
Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
277 lines
8.6 KiB
Python
277 lines
8.6 KiB
Python
import json
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
from copy import deepcopy
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
from langchain_core.outputs import LLMResult
|
|
|
|
from langchain_community.callbacks.utils import (
|
|
flatten_dict,
|
|
)
|
|
|
|
|
|
def save_json(data: dict, file_path: str) -> None:
|
|
"""Save dict to local file path.
|
|
|
|
Parameters:
|
|
data (dict): The dictionary to be saved.
|
|
file_path (str): Local file path.
|
|
"""
|
|
with open(file_path, "w") as outfile:
|
|
json.dump(data, outfile)
|
|
|
|
|
|
class SageMakerCallbackHandler(BaseCallbackHandler):
|
|
"""Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments.
|
|
|
|
Parameters:
|
|
run (sagemaker.experiments.run.Run): Run object where the experiment is logged.
|
|
"""
|
|
|
|
def __init__(self, run: Any) -> None:
|
|
"""Initialize callback handler."""
|
|
super().__init__()
|
|
|
|
self.run = run
|
|
|
|
self.metrics = {
|
|
"step": 0,
|
|
"starts": 0,
|
|
"ends": 0,
|
|
"errors": 0,
|
|
"text_ctr": 0,
|
|
"chain_starts": 0,
|
|
"chain_ends": 0,
|
|
"llm_starts": 0,
|
|
"llm_ends": 0,
|
|
"llm_streams": 0,
|
|
"tool_starts": 0,
|
|
"tool_ends": 0,
|
|
"agent_ends": 0,
|
|
}
|
|
|
|
# Create a temporary directory
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
|
|
def _reset(self) -> None:
|
|
for k, v in self.metrics.items():
|
|
self.metrics[k] = 0
|
|
|
|
def on_llm_start(
|
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
) -> None:
|
|
"""Run when LLM starts."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["llm_starts"] += 1
|
|
self.metrics["starts"] += 1
|
|
|
|
llm_starts = self.metrics["llm_starts"]
|
|
|
|
resp: Dict[str, Any] = {}
|
|
resp.update({"action": "on_llm_start"})
|
|
resp.update(flatten_dict(serialized))
|
|
resp.update(self.metrics)
|
|
|
|
for idx, prompt in enumerate(prompts):
|
|
prompt_resp = deepcopy(resp)
|
|
prompt_resp["prompt"] = prompt
|
|
self.jsonf(
|
|
prompt_resp,
|
|
self.temp_dir,
|
|
f"llm_start_{llm_starts}_prompt_{idx}",
|
|
)
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
"""Run when LLM generates a new token."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["llm_streams"] += 1
|
|
|
|
llm_streams = self.metrics["llm_streams"]
|
|
|
|
resp: Dict[str, Any] = {}
|
|
resp.update({"action": "on_llm_new_token", "token": token})
|
|
resp.update(self.metrics)
|
|
|
|
self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
"""Run when LLM ends running."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["llm_ends"] += 1
|
|
self.metrics["ends"] += 1
|
|
|
|
llm_ends = self.metrics["llm_ends"]
|
|
|
|
resp: Dict[str, Any] = {}
|
|
resp.update({"action": "on_llm_end"})
|
|
resp.update(flatten_dict(response.llm_output or {}))
|
|
|
|
resp.update(self.metrics)
|
|
|
|
for generations in response.generations:
|
|
for idx, generation in enumerate(generations):
|
|
generation_resp = deepcopy(resp)
|
|
generation_resp.update(flatten_dict(generation.dict()))
|
|
|
|
self.jsonf(
|
|
resp,
|
|
self.temp_dir,
|
|
f"llm_end_{llm_ends}_generation_{idx}",
|
|
)
|
|
|
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Run when LLM errors."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["errors"] += 1
|
|
|
|
def on_chain_start(
|
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
) -> None:
|
|
"""Run when chain starts running."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["chain_starts"] += 1
|
|
self.metrics["starts"] += 1
|
|
|
|
chain_starts = self.metrics["chain_starts"]
|
|
|
|
resp: Dict[str, Any] = {}
|
|
resp.update({"action": "on_chain_start"})
|
|
resp.update(flatten_dict(serialized))
|
|
resp.update(self.metrics)
|
|
|
|
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
|
|
input_resp = deepcopy(resp)
|
|
input_resp["inputs"] = chain_input
|
|
|
|
self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")
|
|
|
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
"""Run when chain ends running."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["chain_ends"] += 1
|
|
self.metrics["ends"] += 1
|
|
|
|
chain_ends = self.metrics["chain_ends"]
|
|
|
|
resp: Dict[str, Any] = {}
|
|
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
|
|
resp.update({"action": "on_chain_end", "outputs": chain_output})
|
|
resp.update(self.metrics)
|
|
|
|
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
|
|
|
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Run when chain errors."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["errors"] += 1
|
|
|
|
def on_tool_start(
|
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
|
) -> None:
|
|
"""Run when tool starts running."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["tool_starts"] += 1
|
|
self.metrics["starts"] += 1
|
|
|
|
tool_starts = self.metrics["tool_starts"]
|
|
|
|
resp: Dict[str, Any] = {}
|
|
resp.update({"action": "on_tool_start", "input_str": input_str})
|
|
resp.update(flatten_dict(serialized))
|
|
resp.update(self.metrics)
|
|
|
|
self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
|
|
|
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
|
"""Run when tool ends running."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["tool_ends"] += 1
|
|
self.metrics["ends"] += 1
|
|
|
|
tool_ends = self.metrics["tool_ends"]
|
|
|
|
resp: Dict[str, Any] = {}
|
|
resp.update({"action": "on_tool_end", "output": output})
|
|
resp.update(self.metrics)
|
|
|
|
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
|
|
|
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Run when tool errors."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["errors"] += 1
|
|
|
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
|
"""
|
|
Run when agent is ending.
|
|
"""
|
|
self.metrics["step"] += 1
|
|
self.metrics["text_ctr"] += 1
|
|
|
|
text_ctr = self.metrics["text_ctr"]
|
|
|
|
resp: Dict[str, Any] = {}
|
|
resp.update({"action": "on_text", "text": text})
|
|
resp.update(self.metrics)
|
|
|
|
self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")
|
|
|
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
|
"""Run when agent ends running."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["agent_ends"] += 1
|
|
self.metrics["ends"] += 1
|
|
|
|
agent_ends = self.metrics["agent_ends"]
|
|
resp: Dict[str, Any] = {}
|
|
resp.update(
|
|
{
|
|
"action": "on_agent_finish",
|
|
"output": finish.return_values["output"],
|
|
"log": finish.log,
|
|
}
|
|
)
|
|
resp.update(self.metrics)
|
|
|
|
self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")
|
|
|
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
|
"""Run on agent action."""
|
|
self.metrics["step"] += 1
|
|
self.metrics["tool_starts"] += 1
|
|
self.metrics["starts"] += 1
|
|
|
|
tool_starts = self.metrics["tool_starts"]
|
|
resp: Dict[str, Any] = {}
|
|
resp.update(
|
|
{
|
|
"action": "on_agent_action",
|
|
"tool": action.tool,
|
|
"tool_input": action.tool_input,
|
|
"log": action.log,
|
|
}
|
|
)
|
|
resp.update(self.metrics)
|
|
self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")
|
|
|
|
def jsonf(
|
|
self,
|
|
data: Dict[str, Any],
|
|
data_dir: str,
|
|
filename: str,
|
|
is_output: Optional[bool] = True,
|
|
) -> None:
|
|
"""To log the input data as json file artifact."""
|
|
file_path = os.path.join(data_dir, f"{filename}.json")
|
|
save_json(data, file_path)
|
|
self.run.log_file(file_path, name=filename, is_output=is_output)
|
|
|
|
def flush_tracker(self) -> None:
|
|
"""Reset the steps and delete the temporary local directory."""
|
|
self._reset()
|
|
shutil.rmtree(self.temp_dir)
|