From be29a6287d3d3adfa1b8e8652b51a37e4a2bb0c5 Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Thu, 13 Jul 2023 00:30:18 +0530 Subject: [PATCH] feat: add model architecture back to wandb tracer (#6806) # Description This PR adds model architecture to the `WandbTracer` from the Serialized Run kwargs. This allows visualization of the calling parameters of an Agent, LLM and Tool in Weights & Biases. 1. Safely serialize the run objects to WBTraceTree model_dict 2. Refactors the run processing logic to be more organized. - Twitter handle: @parambharat --------- Co-authored-by: Bharat Ramanathan Co-authored-by: Bagatur --- langchain/callbacks/tracers/wandb.py | 448 ++++++++++++++++++++------- 1 file changed, 337 insertions(+), 111 deletions(-) diff --git a/langchain/callbacks/tracers/wandb.py b/langchain/callbacks/tracers/wandb.py index 8f744e28b2..49d810ae97 100644 --- a/langchain/callbacks/tracers/wandb.py +++ b/langchain/callbacks/tracers/wandb.py @@ -1,6 +1,7 @@ """A Tracer Implementation that records activity to Weights & Biases.""" from __future__ import annotations +import json from typing import ( TYPE_CHECKING, Any, @@ -8,6 +9,7 @@ from typing import ( List, Optional, Sequence, + Tuple, TypedDict, Union, ) @@ -17,7 +19,7 @@ from langchain.callbacks.tracers.schemas import Run, RunTypeEnum if TYPE_CHECKING: from wandb import Settings as WBSettings - from wandb.sdk.data_types import trace_tree + from wandb.sdk.data_types.trace_tree import Span from wandb.sdk.lib.paths import StrPath from wandb.wandb_run import Run as WBRun @@ -25,115 +27,350 @@ if TYPE_CHECKING: PRINT_WARNINGS = True -def _convert_lc_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: - if run.run_type == RunTypeEnum.llm: - return _convert_llm_run_to_wb_span(trace_tree, run) - elif run.run_type == RunTypeEnum.chain: - return _convert_chain_run_to_wb_span(trace_tree, run) - elif run.run_type == RunTypeEnum.tool: - return _convert_tool_run_to_wb_span(trace_tree, run) - else: - return _convert_run_to_wb_span(trace_tree, run) - - -def _convert_llm_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: - base_span = _convert_run_to_wb_span(trace_tree, run) - - base_span.results = [ - trace_tree.Result( - inputs={"prompt": prompt}, - outputs={ - f"gen_{g_i}": gen["text"] - for g_i, gen in enumerate(run.outputs["generations"][ndx]) - } - if ( - run.outputs is not None - and len(run.outputs["generations"]) > ndx - and len(run.outputs["generations"][ndx]) > 0 - ) - else None, - ) - for ndx, prompt in enumerate(run.inputs["prompts"] or []) - ] - base_span.span_kind = trace_tree.SpanKind.LLM - - return base_span - - -def _serialize_inputs(run_inputs: dict) -> Union[dict, list]: +def _serialize_inputs(run_inputs: dict) -> dict: if "input_documents" in run_inputs: docs = run_inputs["input_documents"] - return [doc.json() for doc in docs] + return {f"input_document_{i}": doc.json() for i, doc in enumerate(docs)} else: return run_inputs -def _convert_chain_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: - base_span = _convert_run_to_wb_span(trace_tree, run) +class RunProcessor: + """Handles the conversion of a LangChain Runs into a WBTraceTree.""" - base_span.results = [ - trace_tree.Result(inputs=_serialize_inputs(run.inputs), outputs=run.outputs) - ] - base_span.child_spans = [ - _convert_lc_run_to_wb_span(trace_tree, child_run) - for child_run in run.child_runs - ] - base_span.span_kind = ( - trace_tree.SpanKind.AGENT - if "agent" in run.serialized.get("name", "").lower() - else trace_tree.SpanKind.CHAIN - ) + def __init__(self, wandb_module: Any, trace_module: Any): + self.wandb = wandb_module + self.trace_tree = trace_module - return base_span + def process_span(self, run: Run) -> Optional["Span"]: + """Converts a LangChain Run into a W&B Trace Span. + :param run: The LangChain Run to convert. + :return: The converted W&B Trace Span. + """ + try: + span = self._convert_lc_run_to_wb_span(run) + return span + except Exception as e: + if PRINT_WARNINGS: + self.wandb.termwarn( + f"Skipping trace saving - unable to safely convert LangChain Run " + f"into W&B Trace due to: {e}" + ) + return None + def _convert_run_to_wb_span(self, run: Run) -> "Span": + """Base utility to create a span from a run. + :param run: The run to convert. + :return: The converted Span. + """ + attributes = {**run.extra} if run.extra else {} + attributes["execution_order"] = run.execution_order -def _convert_tool_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: - base_span = _convert_run_to_wb_span(trace_tree, run) - base_span.results = [ - trace_tree.Result(inputs=_serialize_inputs(run.inputs), outputs=run.outputs) - ] - base_span.child_spans = [ - _convert_lc_run_to_wb_span(trace_tree, child_run) - for child_run in run.child_runs - ] - base_span.span_kind = trace_tree.SpanKind.TOOL + return self.trace_tree.Span( + span_id=str(run.id) if run.id is not None else None, + name=run.name, + start_time_ms=int(run.start_time.timestamp() * 1000), + end_time_ms=int(run.end_time.timestamp() * 1000), + status_code=self.trace_tree.StatusCode.SUCCESS + if run.error is None + else self.trace_tree.StatusCode.ERROR, + status_message=run.error, + attributes=attributes, + ) - return base_span + def _convert_llm_run_to_wb_span(self, run: Run) -> "Span": + """Converts a LangChain LLM Run into a W&B Trace Span. + :param run: The LangChain LLM Run to convert. + :return: The converted W&B Trace Span. + """ + base_span = self._convert_run_to_wb_span(run) + if base_span.attributes is None: + base_span.attributes = {} + base_span.attributes["llm_output"] = run.outputs.get("llm_output", {}) + base_span.results = [ + self.trace_tree.Result( + inputs={"prompt": prompt}, + outputs={ + f"gen_{g_i}": gen["text"] + for g_i, gen in enumerate(run.outputs["generations"][ndx]) + } + if ( + run.outputs is not None + and len(run.outputs["generations"]) > ndx + and len(run.outputs["generations"][ndx]) > 0 + ) + else None, + ) + for ndx, prompt in enumerate(run.inputs["prompts"] or []) + ] + base_span.span_kind = self.trace_tree.SpanKind.LLM -def _convert_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: - attributes = {**run.extra} if run.extra else {} - attributes["execution_order"] = run.execution_order + return base_span - return trace_tree.Span( - span_id=str(run.id) if run.id is not None else None, - name=run.serialized.get("name"), - start_time_ms=int(run.start_time.timestamp() * 1000), - end_time_ms=int(run.end_time.timestamp() * 1000), - status_code=trace_tree.StatusCode.SUCCESS - if run.error is None - else trace_tree.StatusCode.ERROR, - status_message=run.error, - attributes=attributes, - ) + def _convert_chain_run_to_wb_span(self, run: Run) -> "Span": + """Converts a LangChain Chain Run into a W&B Trace Span. + :param run: The LangChain Chain Run to convert. + :return: The converted W&B Trace Span. + """ + base_span = self._convert_run_to_wb_span(run) + base_span.results = [ + self.trace_tree.Result( + inputs=_serialize_inputs(run.inputs), outputs=run.outputs + ) + ] + base_span.child_spans = [ + self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs + ] + base_span.span_kind = ( + self.trace_tree.SpanKind.AGENT + if "agent" in run.name.lower() + else self.trace_tree.SpanKind.CHAIN + ) -def _replace_type_with_kind(data: Any) -> Any: - if isinstance(data, dict): - # W&B TraceTree expects "_kind" instead of "_type" since `_type` is special - # in W&B. - if "_type" in data: - _type = data.pop("_type") - data["_kind"] = _type - return {k: _replace_type_with_kind(v) for k, v in data.items()} - elif isinstance(data, list): - return [_replace_type_with_kind(v) for v in data] - elif isinstance(data, tuple): - return tuple(_replace_type_with_kind(v) for v in data) - elif isinstance(data, set): - return {_replace_type_with_kind(v) for v in data} - else: - return data + return base_span + + def _convert_tool_run_to_wb_span(self, run: Run) -> "Span": + """Converts a LangChain Tool Run into a W&B Trace Span. + :param run: The LangChain Tool Run to convert. + :return: The converted W&B Trace Span. + """ + base_span = self._convert_run_to_wb_span(run) + base_span.results = [ + self.trace_tree.Result( + inputs=_serialize_inputs(run.inputs), outputs=run.outputs + ) + ] + base_span.child_spans = [ + self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs + ] + base_span.span_kind = self.trace_tree.SpanKind.TOOL + + return base_span + + def _convert_lc_run_to_wb_span(self, run: Run) -> "Span": + """Utility to convert any generic LangChain Run into a W&B Trace Span. + :param run: The LangChain Run to convert. + :return: The converted W&B Trace Span. + """ + if run.run_type == RunTypeEnum.llm: + return self._convert_llm_run_to_wb_span(run) + elif run.run_type == RunTypeEnum.chain: + return self._convert_chain_run_to_wb_span(run) + elif run.run_type == RunTypeEnum.tool: + return self._convert_tool_run_to_wb_span(run) + else: + return self._convert_run_to_wb_span(run) + + def process_model(self, run: Run) -> Optional[Dict[str, Any]]: + """Utility to process a run for wandb model_dict serialization. + :param run: The run to process. + :return: The convert model_dict to pass to WBTraceTree. + """ + try: + data = json.loads(run.json()) + processed = self.flatten_run(data) + keep_keys = ( + "id", + "name", + "serialized", + "inputs", + "outputs", + "parent_run_id", + "execution_order", + ) + processed = self.truncate_run_iterative(processed, keep_keys=keep_keys) + exact_keys, partial_keys = ("lc", "type"), ("api_key",) + processed = self.modify_serialized_iterative( + processed, exact_keys=exact_keys, partial_keys=partial_keys + ) + output = self.build_tree(processed) + return output + except Exception as e: + if PRINT_WARNINGS: + self.wandb.termwarn(f"WARNING: Failed to serialize model: {e}") + return None + + def flatten_run(self, run: Dict[str, Any]) -> List[Dict[str, Any]]: + """Utility to flatten a nest run object into a list of runs. + :param run: The base run to flatten. + :return: The flattened list of runs. + """ + + def flatten(child_runs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Utility to recursively flatten a list of child runs in a run. + :param child_runs: The list of child runs to flatten. + :return: The flattened list of runs. + """ + if child_runs is None: + return [] + + result = [] + for item in child_runs: + child_runs = item.pop("child_runs", []) + result.append(item) + result.extend(flatten(child_runs)) + + return result + + return flatten([run]) + + def truncate_run_iterative( + self, runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = () + ) -> List[Dict[str, Any]]: + """Utility to truncate a list of runs dictionaries to only keep the specified + keys in each run. + :param runs: The list of runs to truncate. + :param keep_keys: The keys to keep in each run. + :return: The truncated list of runs. + """ + + def truncate_single(run: Dict[str, Any]) -> Dict[str, Any]: + """Utility to truncate a single run dictionary to only keep the specified + keys. + :param run: The run dictionary to truncate. + :return: The truncated run dictionary + """ + new_dict = {} + for key in run: + if key in keep_keys: + new_dict[key] = run.get(key) + return new_dict + + return list(map(truncate_single, runs)) + + def modify_serialized_iterative( + self, + runs: List[Dict[str, Any]], + exact_keys: Tuple[str, ...] = (), + partial_keys: Tuple[str, ...] = (), + ) -> List[Dict[str, Any]]: + """Utility to modify the serialized field of a list of runs dictionaries. + removes any keys that match the exact_keys and any keys that contain any of the + partial_keys. + recursively moves the dictionaries under the kwargs key to the top level. + changes the "id" field to a string "_kind" field that tells WBTraceTree how to + visualize the run. promotes the "serialized" field to the top level. + + :param runs: The list of runs to modify. + :param exact_keys: A tuple of keys to remove from the serialized field. + :param partial_keys: A tuple of partial keys to remove from the serialized + field. + :return: The modified list of runs. + """ + + def remove_exact_and_partial_keys(obj: Dict[str, Any]) -> Dict[str, Any]: + """Recursively removes exact and partial keys from a dictionary. + :param obj: The dictionary to remove keys from. + :return: The modified dictionary. + """ + if isinstance(obj, dict): + obj = { + k: v + for k, v in obj.items() + if k not in exact_keys + and not any(partial in k for partial in partial_keys) + } + for k, v in obj.items(): + obj[k] = remove_exact_and_partial_keys(v) + elif isinstance(obj, list): + obj = [remove_exact_and_partial_keys(x) for x in obj] + return obj + + def handle_id_and_kwargs( + obj: Dict[str, Any], root: bool = False + ) -> Dict[str, Any]: + """Recursively handles the id and kwargs fields of a dictionary. + changes the id field to a string "_kind" field that tells WBTraceTree how + to visualize the run. recursively moves the dictionaries under the kwargs + key to the top level. + :param obj: a run dictionary with id and kwargs fields. + :param root: whether this is the root dictionary or the serialized + dictionary. + :return: The modified dictionary. + """ + if isinstance(obj, dict): + if ("id" in obj or "name" in obj) and not root: + _kind = obj.get("id") + if not _kind: + _kind = [obj.get("name")] + obj["_kind"] = _kind[-1] + obj.pop("id", None) + obj.pop("name", None) + if "kwargs" in obj: + kwargs = obj.pop("kwargs") + for k, v in kwargs.items(): + obj[k] = v + for k, v in obj.items(): + obj[k] = handle_id_and_kwargs(v) + elif isinstance(obj, list): + obj = [handle_id_and_kwargs(x) for x in obj] + return obj + + def transform_serialized(serialized: Dict[str, Any]) -> Dict[str, Any]: + """Transforms the serialized field of a run dictionary to be compatible + with WBTraceTree. + :param serialized: The serialized field of a run dictionary. + :return: The transformed serialized field. + """ + serialized = handle_id_and_kwargs(serialized, root=True) + serialized = remove_exact_and_partial_keys(serialized) + return serialized + + def transform_run(run: Dict[str, Any]) -> Dict[str, Any]: + """Transforms a run dictionary to be compatible with WBTraceTree. + :param run: The run dictionary to transform. + :return: The transformed run dictionary. + """ + transformed_dict = transform_serialized(run) + + serialized = transformed_dict.pop("serialized") + for k, v in serialized.items(): + transformed_dict[k] = v + + _kind = transformed_dict.get("_kind", None) + name = transformed_dict.pop("name", None) + exec_ord = transformed_dict.pop("execution_order", None) + + if not name: + name = _kind + + output_dict = { + f"{exec_ord}_{name}": transformed_dict, + } + return output_dict + + return list(map(transform_run, runs)) + + def build_tree(self, runs: List[Dict[str, Any]]) -> Dict[str, Any]: + """Builds a nested dictionary from a list of runs. + :param runs: The list of runs to build the tree from. + :return: The nested dictionary representing the langchain Run in a tree + structure compatible with WBTraceTree. + """ + id_to_data = {} + child_to_parent = {} + + for entity in runs: + for key, data in entity.items(): + id_val = data.pop("id", None) + parent_run_id = data.pop("parent_run_id", None) + id_to_data[id_val] = {key: data} + if parent_run_id: + child_to_parent[id_val] = parent_run_id + + for child_id, parent_id in child_to_parent.items(): + parent_dict = id_to_data[parent_id] + parent_dict[next(iter(parent_dict))][ + next(iter(id_to_data[child_id])) + ] = id_to_data[child_id][next(iter(id_to_data[child_id]))] + + root_dict = next( + data for id_val, data in id_to_data.items() if id_val not in child_to_parent + ) + + return root_dict class WandbRunArgs(TypedDict): @@ -201,12 +438,13 @@ class WandbTracer(BaseTracer): except ImportError as e: raise ImportError( "Could not import wandb python package." - "Please install it with `pip install wandb`." + "Please install it with `pip install -U wandb`." ) from e self._wandb = wandb self._trace_tree = trace_tree self._run_args = run_args self._ensure_run(should_print_url=(wandb.run is None)) + self.run_processor = RunProcessor(self._wandb, self._trace_tree) def finish(self) -> None: """Waits for all asynchronous processes to finish and data to upload. @@ -219,24 +457,12 @@ class WandbTracer(BaseTracer): """Logs a LangChain Run to W*B as a W&B Trace.""" self._ensure_run() - try: - root_span = _convert_lc_run_to_wb_span(self._trace_tree, run) - except Exception as e: - if PRINT_WARNINGS: - self._wandb.termwarn( - f"Skipping trace saving - unable to safely convert LangChain Run " - f"into W&B Trace due to: {e}" - ) + root_span = self.run_processor.process_span(run) + model_dict = self.run_processor.process_model(run) + + if root_span is None: return - model_dict = None - - # TODO: Add something like this once we have a way to get the clean serialized - # parent dict from a run: - # serialized_parent = safely_get_span_producing_model(run) - # if serialized_parent is not None: - # model_dict = safely_convert_model_to_dict(serialized_parent) - model_trace = self._trace_tree.WBTraceTree( root_span=root_span, model_dict=model_dict,