debug.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """ debug utils """
  2. import fcntl
  3. # for debug purposes map module and param objects to their fully qualified names
  4. module_names = {}
  5. param_names = {}
  6. def debug_extract_module_and_param_names(model):
  7. # extract the fully qualified names as soon as the model is acquired
  8. global module_names
  9. global param_names
  10. # XXX: can probably make a map of param2module and vice-versa
  11. module_names = {module: name for name, module in model.named_modules()}
  12. param_names = {param: name for name, param in model.named_parameters()}
  13. def debug_module2name(module):
  14. if module in module_names:
  15. return module_names[module]
  16. else:
  17. return "unknown"
  18. def debug_module2name_id(module):
  19. return f"name={debug_module2name(module)} id={module.id}"
  20. def debug_module2name_class(module):
  21. return f"name={debug_module2name(module)} {module.__class__.__name__}"
  22. def debug_param2name(param):
  23. if param in param_names:
  24. return param_names[param]
  25. else:
  26. return "unknown"
  27. def debug_param2name_id(param):
  28. return f"name={debug_param2name(param)} id={param.ds_id}"
  29. def debug_param2name_id_shape(param):
  30. return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape}"
  31. def debug_param2name_id_shape_device(param):
  32. return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape} device={param.device}"
  33. def debug_param2name_id_numel(param):
  34. return f"name={debug_param2name(param)} id={param.ds_id} numel={param.numel()}"
  35. def debug_param2name_id_shape_status(param):
  36. return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape} status={param.ds_status}"
  37. def printflock(*msgs):
  38. """
  39. For printing messages for all concurrent gpus w/o getting interleaved text.
  40. This is useful when debugging issues where multi-gpus don't sync.
  41. 1. Enable the force debug in say partitioning and zero3 files
  42. 2. Override the usual versions with ::
  43. def print_rank_0(message, debug=False, force=False):
  44. rank = torch.distributed.get_rank()
  45. printflock(f"[{rank}] {message}")
  46. 3. run the program and you get both logs non-interleaved
  47. But this makes it very difficult to make sense of the output, so the ``log_rank_file`` helper
  48. function might be more useful, as it's easier to send each log stream into a separate file and
  49. then compare those.
  50. """
  51. with open(__file__, "r") as fh:
  52. fcntl.flock(fh, fcntl.LOCK_EX)
  53. try:
  54. print(*msgs)
  55. finally:
  56. fcntl.flock(fh, fcntl.LOCK_UN)
  57. fh = None
  58. def log_rank_file(rank, *msgs):
  59. """
  60. Print to a log file of the given rank
  61. This is useful for debugging hanging in sync processes. Here is a possible workflow:
  62. 1. Enable the force debug in say partitioning and zero3 files
  63. 2. Override the usual versions of print_rank_0 in those files with ::
  64. def print_rank_0(message, debug=False, force=False):
  65. rank = torch.distributed.get_rank()
  66. log_rank_file(rank, message)
  67. 3. run the program
  68. 4. fix up the expected differences, e.g. different cuda numbers ::
  69. perl -pi -e 's|cuda:1|cuda:0|' log_rank_*
  70. 5. now diff and see where names and ids diverge - you will find where the gpus don't do the same
  71. work (e.g. when some layers get conditionally skipped on one gpu but not all)
  72. diff -u log_rank_0.txt log_rank_1.txt | less
  73. """
  74. global fh
  75. if fh is None:
  76. fh = open(f"log_rank_{rank}.txt", "w")
  77. for m in msgs:
  78. fh.write(f"{m}\n")
  79. fh.flush()
  80. def print_backward_tensors(tensor):
  81. def _print_bwd_tensors(grad_fn):
  82. print(f"Backward tensors in {grad_fn}")
  83. for funcs in grad_fn.next_functions:
  84. if funcs[0]:
  85. try:
  86. tensor = getattr(funcs[0], 'variable')
  87. print(funcs[0])
  88. print(
  89. f"Tensor - id: {id(tensor)}, shape: {tensor.shape}, data: {tensor}, grad: {tensor.grad}"
  90. )
  91. except AttributeError as e:
  92. _print_bwd_tensors(funcs[0])
  93. if hasattr(tensor, 'grad_fn'):
  94. _print_bwd_tensors(tensor.grad_fn)