@ -1,3 +1,4 @@
import os
import random
import string
import tempfile
@ -127,22 +128,27 @@ class MlflowLogger:
def __init__ ( self , * * kwargs : Any ) :
self . mlflow = import_mlflow ( )
tracking_uri = get_from_dict_or_env (
kwargs , " tracking_uri " , " MLFLOW_TRACKING_URI " , " "
)
self . mlflow . set_tracking_uri ( tracking_uri )
if " DATABRICKS_RUNTIME_VERSION " in os . environ :
self . mlflow . set_tracking_uri ( " databricks " )
self . mlf_expid = self . mlflow . tracking . fluent . _get_experiment_id ( )
self . mlf_exp = self . mlflow . get_experiment ( self . mlf_expid )
else :
tracking_uri = get_from_dict_or_env (
kwargs , " tracking_uri " , " MLFLOW_TRACKING_URI " , " "
)
self . mlflow . set_tracking_uri ( tracking_uri )
# User can set other env variables described here
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
# User can set other env variables described here
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
experiment_name = get_from_dict_or_env (
kwargs , " experiment_name " , " MLFLOW_EXPERIMENT_NAME "
)
self . mlf_exp = self . mlflow . get_experiment_by_name ( experiment_name )
if self . mlf_exp is not None :
self . mlf_expid = self . mlf_exp . experiment_id
else :
self . mlf_expid = self . mlflow . create_experiment ( experiment_name )
experiment_name = get_from_dict_or_env (
kwargs , " experiment_name " , " MLFLOW_EXPERIMENT_NAME "
)
self . mlf_exp = self . mlflow . get_experiment_by_name ( experiment_name )
if self . mlf_exp is not None :
self . mlf_expid = self . mlf_exp . experiment_id
else :
self . mlf_expid = self . mlflow . create_experiment ( experiment_name )
self . start_run ( kwargs [ " run_name " ] , kwargs [ " run_tags " ] )