eigenvalue.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.utils import log_dist
  6. import numpy as np
  7. import logging
  8. from deepspeed.utils.torch import required_torch_version
  9. class Eigenvalue(object):
  10. def __init__(self,
  11. verbose=False,
  12. max_iter=100,
  13. tol=1e-2,
  14. stability=0,
  15. gas_boundary_resolution=1,
  16. layer_name='',
  17. layer_num=0):
  18. super().__init__()
  19. self.verbose = verbose
  20. self.max_iter = max_iter
  21. self.tol = tol
  22. self.stability = stability
  23. self.gas_boundary_resolution = gas_boundary_resolution
  24. self.layer_name = layer_name
  25. self.layer_num = layer_num
  26. assert len(self.layer_name) > 0 and layer_num > 0
  27. log_dist(
  28. f'enabled eigenvalue with verbose={verbose}, max_iter={max_iter}, tol={tol}, stability={stability}, gas_boundary_resolution={gas_boundary_resolution}, layer_name={layer_name}, layer_num={layer_num}',
  29. ranks=[0])
  30. # Replace all nan/pos-inf/neg-inf to zero
  31. def nan_to_num(self, x):
  32. if required_torch_version(min_version=1.8):
  33. return torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
  34. else:
  35. # Fallback to numpy based implementation for backwards-compatibility with PyTorch 1.7 or older versions.
  36. device = x.device
  37. x = x.cpu().numpy()
  38. x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
  39. return torch.from_numpy(x).to(device)
  40. def normalize(self, v):
  41. norm_squared = self.inner_product(v, v)
  42. norm = norm_squared**0.5 + self.stability
  43. normalized_vectors = [vector / norm for vector in v]
  44. normalized_vectors = [self.nan_to_num(vector) for vector in normalized_vectors]
  45. return normalized_vectors
  46. def inner_product(self, xs, ys):
  47. return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])
  48. def get_layers(self, module):
  49. scope_names = self.layer_name.split('.')
  50. assert len(scope_names) > 0
  51. m = module
  52. for name in scope_names:
  53. assert hasattr(m, name), "layer_name configuration is invalid."
  54. m = getattr(m, name)
  55. return m
  56. def compute_eigenvalue(self, module, device=None, scale=1.0):
  57. block_eigenvalue = []
  58. param_keys = []
  59. layers = self.get_layers(module)
  60. for block in range(self.layer_num):
  61. model_block = layers[block]
  62. # We found this randn() has obvious accuracy impact in some cases, save/recover random state here.
  63. rng_state = torch.random.get_rng_state()
  64. if device is None:
  65. v = [
  66. torch.randn(p.size()) for p in model_block.parameters()
  67. if p.grad is not None and p.grad.grad_fn is not None
  68. ]
  69. else:
  70. v = [
  71. torch.randn(p.size(), device=device) for p in model_block.parameters()
  72. if p.grad is not None and p.grad.grad_fn is not None
  73. ]
  74. torch.random.set_rng_state(rng_state)
  75. grads = [
  76. param.grad for param in model_block.parameters()
  77. if param.grad is not None and param.grad.grad_fn is not None
  78. ]
  79. params = [
  80. param for param in model_block.parameters()
  81. if param.grad is not None and param.grad.grad_fn is not None
  82. ]
  83. layer_keys = [id(p) for p in model_block.parameters()]
  84. param_keys.append(layer_keys)
  85. v = self.normalize(v)
  86. # Disable eigenvalue if the model doesn't support second order gradients computation,
  87. # e.g. when enabling DS transformer kernel.
  88. if len(grads) == 0 or len(params) == 0:
  89. log_dist(f'The model does NOT support eigenvalue computation.', ranks=[0], level=logging.WARNING)
  90. return []
  91. i = 0
  92. eigenvalue_current, eigenvalue_previous = 1., 0.
  93. while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs(
  94. (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) >=
  95. self.tol): # test convergence criteria
  96. eigenvalue_previous = eigenvalue_current
  97. Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True)
  98. #Hv = [hv.float() for hv in Hv]
  99. Hv = [self.nan_to_num(hv).float() for hv in Hv]
  100. eigenvalue_current = self.inner_product(Hv, v).item()
  101. v = self.normalize(Hv)
  102. v = [x / scale for x in v]
  103. i += 1
  104. eigenvalue_current *= scale
  105. block_eigenvalue.append(eigenvalue_current)
  106. if self.verbose:
  107. log_dist(f'block: {block}, power iteration: {i}, eigenvalue: {eigenvalue_current}', ranks=[0])
  108. block_eigenvalue = self.post_process(block_eigenvalue)
  109. if self.verbose:
  110. log_dist(f'post processed block_eigenvalue: {block_eigenvalue}', ranks=[0])
  111. # {param_id: (eigenvalue, layer_id)}
  112. ev_dict = {}
  113. for i, (layer_keys, value) in enumerate(zip(param_keys, block_eigenvalue)):
  114. ev_dict.update(dict.fromkeys(layer_keys, (value, i)))
  115. return ev_dict
  116. # 1. Map all eigenvalues to [0, 1.0].
  117. # 2. Some layers can't generate valid eigenvalues on fp16 precision, use 1.0 instead.
  118. def post_process(self, value_list):
  119. max_value = abs(max(value_list, key=abs))
  120. return [abs(v) / max_value if v != 0.0 else 1.0 for v in value_list]