pass args/kwargs via forward

pull/467/head
Your Name 9 months ago
parent 62e780c054
commit aacd8b2f9d

@ -49,13 +49,13 @@ class RemoteSequential(nn.Module):
self._active_session = ContextVar("active_session", default=None)
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, *args, **kwargs) -> torch.Tensor:
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
if self.active_session is None:
assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args, **kwargs)
else:
return self.active_session.step(inputs, prompts, **kwargs)
return self.active_session.step(inputs, prompts, *args, **kwargs)
@property
def active_session(self) -> Optional[InferenceSession]:

Loading…
Cancel
Save