test_checkpoint_sharding.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import pytest
  6. import torch
  7. import deepspeed
  8. from deepspeed.model_implementations import DeepSpeedTransformerInference
  9. from unit.common import DistributedTest, DistributedFixture
  10. from transformers import AutoConfig, AutoModelForCausalLM
  11. import deepspeed.comm as dist
  12. from huggingface_hub import snapshot_download
  13. from transformers.utils import is_offline_mode
  14. from deepspeed.ops.op_builder import InferenceBuilder
  15. from deepspeed.accelerator import get_accelerator
  16. if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
  17. pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
  18. def check_dtype(model, expected_dtype):
  19. def find_dtype(module):
  20. for child in module.children():
  21. if isinstance(child, DeepSpeedTransformerInference):
  22. return child.attention.attn_qkvw.dtype
  23. else:
  24. found_dtype = find_dtype(child)
  25. if found_dtype:
  26. return found_dtype
  27. found_dtype = find_dtype(model)
  28. assert found_dtype, "Did not find DeepSpeedTransformerInference in model"
  29. assert (found_dtype == expected_dtype), f"Expected transformer dtype {expected_dtype}, but found {found_dtype}"
  30. @pytest.fixture(params=[
  31. "bigscience/bloom-560m", "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neo-125M", "facebook/opt-350m", "facebook/opt-125m"
  32. ])
  33. def model_name(request):
  34. return request.param
  35. @pytest.fixture(params=[torch.float16, torch.int8], ids=["fp16", "int8"])
  36. def dtype(request):
  37. if request.param not in get_accelerator().supported_dtypes():
  38. pytest.skip(f"{request.param} not supported by {get_accelerator().device_name()}.")
  39. return request.param
  40. class save_shard(DistributedFixture):
  41. world_size = 2
  42. def run(self, model_name, class_tmpdir):
  43. # Only write a checkpoint if one does not exist
  44. if not os.path.isdir(os.path.join(class_tmpdir, model_name)):
  45. world_size = int(os.getenv("WORLD_SIZE", "1"))
  46. inf_config = {
  47. "replace_with_kernel_inject": True,
  48. "dtype": torch.float16,
  49. "enable_cuda_graph": False,
  50. "tensor_parallel": {
  51. "tp_size": world_size
  52. },
  53. "save_mp_checkpoint_path": os.path.join(str(class_tmpdir), model_name),
  54. }
  55. # Load model and save sharded checkpoint
  56. model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
  57. model = deepspeed.init_inference(model, config=inf_config)
  58. @pytest.mark.seq_inference
  59. class TestCheckpointShard(DistributedTest):
  60. world_size = 2
  61. def test(self, model_name, dtype, class_tmpdir, save_shard):
  62. world_size = int(os.getenv("WORLD_SIZE", "1"))
  63. inf_config = {
  64. "replace_with_kernel_inject": True,
  65. "dtype": dtype,
  66. "enable_cuda_graph": False,
  67. "tensor_parallel": {
  68. "tp_size": world_size
  69. },
  70. "checkpoint": os.path.join(class_tmpdir, model_name, "ds_inference_config.json"),
  71. }
  72. # Load model on meta tensors
  73. model_config = AutoConfig.from_pretrained(model_name)
  74. # Note that we use half precision to load initially, even for int8
  75. with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
  76. model = AutoModelForCausalLM.from_config(model_config, torch_dtype=torch.bfloat16)
  77. model = model.eval()
  78. model = deepspeed.init_inference(model, config=inf_config)
  79. check_dtype(model, dtype)
  80. @pytest.mark.seq_inference
  81. class TestCheckpointShardinAutoTP(DistributedTest):
  82. world_size = 2
  83. def test(self, model_name, class_tmpdir):
  84. def write_checkpoints_json(model_name, class_tmpdir):
  85. import json
  86. from pathlib import Path
  87. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  88. if local_rank == 0:
  89. # download only on first process
  90. cached_repo_dir = snapshot_download(
  91. model_name,
  92. local_files_only=is_offline_mode(),
  93. cache_dir=os.getenv("HF_HOME", None),
  94. ignore_patterns=["*.safetensors", "*.msgpack", "*.h5"],
  95. )
  96. file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()]
  97. data = {"type": "ds_model", "checkpoints": file_list, "version": 1.0}
  98. os.makedirs(os.path.join(class_tmpdir, model_name), exist_ok=True)
  99. json.dump(data, open(os.path.join(class_tmpdir, model_name, "ds_inference_config.json"), "w"))
  100. dist.barrier()
  101. world_size = int(os.getenv("WORLD_SIZE", "1"))
  102. inf_config = {
  103. "replace_with_kernel_inject": False,
  104. "tensor_parallel": {
  105. "tp_size": world_size
  106. },
  107. "checkpoint": os.path.join(class_tmpdir, model_name, "ds_inference_config.json"),
  108. }
  109. write_checkpoints_json(model_name, class_tmpdir)
  110. # Load model on meta tensors
  111. model_config = AutoConfig.from_pretrained(model_name)
  112. # Note that we use half precision to load initially, even for int8
  113. with deepspeed.OnDevice(dtype=torch.bfloat16, device="meta"):
  114. model = AutoModelForCausalLM.from_config(model_config, torch_dtype=torch.bfloat16)
  115. model = model.eval()
  116. model = deepspeed.init_inference(model, config=inf_config)