Fix serialization issue with W&B (#5693)

The chain input_documents are not displaying properly in W&B, due to
serialization issue:

<img width="1164" alt="Screenshot 2023-06-04 at 11 58 26 AM"
src="https://github.com/hwchase17/langchain/assets/134809928/f31f14f6-0935-4cca-9913-6760cd40eadf">

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
jjzhuo 2023-06-07 20:44:59 -07:00 committed by GitHub
parent ec0dd6e34a
commit 78aa59c68b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -60,10 +60,20 @@ def _convert_llm_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
return base_span return base_span
def _serialize_inputs(run_inputs: dict) -> Union[dict, list]:
if "input_documents" in run_inputs:
docs = run_inputs["input_documents"]
return [doc.json() for doc in docs]
else:
return run_inputs
def _convert_chain_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: def _convert_chain_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
base_span = _convert_run_to_wb_span(trace_tree, run) base_span = _convert_run_to_wb_span(trace_tree, run)
base_span.results = [trace_tree.Result(inputs=run.inputs, outputs=run.outputs)] base_span.results = [
trace_tree.Result(inputs=_serialize_inputs(run.inputs), outputs=run.outputs)
]
base_span.child_spans = [ base_span.child_spans = [
_convert_lc_run_to_wb_span(trace_tree, child_run) _convert_lc_run_to_wb_span(trace_tree, child_run)
for child_run in run.child_runs for child_run in run.child_runs
@ -79,7 +89,9 @@ def _convert_chain_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
def _convert_tool_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span: def _convert_tool_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
base_span = _convert_run_to_wb_span(trace_tree, run) base_span = _convert_run_to_wb_span(trace_tree, run)
base_span.results = [trace_tree.Result(inputs=run.inputs, outputs=run.outputs)] base_span.results = [
trace_tree.Result(inputs=_serialize_inputs(run.inputs), outputs=run.outputs)
]
base_span.child_spans = [ base_span.child_spans = [
_convert_lc_run_to_wb_span(trace_tree, child_run) _convert_lc_run_to_wb_span(trace_tree, child_run)
for child_run in run.child_runs for child_run in run.child_runs