torch_example.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # flake8: noqa
  2. """
  3. This file holds code for the Torch best-practices guide in the documentation.
  4. It ignores yapf because yapf doesn't allow comments right after code blocks,
  5. but we put comments right after code blocks to prevent large white spaces
  6. in the documentation.
  7. """
  8. # yapf: disable
  9. if __name__ == "__main__":
  10. # temporarily disable due to mnist outage
  11. import sys
  12. sys.exit(0)
  13. # __torch_model_start__
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. class Model(nn.Module):
  18. def __init__(self):
  19. super(Model, self).__init__()
  20. self.conv1 = nn.Conv2d(1, 20, 5, 1)
  21. self.conv2 = nn.Conv2d(20, 50, 5, 1)
  22. self.fc1 = nn.Linear(4 * 4 * 50, 500)
  23. self.fc2 = nn.Linear(500, 10)
  24. def forward(self, x):
  25. x = F.relu(self.conv1(x))
  26. x = F.max_pool2d(x, 2, 2)
  27. x = F.relu(self.conv2(x))
  28. x = F.max_pool2d(x, 2, 2)
  29. x = x.view(-1, 4 * 4 * 50)
  30. x = F.relu(self.fc1(x))
  31. x = self.fc2(x)
  32. return F.log_softmax(x, dim=1)
  33. # __torch_model_end__
  34. # yapf: enable
  35. # yapf: disable
  36. # __torch_helper_start__
  37. from filelock import FileLock
  38. from torchvision import datasets, transforms
  39. def train(model, device, train_loader, optimizer):
  40. model.train()
  41. for batch_idx, (data, target) in enumerate(train_loader):
  42. # This break is for speeding up the tutorial.
  43. if batch_idx * len(data) > 1024:
  44. return
  45. data, target = data.to(device), target.to(device)
  46. optimizer.zero_grad()
  47. output = model(data)
  48. loss = F.nll_loss(output, target)
  49. loss.backward()
  50. optimizer.step()
  51. def test(model, device, test_loader):
  52. model.eval()
  53. test_loss = 0
  54. correct = 0
  55. with torch.no_grad():
  56. for data, target in test_loader:
  57. data, target = data.to(device), target.to(device)
  58. output = model(data)
  59. # sum up batch loss
  60. test_loss += F.nll_loss(
  61. output, target, reduction="sum").item()
  62. pred = output.argmax(
  63. dim=1,
  64. keepdim=True)
  65. correct += pred.eq(target.view_as(pred)).sum().item()
  66. test_loss /= len(test_loader.dataset)
  67. return {
  68. "loss": test_loss,
  69. "accuracy": 100. * correct / len(test_loader.dataset)
  70. }
  71. def dataset_creator(use_cuda):
  72. kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
  73. with FileLock("./data.lock"):
  74. train_loader = torch.utils.data.DataLoader(
  75. datasets.MNIST(
  76. "./data",
  77. train=True,
  78. download=True,
  79. transform=transforms.Compose([
  80. transforms.ToTensor(),
  81. transforms.Normalize((0.1307, ), (0.3081, ))
  82. ])),
  83. 128,
  84. shuffle=True,
  85. **kwargs)
  86. test_loader = torch.utils.data.DataLoader(
  87. datasets.MNIST(
  88. "./data",
  89. train=False,
  90. transform=transforms.Compose([
  91. transforms.ToTensor(),
  92. transforms.Normalize((0.1307, ), (0.3081, ))
  93. ])),
  94. 128,
  95. shuffle=True,
  96. **kwargs)
  97. return train_loader, test_loader
  98. # __torch_helper_end__
  99. # yapf: enable
  100. # yapf: disable
  101. # __torch_net_start__
  102. import torch.optim as optim
  103. class Network(object):
  104. def __init__(self, lr=0.01, momentum=0.5):
  105. use_cuda = torch.cuda.is_available()
  106. self.device = device = torch.device("cuda" if use_cuda else "cpu")
  107. self.train_loader, self.test_loader = dataset_creator(use_cuda)
  108. self.model = Model().to(device)
  109. self.optimizer = optim.SGD(
  110. self.model.parameters(), lr=lr, momentum=momentum)
  111. def train(self):
  112. train(self.model, self.device, self.train_loader, self.optimizer)
  113. return test(self.model, self.device, self.test_loader)
  114. def get_weights(self):
  115. return self.model.state_dict()
  116. def set_weights(self, weights):
  117. self.model.load_state_dict(weights)
  118. def save(self):
  119. torch.save(self.model.state_dict(), "mnist_cnn.pt")
  120. net = Network()
  121. net.train()
  122. # __torch_net_end__
  123. # yapf: enable
  124. # yapf: disable
  125. # __torch_ray_start__
  126. import ray
  127. ray.init()
  128. RemoteNetwork = ray.remote(Network)
  129. # Use the below instead of `ray.remote(network)` to leverage the GPU.
  130. # RemoteNetwork = ray.remote(num_gpus=1)(Network)
  131. # __torch_ray_end__
  132. # yapf: enable
  133. # yapf: disable
  134. # __torch_actor_start__
  135. NetworkActor = RemoteNetwork.remote()
  136. NetworkActor2 = RemoteNetwork.remote()
  137. ray.get([NetworkActor.train.remote(), NetworkActor2.train.remote()])
  138. # __torch_actor_end__
  139. # yapf: enable
  140. # yapf: disable
  141. # __weight_average_start__
  142. weights = ray.get(
  143. [NetworkActor.get_weights.remote(),
  144. NetworkActor2.get_weights.remote()])
  145. from collections import OrderedDict
  146. averaged_weights = OrderedDict(
  147. [(k, (weights[0][k] + weights[1][k]) / 2) for k in weights[0]])
  148. weight_id = ray.put(averaged_weights)
  149. [
  150. actor.set_weights.remote(weight_id)
  151. for actor in [NetworkActor, NetworkActor2]
  152. ]
  153. ray.get([actor.train.remote() for actor in [NetworkActor, NetworkActor2]])