hybrid_engine_test.py 831 B

123456789101112131415161718192021222324252627282930
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from transformers import AutoModelForCausalLM
  6. import deepspeed
  7. import argparse
  8. from deepspeed.accelerator import get_accelerator
  9. deepspeed.runtime.utils.see_memory_usage('pre test', force=True)
  10. model = AutoModelForCausalLM.from_pretrained('facebook/opt-350M').half().to(get_accelerator().device_name())
  11. parser = argparse.ArgumentParser()
  12. parser = deepspeed.add_config_arguments(parser)
  13. args = parser.parse_args()
  14. deepspeed.runtime.utils.see_memory_usage('post test', force=True)
  15. m, _, _, _ = deepspeed.initialize(model=model, args=args, enable_hybrid_engine=True)
  16. m.eval()
  17. input = torch.ones(1, 16, device='cuda', dtype=torch.long)
  18. out = m(input)
  19. m.train()
  20. out = m(input)
  21. print(out['logits'], out['logits'].norm())