ds_elastic 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #!/usr/bin/env python3
  2. import argparse
  3. import json
  4. import deepspeed
  5. from deepspeed.elasticity import compute_elastic_config
  6. if __name__ == '__main__':
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json")
  9. parser.add_argument('-w',
  10. '--world-size',
  11. type=int,
  12. default=0,
  13. help="Intended/current world size")
  14. args = parser.parse_args()
  15. ds_config = json.load(open(args.config, 'r'))
  16. ds_version = deepspeed.__version__
  17. elastic_config = ds_config['elasticity']
  18. print('------------------------------------------')
  19. print("Elasticity config:")
  20. print('------------------------------------------')
  21. print(json.dumps(elastic_config, indent=4, sort_keys=True))
  22. if args.world_size > 0:
  23. final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size)
  24. print('------------------------------------------')
  25. print(f"Calculated results for world size {args.world_size}:")
  26. print('------------------------------------------')
  27. print(f'final_batch_size .... {final_batch_size}')
  28. print(f'valid_gpus .......... {valid_gpus}')
  29. print(f'micro_batch_size .... {micro_batch_size}')
  30. else:
  31. final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version)
  32. print('------------------------------------------')
  33. print("Calculated results:")
  34. print('------------------------------------------')
  35. print(f'final_batch_size .... {final_batch_size}')
  36. print(f'valid_gpus .......... {valid_gpus}')