pytorch_tutorials_hyperparameter_tuning_tutorial.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. # -*- coding: utf-8 -*-
  2. # flake8: noqa
  3. """
  4. Hyperparameter tuning with Ray Tune
  5. ===================================
  6. Hyperparameter tuning can make the difference between an average model and a highly
  7. accurate one. Often simple things like choosing a different learning rate or changing
  8. a network layer size can have a dramatic impact on your model performance.
  9. Fortunately, there are tools that help with finding the best combination of parameters.
  10. `Ray Tune <https://docs.ray.io/en/latest/tune.html>`_ is an industry standard tool for
  11. distributed hyperparameter tuning. Ray Tune includes the latest hyperparameter search
  12. algorithms, integrates with TensorBoard and other analysis libraries, and natively
  13. supports distributed training through `Ray's distributed machine learning engine
  14. <https://ray.io/>`_.
  15. In this tutorial, we will show you how to integrate Ray Tune into your PyTorch
  16. training workflow. We will extend `this tutorial from the PyTorch documentation
  17. <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_ for training
  18. a CIFAR10 image classifier.
  19. As you will see, we only need to add some slight modifications. In particular, we
  20. need to
  21. 1. wrap data loading and training in functions,
  22. 2. make some network parameters configurable,
  23. 3. add checkpointing (optional),
  24. 4. and define the search space for the model tuning
  25. |
  26. To run this tutorial, please make sure the following packages are
  27. installed:
  28. - ``ray[tune]``: Distributed hyperparameter tuning library
  29. - ``torchvision``: For the data transformers
  30. Setup / Imports
  31. ---------------
  32. Let's start with the imports:
  33. """
  34. from functools import partial
  35. import os
  36. import torch
  37. import torch.nn as nn
  38. import torch.nn.functional as F
  39. import torch.optim as optim
  40. from torch.utils.data import random_split
  41. import torchvision
  42. import torchvision.transforms as transforms
  43. from ray import tune
  44. from ray.air import Checkpoint, session
  45. from ray.tune.schedulers import ASHAScheduler
  46. ######################################################################
  47. # Most of the imports are needed for building the PyTorch model. Only the last three
  48. # imports are for Ray Tune.
  49. #
  50. # Data loaders
  51. # ------------
  52. # We wrap the data loaders in their own function and pass a global data directory.
  53. # This way we can share a data directory between different trials.
  54. def load_data(data_dir="./data"):
  55. transform = transforms.Compose(
  56. [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
  57. )
  58. trainset = torchvision.datasets.CIFAR10(
  59. root=data_dir, train=True, download=True, transform=transform
  60. )
  61. testset = torchvision.datasets.CIFAR10(
  62. root=data_dir, train=False, download=True, transform=transform
  63. )
  64. return trainset, testset
  65. ######################################################################
  66. # Configurable neural network
  67. # ---------------------------
  68. # We can only tune those parameters that are configurable.
  69. # In this example, we can specify
  70. # the layer sizes of the fully connected layers:
  71. class Net(nn.Module):
  72. def __init__(self, l1=120, l2=84):
  73. super(Net, self).__init__()
  74. self.conv1 = nn.Conv2d(3, 6, 5)
  75. self.pool = nn.MaxPool2d(2, 2)
  76. self.conv2 = nn.Conv2d(6, 16, 5)
  77. self.fc1 = nn.Linear(16 * 5 * 5, l1)
  78. self.fc2 = nn.Linear(l1, l2)
  79. self.fc3 = nn.Linear(l2, 10)
  80. def forward(self, x):
  81. x = self.pool(F.relu(self.conv1(x)))
  82. x = self.pool(F.relu(self.conv2(x)))
  83. x = torch.flatten(x, 1) # flatten all dimensions except batch
  84. x = F.relu(self.fc1(x))
  85. x = F.relu(self.fc2(x))
  86. x = self.fc3(x)
  87. return x
  88. ######################################################################
  89. # The train function
  90. # ------------------
  91. # Now it gets interesting, because we introduce some changes to the example `from the PyTorch
  92. # documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_.
  93. #
  94. # We wrap the training script in a function ``train_cifar(config, data_dir=None)``.
  95. # The ``config`` parameter will receive the hyperparameters we would like to
  96. # train with. The ``data_dir`` specifies the directory where we load and store the data,
  97. # so that multiple runs can share the same data source.
  98. # We also load the model and optimizer state at the start of the run, if a checkpoint
  99. # is provided. Further down in this tutorial you will find information on how
  100. # to save the checkpoint and what it is used for.
  101. #
  102. # .. code-block:: python
  103. #
  104. # net = Net(config["l1"], config["l2"])
  105. #
  106. # checkpoint = session.get_checkpoint()
  107. #
  108. # if checkpoint:
  109. # checkpoint_state = checkpoint.to_dict()
  110. # start_epoch = checkpoint_state["epoch"]
  111. # net.load_state_dict(checkpoint_state["net_state_dict"])
  112. # optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
  113. # else:
  114. # start_epoch = 0
  115. #
  116. # The learning rate of the optimizer is made configurable, too:
  117. #
  118. # .. code-block:: python
  119. #
  120. # optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
  121. #
  122. # We also split the training data into a training and validation subset. We thus train on
  123. # 80% of the data and calculate the validation loss on the remaining 20%. The batch sizes
  124. # with which we iterate through the training and test sets are configurable as well.
  125. #
  126. # Adding (multi) GPU support with DataParallel
  127. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  128. # Image classification benefits largely from GPUs. Luckily, we can continue to use
  129. # PyTorch's abstractions in Ray Tune. Thus, we can wrap our model in ``nn.DataParallel``
  130. # to support data parallel training on multiple GPUs:
  131. #
  132. # .. code-block:: python
  133. #
  134. # device = "cpu"
  135. # if torch.cuda.is_available():
  136. # device = "cuda:0"
  137. # if torch.cuda.device_count() > 1:
  138. # net = nn.DataParallel(net)
  139. # net.to(device)
  140. #
  141. # By using a ``device`` variable we make sure that training also works when we have
  142. # no GPUs available. PyTorch requires us to send our data to the GPU memory explicitly,
  143. # like this:
  144. #
  145. # .. code-block:: python
  146. #
  147. # for i, data in enumerate(trainloader, 0):
  148. # inputs, labels = data
  149. # inputs, labels = inputs.to(device), labels.to(device)
  150. #
  151. # The code now supports training on CPUs, on a single GPU, and on multiple GPUs. Notably, Ray
  152. # also supports `fractional GPUs <https://docs.ray.io/en/master/using-ray-with-gpus.html#fractional-gpus>`_
  153. # so we can share GPUs among trials, as long as the model still fits on the GPU memory. We'll come back
  154. # to that later.
  155. #
  156. # Communicating with Ray Tune
  157. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
  158. #
  159. # The most interesting part is the communication with Ray Tune:
  160. #
  161. # .. code-block:: python
  162. #
  163. # checkpoint_data = {
  164. # "epoch": epoch,
  165. # "net_state_dict": net.state_dict(),
  166. # "optimizer_state_dict": optimizer.state_dict(),
  167. # }
  168. # checkpoint = Checkpoint.from_dict(checkpoint_data)
  169. #
  170. # session.report(
  171. # {"loss": val_loss / val_steps, "accuracy": correct / total},
  172. # checkpoint=checkpoint,
  173. # )
  174. #
  175. # Here we first save a checkpoint and then report some metrics back to Ray Tune. Specifically,
  176. # we send the validation loss and accuracy back to Ray Tune. Ray Tune can then use these metrics
  177. # to decide which hyperparameter configuration lead to the best results. These metrics
  178. # can also be used to stop bad performing trials early in order to avoid wasting
  179. # resources on those trials.
  180. #
  181. # The checkpoint saving is optional, however, it is necessary if we wanted to use advanced
  182. # schedulers like
  183. # `Population Based Training <https://docs.ray.io/en/master/tune/tutorials/tune-advanced-tutorial.html>`_.
  184. # Also, by saving the checkpoint we can later load the trained models and validate them
  185. # on a test set. Lastly, saving checkpoints is useful for fault tolerance, and it allows
  186. # us to interrupt training and continue training later.
  187. #
  188. # Full training function
  189. # ~~~~~~~~~~~~~~~~~~~~~~
  190. #
  191. # The full code example looks like this:
  192. def train_cifar(config, data_dir=None):
  193. net = Net(config["l1"], config["l2"])
  194. device = "cpu"
  195. if torch.cuda.is_available():
  196. device = "cuda:0"
  197. if torch.cuda.device_count() > 1:
  198. net = nn.DataParallel(net)
  199. net.to(device)
  200. criterion = nn.CrossEntropyLoss()
  201. optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
  202. checkpoint = session.get_checkpoint()
  203. if checkpoint:
  204. checkpoint_state = checkpoint.to_dict()
  205. start_epoch = checkpoint_state["epoch"]
  206. net.load_state_dict(checkpoint_state["net_state_dict"])
  207. optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
  208. else:
  209. start_epoch = 0
  210. trainset, testset = load_data(data_dir)
  211. test_abs = int(len(trainset) * 0.8)
  212. train_subset, val_subset = random_split(
  213. trainset, [test_abs, len(trainset) - test_abs]
  214. )
  215. trainloader = torch.utils.data.DataLoader(
  216. train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
  217. )
  218. valloader = torch.utils.data.DataLoader(
  219. val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
  220. )
  221. for epoch in range(start_epoch, 10): # loop over the dataset multiple times
  222. running_loss = 0.0
  223. epoch_steps = 0
  224. for i, data in enumerate(trainloader, 0):
  225. # get the inputs; data is a list of [inputs, labels]
  226. inputs, labels = data
  227. inputs, labels = inputs.to(device), labels.to(device)
  228. # zero the parameter gradients
  229. optimizer.zero_grad()
  230. # forward + backward + optimize
  231. outputs = net(inputs)
  232. loss = criterion(outputs, labels)
  233. loss.backward()
  234. optimizer.step()
  235. # print statistics
  236. running_loss += loss.item()
  237. epoch_steps += 1
  238. if i % 2000 == 1999: # print every 2000 mini-batches
  239. print(
  240. "[%d, %5d] loss: %.3f"
  241. % (epoch + 1, i + 1, running_loss / epoch_steps)
  242. )
  243. running_loss = 0.0
  244. # Validation loss
  245. val_loss = 0.0
  246. val_steps = 0
  247. total = 0
  248. correct = 0
  249. for i, data in enumerate(valloader, 0):
  250. with torch.no_grad():
  251. inputs, labels = data
  252. inputs, labels = inputs.to(device), labels.to(device)
  253. outputs = net(inputs)
  254. _, predicted = torch.max(outputs.data, 1)
  255. total += labels.size(0)
  256. correct += (predicted == labels).sum().item()
  257. loss = criterion(outputs, labels)
  258. val_loss += loss.cpu().numpy()
  259. val_steps += 1
  260. checkpoint_data = {
  261. "epoch": epoch,
  262. "net_state_dict": net.state_dict(),
  263. "optimizer_state_dict": optimizer.state_dict(),
  264. }
  265. checkpoint = Checkpoint.from_dict(checkpoint_data)
  266. session.report(
  267. {"loss": val_loss / val_steps, "accuracy": correct / total},
  268. checkpoint=checkpoint,
  269. )
  270. print("Finished Training")
  271. ######################################################################
  272. # As you can see, most of the code is adapted directly from the original example.
  273. #
  274. # Test set accuracy
  275. # -----------------
  276. # Commonly the performance of a machine learning model is tested on a hold-out test
  277. # set with data that has not been used for training the model. We also wrap this in a
  278. # function:
  279. def test_accuracy(net, device="cpu"):
  280. trainset, testset = load_data()
  281. testloader = torch.utils.data.DataLoader(
  282. testset, batch_size=4, shuffle=False, num_workers=2
  283. )
  284. correct = 0
  285. total = 0
  286. with torch.no_grad():
  287. for data in testloader:
  288. images, labels = data
  289. images, labels = images.to(device), labels.to(device)
  290. outputs = net(images)
  291. _, predicted = torch.max(outputs.data, 1)
  292. total += labels.size(0)
  293. correct += (predicted == labels).sum().item()
  294. return correct / total
  295. ######################################################################
  296. # The function also expects a ``device`` parameter, so we can do the
  297. # test set validation on a GPU.
  298. #
  299. # Configuring the search space
  300. # ----------------------------
  301. # Lastly, we need to define Ray Tune's search space. Here is an example:
  302. #
  303. # .. code-block:: python
  304. #
  305. # config = {
  306. # "l1": tune.choice([2 ** i for i in range(9)]),
  307. # "l2": tune.choice([2 ** i for i in range(9)]),
  308. # "lr": tune.loguniform(1e-4, 1e-1),
  309. # "batch_size": tune.choice([2, 4, 8, 16])
  310. # }
  311. #
  312. # The ``tune.choice()`` accepts a list of values that are uniformly sampled from.
  313. # In this example, the ``l1`` and ``l2`` parameters
  314. # should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256.
  315. # The ``lr`` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,
  316. # the batch size is a choice between 2, 4, 8, and 16.
  317. #
  318. # At each trial, Ray Tune will now randomly sample a combination of parameters from these
  319. # search spaces. It will then train a number of models in parallel and find the best
  320. # performing one among these. We also use the ``ASHAScheduler`` which will terminate bad
  321. # performing trials early.
  322. #
  323. # We wrap the ``train_cifar`` function with ``functools.partial`` to set the constant
  324. # ``data_dir`` parameter. We can also tell Ray Tune what resources should be
  325. # available for each trial:
  326. #
  327. # .. code-block:: python
  328. #
  329. # gpus_per_trial = 2
  330. # # ...
  331. # result = tune.run(
  332. # partial(train_cifar, data_dir=data_dir),
  333. # resources_per_trial={"cpu": 8, "gpu": gpus_per_trial},
  334. # config=config,
  335. # num_samples=num_samples,
  336. # scheduler=scheduler,
  337. # checkpoint_at_end=True)
  338. #
  339. # You can specify the number of CPUs, which are then available e.g.
  340. # to increase the ``num_workers`` of the PyTorch ``DataLoader`` instances. The selected
  341. # number of GPUs are made visible to PyTorch in each trial. Trials do not have access to
  342. # GPUs that haven't been requested for them - so you don't have to care about two trials
  343. # using the same set of resources.
  344. #
  345. # Here we can also specify fractional GPUs, so something like ``gpus_per_trial=0.5`` is
  346. # completely valid. The trials will then share GPUs among each other.
  347. # You just have to make sure that the models still fit in the GPU memory.
  348. #
  349. # After training the models, we will find the best performing one and load the trained
  350. # network from the checkpoint file. We then obtain the test set accuracy and report
  351. # everything by printing.
  352. #
  353. # The full main function looks like this:
  354. def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
  355. data_dir = os.path.abspath("./data")
  356. load_data(data_dir)
  357. config = {
  358. "l1": tune.choice([2**i for i in range(9)]),
  359. "l2": tune.choice([2**i for i in range(9)]),
  360. "lr": tune.loguniform(1e-4, 1e-1),
  361. "batch_size": tune.choice([2, 4, 8, 16]),
  362. }
  363. scheduler = ASHAScheduler(
  364. metric="loss",
  365. mode="min",
  366. max_t=max_num_epochs,
  367. grace_period=1,
  368. reduction_factor=2,
  369. )
  370. result = tune.run(
  371. partial(train_cifar, data_dir=data_dir),
  372. resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
  373. config=config,
  374. num_samples=num_samples,
  375. scheduler=scheduler,
  376. )
  377. best_trial = result.get_best_trial("loss", "min", "last")
  378. print(f"Best trial config: {best_trial.config}")
  379. print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
  380. print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}")
  381. best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
  382. device = "cpu"
  383. if torch.cuda.is_available():
  384. device = "cuda:0"
  385. if gpus_per_trial > 1:
  386. best_trained_model = nn.DataParallel(best_trained_model)
  387. best_trained_model.to(device)
  388. best_checkpoint = best_trial.checkpoint.to_air_checkpoint()
  389. best_checkpoint_data = best_checkpoint.to_dict()
  390. best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])
  391. test_acc = test_accuracy(best_trained_model, device)
  392. print("Best trial test set accuracy: {}".format(test_acc))
  393. if __name__ == "__main__":
  394. # sphinx_gallery_start_ignore
  395. # Fixes ``AttributeError: '_LoggingTee' object has no attribute 'fileno'``.
  396. # This is only needed to run with sphinx-build.
  397. import sys
  398. sys.stdout.fileno = lambda: False
  399. # sphinx_gallery_end_ignore
  400. # You can change the number of GPUs per trial here:
  401. main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)
  402. ######################################################################
  403. # If you run the code, an example output could look like this:
  404. #
  405. # ::
  406. #
  407. # Number of trials: 10/10 (10 TERMINATED)
  408. # +-----+--------------+------+------+-------------+--------+---------+------------+
  409. # | ... | batch_size | l1 | l2 | lr | iter | loss | accuracy |
  410. # |-----+--------------+------+------+-------------+--------+---------+------------|
  411. # | ... | 2 | 1 | 256 | 0.000668163 | 1 | 2.31479 | 0.0977 |
  412. # | ... | 4 | 64 | 8 | 0.0331514 | 1 | 2.31605 | 0.0983 |
  413. # | ... | 4 | 2 | 1 | 0.000150295 | 1 | 2.30755 | 0.1023 |
  414. # | ... | 16 | 32 | 32 | 0.0128248 | 10 | 1.66912 | 0.4391 |
  415. # | ... | 4 | 8 | 128 | 0.00464561 | 2 | 1.7316 | 0.3463 |
  416. # | ... | 8 | 256 | 8 | 0.00031556 | 1 | 2.19409 | 0.1736 |
  417. # | ... | 4 | 16 | 256 | 0.00574329 | 2 | 1.85679 | 0.3368 |
  418. # | ... | 8 | 2 | 2 | 0.00325652 | 1 | 2.30272 | 0.0984 |
  419. # | ... | 2 | 2 | 2 | 0.000342987 | 2 | 1.76044 | 0.292 |
  420. # | ... | 4 | 64 | 32 | 0.003734 | 8 | 1.53101 | 0.4761 |
  421. # +-----+--------------+------+------+-------------+--------+---------+------------+
  422. #
  423. # Best trial config: {'l1': 64, 'l2': 32, 'lr': 0.0037339984519545164, 'batch_size': 4}
  424. # Best trial final validation loss: 1.5310075663924216
  425. # Best trial final validation accuracy: 0.4761
  426. # Best trial test set accuracy: 0.4737
  427. #
  428. # Most trials have been stopped early in order to avoid wasting resources.
  429. # The best performing trial achieved a validation accuracy of about 47%, which could
  430. # be confirmed on the test set.
  431. #
  432. # So that's it! You can now tune the parameters of your PyTorch models.