multi_output_model.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import os
  2. import json
  3. import argparse
  4. import torch
  5. class MultiOutputModel(torch.nn.Module):
  6. def __init__(self, hidden_dim, weight_value):
  7. super(MultiOutputModel, self).__init__()
  8. self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
  9. self.linear.weight.data.fill_(weight_value)
  10. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  11. def forward(self, inputs, targets):
  12. losses = []
  13. for x, y in zip(inputs, targets):
  14. hidden_dim = self.linear(x)
  15. loss = self.cross_entropy_loss(hidden_dim, y)
  16. losses.append(loss)
  17. return tuple(losses)
  18. def multi_output_dataloader(model, total_samples, hidden_dim, device, inputs, targets):
  19. assert len(inputs) == len(targets)
  20. batch_size = model.train_micro_batch_size_per_gpu()
  21. train_data = [
  22. torch.full(size=(total_samples,
  23. hidden_dim),
  24. fill_value=x,
  25. device=device,
  26. dtype=torch.half,
  27. requires_grad=True) for x in inputs
  28. ]
  29. train_label = [
  30. torch.empty(total_samples,
  31. device=device,
  32. dtype=torch.long).fill_(y) for y in targets
  33. ]
  34. train_dataset = torch.utils.data.TensorDataset(*train_data, *train_label)
  35. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
  36. return train_loader