test_lr_schedulers.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import torch
  2. import deepspeed
  3. import argparse
  4. import pytest
  5. import json
  6. import os
  7. from common import distributed_test
  8. from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
  9. @pytest.mark.parametrize("scheduler_type,params",
  10. [("WarmupLR",
  11. {}),
  12. ("OneCycle",
  13. {
  14. 'cycle_min_lr': 0,
  15. 'cycle_max_lr': 0
  16. }),
  17. ("LRRangeTest",
  18. {})])
  19. def test_get_lr_before_train(tmpdir, scheduler_type, params):
  20. config_dict = {
  21. "train_batch_size": 2,
  22. "steps_per_print": 1,
  23. "optimizer": {
  24. "type": "Adam",
  25. "params": {
  26. "lr": 0.00015
  27. },
  28. },
  29. "scheduler": {
  30. "type": scheduler_type,
  31. "params": params
  32. },
  33. "gradient_clipping": 1.0
  34. }
  35. args = args_from_dict(tmpdir, config_dict)
  36. hidden_dim = 10
  37. model = SimpleModel(hidden_dim, empty_grad=False)
  38. @distributed_test(world_size=[1])
  39. def _test_get_lr_before_train(args, model, hidden_dim):
  40. model, _, _, lr_scheduler = deepspeed.initialize(args=args,
  41. model=model,
  42. model_parameters=model.parameters())
  43. data_loader = random_dataloader(model=model,
  44. total_samples=50,
  45. hidden_dim=hidden_dim,
  46. device=model.device,
  47. dtype=torch.float)
  48. for n, batch in enumerate(data_loader):
  49. # get lr before training starts
  50. lr_scheduler.get_lr()
  51. loss = model(batch[0], batch[1])
  52. model.backward(loss)
  53. model.step()
  54. _test_get_lr_before_train(args=args, model=model, hidden_dim=hidden_dim)