community[patch]: fix CometTracer bug (#20796)

Hi! My name is Alex, I'm an SDK engineer from
[Comet](https://www.comet.com/site/)

This PR updates the `CometTracer` class.

Fixed an issue when `CometTracer` failed while logging the data to Comet
because this data is not JSON-encodable.

The problem was in some of the `Run` attributes that could contain
non-default types inside, now these attributes are taken not from the
run instance, but from the `run.dict()` return value.
pull/20801/head
Aliaksandr Kuzmik 2 months ago committed by GitHub
parent 1c89e45c14
commit 5560cc448c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -70,24 +70,26 @@ class CometTracer(BaseTracer):
self._flush: Callable[[], None] = comet_llm_api.flush self._flush: Callable[[], None] = comet_llm_api.flush
def _persist_run(self, run: "Run") -> None: def _persist_run(self, run: "Run") -> None:
run_dict: Dict[str, Any] = run.dict()
chain_ = self._chains_map[run.id] chain_ = self._chains_map[run.id]
chain_.set_outputs(outputs=run.outputs) chain_.set_outputs(outputs=run_dict["outputs"])
self._chain_api.log_chain(chain_) self._chain_api.log_chain(chain_)
def _process_start_trace(self, run: "Run") -> None: def _process_start_trace(self, run: "Run") -> None:
run_dict: Dict[str, Any] = run.dict()
if not run.parent_run_id: if not run.parent_run_id:
# This is the first run, which maps to a chain # This is the first run, which maps to a chain
chain_: "Chain" = self._chain.Chain( chain_: "Chain" = self._chain.Chain(
inputs=run.inputs, inputs=run_dict["inputs"],
metadata=None, metadata=None,
experiment_info=self._experiment_info.get(), experiment_info=self._experiment_info.get(),
) )
self._chains_map[run.id] = chain_ self._chains_map[run.id] = chain_
else: else:
span: "Span" = self._span.Span( span: "Span" = self._span.Span(
inputs=run.inputs, inputs=run_dict["inputs"],
category=_get_run_type(run), category=_get_run_type(run),
metadata=run.extra, metadata=run_dict["extra"],
name=run.name, name=run.name,
) )
span.__api__start__(self._chains_map[run.parent_run_id]) span.__api__start__(self._chains_map[run.parent_run_id])
@ -95,12 +97,13 @@ class CometTracer(BaseTracer):
self._span_map[run.id] = span self._span_map[run.id] = span
def _process_end_trace(self, run: "Run") -> None: def _process_end_trace(self, run: "Run") -> None:
run_dict: Dict[str, Any] = run.dict()
if not run.parent_run_id: if not run.parent_run_id:
pass pass
# Langchain will call _persist_run for us # Langchain will call _persist_run for us
else: else:
span = self._span_map[run.id] span = self._span_map[run.id]
span.set_outputs(outputs=run.outputs) span.set_outputs(outputs=run_dict["outputs"])
span.__api__end__() span.__api__end__()
def flush(self) -> None: def flush(self) -> None:

Loading…
Cancel
Save