test_he_lora.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import math
  6. import torch
  7. import torch.nn.functional as F
  8. import pytest
  9. import deepspeed
  10. from deepspeed.runtime.zero import GatheredParameters
  11. from deepspeed.ops.op_builder import OpBuilder
  12. from deepspeed.utils import safe_get_full_grad
  13. import numpy.testing as npt
  14. from unit.common import DistributedTest
  15. from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM)
  16. rocm_version = OpBuilder.installed_rocm_version()
  17. if rocm_version != (0, 0):
  18. pytest.skip("skip inference tests on rocm for now", allow_module_level=True)
  19. def to_device(batch, device):
  20. output = {}
  21. for k, v in batch.items():
  22. try:
  23. output[k] = v.to(device)
  24. except:
  25. output[k] = v
  26. return output
  27. def convert_linear_layer_to_lora(model, part_module_name, lora_dim=0, lora_scaling=1, lora_droppout=0):
  28. from deepspeed.compression.helper import recursive_getattr, recursive_setattr
  29. repalce_name = []
  30. for name, module in model.named_modules():
  31. if isinstance(module, torch.nn.Linear) and part_module_name in name:
  32. repalce_name.append(name)
  33. for name in repalce_name:
  34. module = recursive_getattr(model, name)
  35. tmp = LinearLayer_LoRA(module.weight, lora_dim, lora_scaling, lora_droppout,
  36. module.bias).to(module.weight.device).to(module.weight.dtype)
  37. recursive_setattr(model, name, tmp)
  38. return model
  39. class LinearLayer_LoRA(torch.nn.Module):
  40. # an simple implementation of LoRA
  41. # for now only support Linear Layer
  42. def __init__(self, weight, lora_dim=0, lora_scaling=1, lora_droppout=0, bias=None):
  43. super(LinearLayer_LoRA, self).__init__()
  44. self.weight = weight
  45. self.bias = bias
  46. if lora_dim <= 0:
  47. raise ValueError("You are training to use LoRA, whose reduced dim should be larger than 1")
  48. try:
  49. # for zero stage 3
  50. rows, columns = weight.ds_shape
  51. except:
  52. rows, columns = weight.shape
  53. self.lora_right_weight = torch.nn.Parameter(torch.zeros(
  54. columns, lora_dim)) # apply transpose so in forward we do not need to transpose again
  55. self.lora_left_weight = torch.nn.Parameter(torch.zeros(lora_dim, rows))
  56. self.lora_scaling = lora_scaling / lora_dim
  57. if lora_droppout > 0:
  58. self.lora_dropout = torch.nn.Dropout(lora_droppout)
  59. else:
  60. self.lora_dropout = torch.nn.Identity()
  61. self.reset_parameters()
  62. # disable the original weight gradient
  63. self.weight.requires_grad = False
  64. # fuse LoRA to the original weight
  65. self.fuse_lora = False
  66. def eval(self):
  67. self.lora_dropout.eval()
  68. def train(self, mode=True):
  69. self.lora_dropout.train(mode)
  70. def reset_parameters(self):
  71. torch.nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
  72. torch.nn.init.zeros_(self.lora_left_weight)
  73. def forward(self, input):
  74. if self.fuse_lora:
  75. return F.linear(input, self.weight, self.bias)
  76. else:
  77. return F.linear(input, self.weight, self.bias) + (
  78. self.lora_dropout(input) @ self.lora_right_weight @ self.lora_left_weight) * self.lora_scaling
  79. def only_optimize_lora_parameters(model):
  80. # turn off the gradient of all the parameters except the LoRA parameters
  81. for name, param in model.named_parameters():
  82. if "lora_right_weight" in name or "lora_left_weight" in name:
  83. param.requires_grad = True
  84. else:
  85. param.requires_grad = False
  86. return model
  87. @pytest.mark.seq_inference
  88. @pytest.mark.parametrize("batch_size", [1], ids=["bsz=1"])
  89. @pytest.mark.parametrize("zero_stage", [2, 3], ids=["zero_stage=2", "zero_stage=3"])
  90. @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neo-125m", "facebook/opt-350m", "bigscience/bloom-560m"])
  91. @pytest.mark.parametrize("offload_device", ["none", "cpu"])
  92. class TestHybridEngineLoRA(DistributedTest):
  93. world_size = 1
  94. def get_model(self, model_name):
  95. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  96. model_config = AutoConfig.from_pretrained(model_name)
  97. model_config.dropout = 0.0
  98. model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config)
  99. model = model.half()
  100. model = model.to(f'cuda:{local_rank}')
  101. return model
  102. def get_tokenizer(self, model_name):
  103. tokenizer = AutoTokenizer.from_pretrained(model_name)
  104. tokenizer.pad_token = tokenizer.eos_token
  105. return tokenizer
  106. def get_train_sentences(self, batch_size):
  107. sentences = [
  108. r"\n\nHuman: I am trying to write a fairy tale. What is the most popular plot?\n\n"
  109. r"Assistant: The most popular plot might be a princess goes to a faraway land, falls in love",
  110. r"\n\nHuman: What flowers should I grow to attract bees?\n\nAssistant: The reason you want bees "
  111. r"in your garden is to attract pollinators and get more fruit or vegetable production."
  112. ]
  113. if batch_size <= 2:
  114. return sentences[:batch_size]
  115. else:
  116. raise NotImplementedError(f"batch_size {batch_size} not implemented")
  117. def test_lora(self, batch_size, model_name, zero_stage, offload_device):
  118. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  119. model = self.get_model(model_name)
  120. tokenizer = self.get_tokenizer(model_name)
  121. train_sentences = self.get_train_sentences(batch_size)
  122. # Inject LoRA
  123. model = convert_linear_layer_to_lora(model, "", 8)
  124. model = only_optimize_lora_parameters(model)
  125. ds_config = {
  126. "optimizer": {
  127. "type": "Adam",
  128. "params": {
  129. "lr": 1.0,
  130. "betas": [0.9, 0.95]
  131. }
  132. },
  133. "train_batch_size": batch_size,
  134. "fp16": {
  135. "enabled": True,
  136. "initial_scale_power": 12
  137. },
  138. "hybrid_engine": {
  139. "enabled": True,
  140. "pin_parameters": True
  141. },
  142. "zero_optimization": {
  143. "stage": zero_stage,
  144. "offload_optimizer": {
  145. "device": offload_device
  146. }
  147. }
  148. }
  149. model, *_ = deepspeed.initialize(model=model, config=ds_config)
  150. # Verify gradient norm is larger than 0
  151. before_grad_update_layer0_params = [
  152. ele.detach().cpu().float().numpy() for ele in model.layer_params[0]
  153. if ele is not None and len(ele.shape) > 1
  154. ]
  155. model.train()
  156. batch = tokenizer(train_sentences, max_length=16, padding="max_length", truncation=True, return_tensors="pt")
  157. batch = to_device(batch, f'cuda:{local_rank}')
  158. batch["labels"] = batch["input_ids"]
  159. outputs = model(**batch, use_cache=False)
  160. loss = outputs.loss
  161. model.backward(loss)
  162. grad_norm_dict = dict()
  163. for name, param in model.named_parameters():
  164. if param.requires_grad is True:
  165. grad_norm_dict[name] = torch.linalg.norm(safe_get_full_grad(param))
  166. model.step()
  167. grad_norm = sum([ele.detach().cpu().numpy() for ele in grad_norm_dict.values()])
  168. assert grad_norm > 1E-5
  169. # Verify parameter remains the same
  170. after_grad_update_layer0_params = [
  171. ele.detach().cpu().float().numpy() for ele in model.layer_params[0]
  172. if ele is not None and len(ele.shape) > 1
  173. ]
  174. for lhs, rhs in zip(before_grad_update_layer0_params, after_grad_update_layer0_params):
  175. npt.assert_allclose(lhs, rhs, 1E-5, 1E-5)
  176. # Verify fuse will mutate layer_params
  177. model.eval()
  178. with GatheredParameters(model.parameters()):
  179. model.fuse_lora_weight()
  180. after_grad_update_layer0_params_lora_fused = [
  181. ele.detach().cpu().float().numpy() for ele in model.layer_params[0]
  182. if ele is not None and len(ele.shape) > 1
  183. ]
  184. for lhs, rhs in zip(before_grad_update_layer0_params, after_grad_update_layer0_params_lora_fused):
  185. with pytest.raises(AssertionError):
  186. npt.assert_allclose(lhs, rhs, 1E-5, 1E-5)
  187. with GatheredParameters(model.parameters()):
  188. model.unfuse_lora_weight()