|
|
|
@ -33,5 +33,15 @@
|
|
|
|
|
# # [IMPORTANT] maybe first create an op for one batch, then a wrapper that split into batches
|
|
|
|
|
# return torch.cat(output_batches, dim=0)
|
|
|
|
|
#
|
|
|
|
|
# def backward(ctx, grad_outputs):
|
|
|
|
|
# return TODO(ctx, )
|
|
|
|
|
# def backward(ctx, *grad_outputs: torch.Tensor):
|
|
|
|
|
# return TODO_think_through(ctx, *grad_outputs)
|
|
|
|
|
|
|
|
|
|
# class RemoteSequentialInferenceSession:
|
|
|
|
|
# """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
|
|
|
|
# # form sequence, maintain past inputs for all stages; if any stage breaks, fix the same way as in forward:
|
|
|
|
|
# # patch = RemoteSequenceManager[spans[BROKEN_INDEX].start : spans[BROKEN_INDEX].end].form_sequence()
|
|
|
|
|
# # feed_past_key_values(into=patch, inputs=past_key_values[BROKEN_INDEX])
|
|
|
|
|
# # self.TODO_replace_in_metadata(remove=BROKEN_INDEX, replace_with=patch)
|
|
|
|
|
# #
|
|
|
|
|
# TODO_think()
|
|
|
|
|
#
|
|
|
|
|