multi_output_model.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. class MultiOutputModel(torch.nn.Module):
  6. def __init__(self, hidden_dim, weight_value):
  7. super(MultiOutputModel, self).__init__()
  8. self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
  9. self.linear.weight.data.fill_(weight_value)
  10. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  11. def forward(self, inputs, targets):
  12. losses = []
  13. for x, y in zip(inputs, targets):
  14. hidden_dim = self.linear(x)
  15. loss = self.cross_entropy_loss(hidden_dim, y)
  16. losses.append(loss)
  17. return tuple(losses)
  18. def multi_output_dataloader(model, total_samples, hidden_dim, device, inputs, targets):
  19. assert len(inputs) == len(targets)
  20. batch_size = model.train_micro_batch_size_per_gpu()
  21. train_data = [
  22. torch.full(size=(total_samples, hidden_dim), fill_value=x, device=device, dtype=torch.half, requires_grad=True)
  23. for x in inputs
  24. ]
  25. train_label = [torch.empty(total_samples, device=device, dtype=torch.long).fill_(y) for y in targets]
  26. train_dataset = torch.utils.data.TensorDataset(*train_data, *train_label)
  27. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
  28. return train_loader