diff --git a/langchain/callbacks/tracers/wandb.py b/langchain/callbacks/tracers/wandb.py index 98adc0d0..4a6bdc89 100644 --- a/langchain/callbacks/tracers/wandb.py +++ b/langchain/callbacks/tracers/wandb.py @@ -60,10 +60,20 @@ def _convert_llm_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: return base_span +def _serialize_inputs(run_inputs: dict) -> Union[dict, list]: + if "input_documents" in run_inputs: + docs = run_inputs["input_documents"] + return [doc.json() for doc in docs] + else: + return run_inputs + + def _convert_chain_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: base_span = _convert_run_to_wb_span(trace_tree, run) - base_span.results = [trace_tree.Result(inputs=run.inputs, outputs=run.outputs)] + base_span.results = [ + trace_tree.Result(inputs=_serialize_inputs(run.inputs), outputs=run.outputs) + ] base_span.child_spans = [ _convert_lc_run_to_wb_span(trace_tree, child_run) for child_run in run.child_runs @@ -79,7 +89,9 @@ def _convert_chain_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: def _convert_tool_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: base_span = _convert_run_to_wb_span(trace_tree, run) - base_span.results = [trace_tree.Result(inputs=run.inputs, outputs=run.outputs)] + base_span.results = [ + trace_tree.Result(inputs=_serialize_inputs(run.inputs), outputs=run.outputs) + ] base_span.child_spans = [ _convert_lc_run_to_wb_span(trace_tree, child_run) for child_run in run.child_runs