diff --git a/libs/langchain/langchain/callbacks/mlflow_callback.py b/libs/langchain/langchain/callbacks/mlflow_callback.py index 553404b07f..baa297af54 100644 --- a/libs/langchain/langchain/callbacks/mlflow_callback.py +++ b/libs/langchain/langchain/callbacks/mlflow_callback.py @@ -1,3 +1,4 @@ +import os import random import string import tempfile @@ -127,22 +128,27 @@ class MlflowLogger: def __init__(self, **kwargs: Any): self.mlflow = import_mlflow() - tracking_uri = get_from_dict_or_env( - kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", "" - ) - self.mlflow.set_tracking_uri(tracking_uri) + if "DATABRICKS_RUNTIME_VERSION" in os.environ: + self.mlflow.set_tracking_uri("databricks") + self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id() + self.mlf_exp = self.mlflow.get_experiment(self.mlf_expid) + else: + tracking_uri = get_from_dict_or_env( + kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", "" + ) + self.mlflow.set_tracking_uri(tracking_uri) - # User can set other env variables described here - # > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server + # User can set other env variables described here + # > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server - experiment_name = get_from_dict_or_env( - kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME" - ) - self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name) - if self.mlf_exp is not None: - self.mlf_expid = self.mlf_exp.experiment_id - else: - self.mlf_expid = self.mlflow.create_experiment(experiment_name) + experiment_name = get_from_dict_or_env( + kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME" + ) + self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name) + if self.mlf_exp is not None: + self.mlf_expid = self.mlf_exp.experiment_id + else: + self.mlf_expid = self.mlflow.create_experiment(experiment_name) self.start_run(kwargs["run_name"], kwargs["run_tags"])