|
|
|
@ -49,13 +49,13 @@ class RemoteSequential(nn.Module):
|
|
|
|
|
|
|
|
|
|
self._active_session = ContextVar("active_session", default=None)
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
|
|
|
|
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, *args, **kwargs) -> torch.Tensor:
|
|
|
|
|
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
|
|
|
|
|
if self.active_session is None:
|
|
|
|
|
assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
|
|
|
|
|
return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
|
|
|
|
|
return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return self.active_session.step(inputs, prompts, **kwargs)
|
|
|
|
|
return self.active_session.step(inputs, prompts, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def active_session(self) -> Optional[InferenceSession]:
|
|
|
|
|