You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
2.9 KiB
Python
77 lines
2.9 KiB
Python
import torch
|
|
from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten
|
|
|
|
|
|
def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3):
|
|
"""Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass"""
|
|
assert not isinstance(callable, torch.nn.Module)
|
|
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
|
|
raise RuntimeError(
|
|
"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
|
|
)
|
|
|
|
flatten_arg, _ = _tree_flatten(sample_args)
|
|
flatten_sample_args = tuple(flatten_arg)
|
|
assert all(
|
|
isinstance(arg, torch.Tensor) for arg in flatten_arg
|
|
), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed."
|
|
|
|
len_user_args = len(sample_args)
|
|
static_input_surface = flatten_sample_args
|
|
|
|
graph = torch.cuda.CUDAGraph()
|
|
|
|
# Warmup
|
|
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
|
|
# from ending up in any captures.
|
|
s = torch.cuda.Stream()
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(s):
|
|
for _ in range(num_warmup_iters):
|
|
outputs, _ = _tree_flatten(callable(*sample_args))
|
|
del outputs
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
# Capture forward graph
|
|
with torch.cuda.graph(graph):
|
|
outputs = callable(*sample_args)
|
|
|
|
flatten_outputs, output_unflatten_spec = _tree_flatten(outputs)
|
|
static_outputs = tuple(flatten_outputs)
|
|
|
|
def make_graphed_function(
|
|
graph,
|
|
len_user_args,
|
|
output_unflatten_spec,
|
|
static_input_surface,
|
|
static_outputs,
|
|
):
|
|
def replay_graph(*inputs):
|
|
# At this stage, only the user args may (potentially) be new tensors.
|
|
for i in range(len_user_args):
|
|
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
|
|
static_input_surface[i].copy_(inputs[i])
|
|
graph.replay()
|
|
assert isinstance(static_outputs, tuple)
|
|
return tuple(o.detach() for o in static_outputs)
|
|
|
|
def functionalized(*user_args):
|
|
# Runs the autograd function with inputs == all inputs to the graph that might require grad
|
|
# (explicit user args + module parameters)
|
|
# Assumes module params didn't change since capture.
|
|
flatten_user_args, _ = _tree_flatten(user_args)
|
|
out = replay_graph(*flatten_user_args)
|
|
return _tree_unflatten(out, output_unflatten_spec)
|
|
|
|
return functionalized
|
|
|
|
# Put together the final graphed callable
|
|
graphed = make_graphed_function(
|
|
graph,
|
|
len_user_args,
|
|
output_unflatten_spec,
|
|
static_input_surface,
|
|
static_outputs,
|
|
)
|
|
return graphed
|