multi_output_model.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from .common import preferred_dtype
  6. class MultiOutputModel(torch.nn.Module):
  7. def __init__(self, hidden_dim, weight_value):
  8. super(MultiOutputModel, self).__init__()
  9. self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
  10. self.linear.weight.data.fill_(weight_value)
  11. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  12. def forward(self, inputs, targets):
  13. losses = []
  14. for x, y in zip(inputs, targets):
  15. hidden_dim = self.linear(x)
  16. loss = self.cross_entropy_loss(hidden_dim, y)
  17. losses.append(loss)
  18. return tuple(losses)
  19. def multi_output_dataloader(model, total_samples, hidden_dim, device, inputs, targets):
  20. assert len(inputs) == len(targets)
  21. batch_size = model.train_micro_batch_size_per_gpu()
  22. train_data = [
  23. torch.full(size=(total_samples, hidden_dim),
  24. fill_value=x,
  25. device=device,
  26. dtype=preferred_dtype(),
  27. requires_grad=True) for x in inputs
  28. ]
  29. train_label = [torch.empty(total_samples, device=device, dtype=torch.long).fill_(y) for y in targets]
  30. train_dataset = torch.utils.data.TensorDataset(*train_data, *train_label)
  31. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
  32. return train_loader