horovod_example.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # This file is duplicated in ray/tests/horovod
  2. import argparse
  3. import os
  4. from filelock import FileLock
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. import torch.optim as optim
  8. from torchvision import datasets, transforms
  9. import torch.utils.data.distributed
  10. import horovod.torch as hvd
  11. from horovod.ray import RayExecutor
  12. def metric_average(val, name):
  13. tensor = torch.tensor(val)
  14. avg_tensor = hvd.allreduce(tensor, name=name)
  15. return avg_tensor.item()
  16. class Net(nn.Module):
  17. def __init__(self):
  18. super(Net, self).__init__()
  19. self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
  20. self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
  21. self.conv2_drop = nn.Dropout2d()
  22. self.fc1 = nn.Linear(320, 50)
  23. self.fc2 = nn.Linear(50, 10)
  24. def forward(self, x):
  25. x = F.relu(F.max_pool2d(self.conv1(x), 2))
  26. x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
  27. x = x.view(-1, 320)
  28. x = F.relu(self.fc1(x))
  29. x = F.dropout(x, training=self.training)
  30. x = self.fc2(x)
  31. return F.log_softmax(x)
  32. def train_fn(
  33. data_dir=None,
  34. seed=42,
  35. use_cuda=False,
  36. batch_size=64,
  37. use_adasum=False,
  38. lr=0.01,
  39. momentum=0.5,
  40. num_epochs=10,
  41. log_interval=10,
  42. ):
  43. # Horovod: initialize library.
  44. hvd.init()
  45. torch.manual_seed(seed)
  46. if use_cuda:
  47. # Horovod: pin GPU to local rank.
  48. torch.cuda.set_device(hvd.local_rank())
  49. torch.cuda.manual_seed(seed)
  50. # Horovod: limit # of CPU threads to be used per worker.
  51. torch.set_num_threads(1)
  52. kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
  53. data_dir = data_dir or "./data"
  54. with FileLock(os.path.expanduser("~/.horovod_lock")):
  55. train_dataset = datasets.MNIST(
  56. data_dir,
  57. train=True,
  58. download=True,
  59. transform=transforms.Compose(
  60. [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
  61. ),
  62. )
  63. # Horovod: use DistributedSampler to partition the training data.
  64. train_sampler = torch.utils.data.distributed.DistributedSampler(
  65. train_dataset, num_replicas=hvd.size(), rank=hvd.rank()
  66. )
  67. train_loader = torch.utils.data.DataLoader(
  68. train_dataset, batch_size=batch_size, sampler=train_sampler, **kwargs
  69. )
  70. model = Net()
  71. # By default, Adasum doesn't need scaling up learning rate.
  72. lr_scaler = hvd.size() if not use_adasum else 1
  73. if use_cuda:
  74. # Move model to GPU.
  75. model.cuda()
  76. # If using GPU Adasum allreduce, scale learning rate by local_size.
  77. if use_adasum and hvd.nccl_built():
  78. lr_scaler = hvd.local_size()
  79. # Horovod: scale learning rate by lr_scaler.
  80. optimizer = optim.SGD(model.parameters(), lr=lr * lr_scaler, momentum=momentum)
  81. # Horovod: wrap optimizer with DistributedOptimizer.
  82. optimizer = hvd.DistributedOptimizer(
  83. optimizer,
  84. named_parameters=model.named_parameters(),
  85. op=hvd.Adasum if use_adasum else hvd.Average,
  86. )
  87. for epoch in range(1, num_epochs + 1):
  88. model.train()
  89. # Horovod: set epoch to sampler for shuffling.
  90. train_sampler.set_epoch(epoch)
  91. for batch_idx, (data, target) in enumerate(train_loader):
  92. if use_cuda:
  93. data, target = data.cuda(), target.cuda()
  94. optimizer.zero_grad()
  95. output = model(data)
  96. loss = F.nll_loss(output, target)
  97. loss.backward()
  98. optimizer.step()
  99. if batch_idx % log_interval == 0:
  100. # Horovod: use train_sampler to determine the number of
  101. # examples in this worker's partition.
  102. print(
  103. "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
  104. epoch,
  105. batch_idx * len(data),
  106. len(train_sampler),
  107. 100.0 * batch_idx / len(train_loader),
  108. loss.item(),
  109. )
  110. )
  111. def main(
  112. num_workers, use_gpu, timeout_s=30, placement_group_timeout_s=100, kwargs=None
  113. ):
  114. kwargs = kwargs or {}
  115. if use_gpu:
  116. kwargs["use_cuda"] = True
  117. settings = RayExecutor.create_settings(
  118. timeout_s=timeout_s, placement_group_timeout_s=placement_group_timeout_s
  119. )
  120. executor = RayExecutor(settings, use_gpu=use_gpu, num_workers=num_workers)
  121. executor.start()
  122. executor.run(train_fn, kwargs=kwargs)
  123. if __name__ == "__main__":
  124. # Training settings
  125. parser = argparse.ArgumentParser(
  126. description="PyTorch MNIST Example",
  127. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  128. )
  129. parser.add_argument(
  130. "--batch-size",
  131. type=int,
  132. default=64,
  133. metavar="N",
  134. help="input batch size for training (default: 64)",
  135. )
  136. parser.add_argument(
  137. "--num-epochs",
  138. type=int,
  139. default=5,
  140. metavar="N",
  141. help="number of epochs to train (default: 10)",
  142. )
  143. parser.add_argument(
  144. "--lr",
  145. type=float,
  146. default=0.01,
  147. metavar="LR",
  148. help="learning rate (default: 0.01)",
  149. )
  150. parser.add_argument(
  151. "--momentum",
  152. type=float,
  153. default=0.5,
  154. metavar="M",
  155. help="SGD momentum (default: 0.5)",
  156. )
  157. parser.add_argument(
  158. "--use-cuda", action="store_true", default=False, help="enables CUDA training"
  159. )
  160. parser.add_argument(
  161. "--seed", type=int, default=42, metavar="S", help="random seed (default: 42)"
  162. )
  163. parser.add_argument(
  164. "--log-interval",
  165. type=int,
  166. default=10,
  167. metavar="N",
  168. help="how many batches to wait before logging training status",
  169. )
  170. parser.add_argument(
  171. "--use-adasum",
  172. action="store_true",
  173. default=False,
  174. help="use adasum algorithm to do reduction",
  175. )
  176. parser.add_argument(
  177. "--num-workers",
  178. type=int,
  179. default=4,
  180. help="Number of Ray workers to use for training.",
  181. )
  182. parser.add_argument(
  183. "--data-dir",
  184. help="location of the training dataset in the local filesystem ("
  185. "will be downloaded if needed)",
  186. )
  187. parser.add_argument(
  188. "--address",
  189. required=False,
  190. type=str,
  191. default=None,
  192. help="Address of Ray cluster.",
  193. )
  194. args = parser.parse_args()
  195. import ray
  196. if args.address:
  197. ray.init(args.address)
  198. else:
  199. ray.init()
  200. kwargs = {
  201. "data_dir": args.data_dir,
  202. "seed": args.seed,
  203. "use_cuda": args.use_cuda if args.use_cuda else False,
  204. "batch_size": args.batch_size,
  205. "use_adasum": args.use_adasum if args.use_adasum else False,
  206. "lr": args.lr,
  207. "momentum": args.momentum,
  208. "num_epochs": args.num_epochs,
  209. "log_interval": args.log_interval,
  210. }
  211. main(
  212. num_workers=args.num_workers,
  213. use_gpu=args.use_cuda if args.use_cuda else False,
  214. kwargs=kwargs,
  215. )