ds_elastic 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. #!/usr/bin/env python3
  2. import argparse
  3. import json
  4. import deepspeed
  5. from deepspeed.elasticity import compute_elastic_config
  6. if __name__ == '__main__':
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json")
  9. parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size")
  10. args = parser.parse_args()
  11. ds_config = json.load(open(args.config, 'r'))
  12. ds_version = deepspeed.__version__
  13. elastic_config = ds_config['elasticity']
  14. print('------------------------------------------')
  15. print("Elasticity config:")
  16. print('------------------------------------------')
  17. print(json.dumps(elastic_config, indent=4, sort_keys=True))
  18. if args.world_size > 0:
  19. final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config,
  20. target_deepspeed_version=ds_version,
  21. world_size=args.world_size)
  22. print('------------------------------------------')
  23. print(f"Calculated results for world size {args.world_size}:")
  24. print('------------------------------------------')
  25. print(f'final_batch_size .... {final_batch_size}')
  26. print(f'valid_gpus .......... {valid_gpus}')
  27. print(f'micro_batch_size .... {micro_batch_size}')
  28. else:
  29. final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version)
  30. print('------------------------------------------')
  31. print("Calculated results:")
  32. print('------------------------------------------')
  33. print(f'final_batch_size .... {final_batch_size}')
  34. print(f'valid_gpus .......... {valid_gpus}')