@ -179,6 +179,9 @@ class ImageLoggingContext:
self.summary_context.stop()
self.timing_contexts["total"] = self.summary_context
# move total to the end
self.timing_contexts["total"] = self.timing_contexts.pop("total")
if torch.cuda.is_available():
self.summary_context.memory_peak = max(
max(context.memory_peak, context.memory_start, context.memory_end)