design doc

sequence
justheuristic 2 years ago
parent 0613095168
commit 3b6d94ffef

@ -33,5 +33,15 @@
# # [IMPORTANT] maybe first create an op for one batch, then a wrapper that split into batches
# return torch.cat(output_batches, dim=0)
#
# def backward(ctx, grad_outputs):
# return TODO(ctx, )
# def backward(ctx, *grad_outputs: torch.Tensor):
# return TODO_think_through(ctx, *grad_outputs)
# class RemoteSequentialInferenceSession:
# """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
# # form sequence, maintain past inputs for all stages; if any stage breaks, fix the same way as in forward:
# # patch = RemoteSequenceManager[spans[BROKEN_INDEX].start : spans[BROKEN_INDEX].end].form_sequence()
# # feed_past_key_values(into=patch, inputs=past_key_values[BROKEN_INDEX])
# # self.TODO_replace_in_metadata(remove=BROKEN_INDEX, replace_with=patch)
# #
# TODO_think()
#

Loading…
Cancel
Save