#!/usr/bin/env python3 import argparse import json import deepspeed from deepspeed.elasticity import compute_elastic_config if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json") parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size") args = parser.parse_args() ds_config = json.load(open(args.config, 'r')) ds_version = deepspeed.__version__ elastic_config = ds_config['elasticity'] print('------------------------------------------') print("Elasticity config:") print('------------------------------------------') print(json.dumps(elastic_config, indent=4, sort_keys=True)) if args.world_size > 0: final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size) print('------------------------------------------') print(f"Calculated results for world size {args.world_size}:") print('------------------------------------------') print(f'final_batch_size .... {final_batch_size}') print(f'valid_gpus .......... {valid_gpus}') print(f'micro_batch_size .... {micro_batch_size}') else: final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version) print('------------------------------------------') print("Calculated results:") print('------------------------------------------') print(f'final_batch_size .... {final_batch_size}') print(f'valid_gpus .......... {valid_gpus}')