Added Databricks support to MLflow Callback (#7906)

Added a quick check to make integration easier with Databricks; another
option would be to make a new class, but this seemed more
straightfoward.

cc: @liangz1 Can this be done in a more straightfoward way?
pull/4403/head^2
Rithwik Ediga Lakhamsani 1 year ago committed by GitHub
parent 479cc086ba
commit d1d691caa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,4 @@
import os
import random import random
import string import string
import tempfile import tempfile
@ -127,22 +128,27 @@ class MlflowLogger:
def __init__(self, **kwargs: Any): def __init__(self, **kwargs: Any):
self.mlflow = import_mlflow() self.mlflow = import_mlflow()
tracking_uri = get_from_dict_or_env( if "DATABRICKS_RUNTIME_VERSION" in os.environ:
kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", "" self.mlflow.set_tracking_uri("databricks")
) self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
self.mlflow.set_tracking_uri(tracking_uri) self.mlf_exp = self.mlflow.get_experiment(self.mlf_expid)
else:
tracking_uri = get_from_dict_or_env(
kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", ""
)
self.mlflow.set_tracking_uri(tracking_uri)
# User can set other env variables described here # User can set other env variables described here
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server # > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
experiment_name = get_from_dict_or_env( experiment_name = get_from_dict_or_env(
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME" kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
) )
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name) self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
if self.mlf_exp is not None: if self.mlf_exp is not None:
self.mlf_expid = self.mlf_exp.experiment_id self.mlf_expid = self.mlf_exp.experiment_id
else: else:
self.mlf_expid = self.mlflow.create_experiment(experiment_name) self.mlf_expid = self.mlflow.create_experiment(experiment_name)
self.start_run(kwargs["run_name"], kwargs["run_tags"]) self.start_run(kwargs["run_name"], kwargs["run_tags"])

Loading…
Cancel
Save