diff --git a/langchain/callbacks/comet_ml_callback.py b/langchain/callbacks/comet_ml_callback.py index f057cea84f..7917d26a50 100644 --- a/langchain/callbacks/comet_ml_callback.py +++ b/langchain/callbacks/comet_ml_callback.py @@ -96,7 +96,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self, task_type: Optional[str] = "inference", workspace: Optional[str] = None, - project_name: Optional[str] = "comet-langchain-demo", + project_name: Optional[str] = None, tags: Optional[Sequence] = None, name: Optional[str] = None, visualizations: Optional[List[str]] = None, @@ -106,7 +106,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): ) -> None: """Initialize callback handler.""" - comet_ml = import_comet_ml() + self.comet_ml = import_comet_ml() super().__init__() self.task_type = task_type @@ -133,7 +133,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): "https://github.com/comet-ml/issue_tracking/issues with the tag " "`langchain`." ) - comet_ml.LOGGER.warning(warning) + self.comet_ml.LOGGER.warning(warning) self.callback_columns: list = [] self.action_records: list = [] @@ -242,8 +242,6 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) - comet_ml = import_comet_ml() - for chain_input_key, chain_input_val in inputs.items(): if isinstance(chain_input_val, str): input_resp = deepcopy(resp) @@ -253,7 +251,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.action_records.append(input_resp) else: - comet_ml.LOGGER.warning( + self.comet_ml.LOGGER.warning( f"Unexpected data format provided! " f"Input Value for {chain_input_key} will not be logged" ) @@ -268,8 +266,6 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): resp.update({"action": "on_chain_end"}) resp.update(self.get_custom_callback_meta()) - comet_ml = import_comet_ml() - for chain_output_key, chain_output_val in outputs.items(): if isinstance(chain_output_val, str): output_resp = deepcopy(resp) @@ -278,7 +274,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): output_resp.update({chain_output_key: chain_output_val}) self.action_records.append(output_resp) else: - comet_ml.LOGGER.warning( + self.comet_ml.LOGGER.warning( f"Unexpected data format provided! " f"Output Value for {chain_output_key} will not be logged" ) @@ -449,7 +445,14 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self._log_session(langchain_asset) if langchain_asset: - self._log_model(langchain_asset) + try: + self._log_model(langchain_asset) + except Exception: + self.comet_ml.LOGGER.error( + "Failed to export agent or LLM to Comet", + exc_info=True, + extra={"show_traceback": True}, + ) if finish: self.experiment.end() @@ -470,8 +473,6 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.experiment.log_text(prompt, metadata=metadata, step=step) def _log_model(self, langchain_asset: Any) -> None: - comet_ml = import_comet_ml() - model_parameters = self._get_llm_parameters(langchain_asset) self.experiment.log_parameters(model_parameters, prefix="model") @@ -487,24 +488,45 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): langchain_asset.save_agent(langchain_asset_path) self.experiment.log_model(model_name, str(langchain_asset_path)) else: - comet_ml.LOGGER.warning( + self.comet_ml.LOGGER.error( f"{e}" " Could not save Langchain Asset " f"for {langchain_asset.__class__.__name__}" ) def _log_session(self, langchain_asset: Optional[Any] = None) -> None: - llm_session_df = self._create_session_analysis_dataframe(langchain_asset) - # Log the cleaned dataframe as a table - self.experiment.log_table("langchain-llm-session.csv", llm_session_df) + try: + llm_session_df = self._create_session_analysis_dataframe(langchain_asset) + # Log the cleaned dataframe as a table + self.experiment.log_table("langchain-llm-session.csv", llm_session_df) + except Exception: + self.comet_ml.LOGGER.warning( + "Failed to log session data to Comet", + exc_info=True, + extra={"show_traceback": True}, + ) - metadata = {"langchain_version": str(langchain.__version__)} - # Log the langchain low-level records as a JSON file directly - self.experiment.log_asset_data( - self.action_records, "langchain-action_records.json", metadata=metadata - ) + try: + metadata = {"langchain_version": str(langchain.__version__)} + # Log the langchain low-level records as a JSON file directly + self.experiment.log_asset_data( + self.action_records, "langchain-action_records.json", metadata=metadata + ) + except Exception: + self.comet_ml.LOGGER.warning( + "Failed to log session data to Comet", + exc_info=True, + extra={"show_traceback": True}, + ) - self._log_visualizations(llm_session_df) + try: + self._log_visualizations(llm_session_df) + except Exception: + self.comet_ml.LOGGER.warning( + "Failed to log visualizations to Comet", + exc_info=True, + extra={"show_traceback": True}, + ) def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None: if not metrics: @@ -519,7 +541,6 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): return spacy = import_spacy() - comet_ml = import_comet_ml() prompts = session_df["prompts"].tolist() outputs = session_df["text"].tolist() @@ -544,7 +565,9 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): step=idx, ) except Exception as e: - comet_ml.LOGGER.warning(e) + self.comet_ml.LOGGER.warning( + e, exc_info=True, extra={"show_traceback": True} + ) return