test_he_lora.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import math
  6. import torch
  7. import torch.nn.functional as F
  8. import pytest
  9. import deepspeed
  10. from deepspeed.runtime.zero import GatheredParameters
  11. from deepspeed.ops.op_builder import OpBuilder
  12. from deepspeed.utils import safe_get_full_grad
  13. import numpy.testing as npt
  14. from unit.common import DistributedTest
  15. from deepspeed.ops.op_builder import InferenceBuilder
  16. from deepspeed.accelerator import get_accelerator
  17. if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
  18. pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
  19. from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM)
  20. rocm_version = OpBuilder.installed_rocm_version()
  21. if rocm_version != (0, 0):
  22. pytest.skip("skip inference tests on rocm for now", allow_module_level=True)
  23. def to_device(batch, device):
  24. output = {}
  25. for k, v in batch.items():
  26. try:
  27. output[k] = v.to(device)
  28. except:
  29. output[k] = v
  30. return output
  31. def convert_linear_layer_to_lora(model, part_module_name, lora_dim=0, lora_scaling=1, lora_droppout=0):
  32. from deepspeed.compression.helper import recursive_getattr, recursive_setattr
  33. repalce_name = []
  34. for name, module in model.named_modules():
  35. if isinstance(module, torch.nn.Linear) and part_module_name in name:
  36. repalce_name.append(name)
  37. for name in repalce_name:
  38. module = recursive_getattr(model, name)
  39. tmp = LinearLayer_LoRA(module.weight, lora_dim, lora_scaling, lora_droppout,
  40. module.bias).to(module.weight.device).to(module.weight.dtype)
  41. recursive_setattr(model, name, tmp)
  42. return model
  43. class LinearLayer_LoRA(torch.nn.Module):
  44. # an simple implementation of LoRA
  45. # for now only support Linear Layer
  46. def __init__(self, weight, lora_dim=0, lora_scaling=1, lora_droppout=0, bias=None):
  47. super(LinearLayer_LoRA, self).__init__()
  48. self.weight = weight
  49. self.bias = bias
  50. if lora_dim <= 0:
  51. raise ValueError("You are training to use LoRA, whose reduced dim should be larger than 1")
  52. try:
  53. # for zero stage 3
  54. rows, columns = weight.ds_shape
  55. except:
  56. rows, columns = weight.shape
  57. self.lora_right_weight = torch.nn.Parameter(torch.zeros(
  58. columns, lora_dim)) # apply transpose so in forward we do not need to transpose again
  59. self.lora_left_weight = torch.nn.Parameter(torch.zeros(lora_dim, rows))
  60. self.lora_scaling = lora_scaling / lora_dim
  61. if lora_droppout > 0:
  62. self.lora_dropout = torch.nn.Dropout(lora_droppout)
  63. else:
  64. self.lora_dropout = torch.nn.Identity()
  65. self.reset_parameters()
  66. # disable the original weight gradient
  67. self.weight.requires_grad = False
  68. # fuse LoRA to the original weight
  69. self.fuse_lora = False
  70. def eval(self):
  71. self.lora_dropout.eval()
  72. def train(self, mode=True):
  73. self.lora_dropout.train(mode)
  74. def reset_parameters(self):
  75. torch.nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
  76. torch.nn.init.zeros_(self.lora_left_weight)
  77. def forward(self, input):
  78. if self.fuse_lora:
  79. return F.linear(input, self.weight, self.bias)
  80. else:
  81. return F.linear(input, self.weight, self.bias) + (
  82. self.lora_dropout(input) @ self.lora_right_weight @ self.lora_left_weight) * self.lora_scaling
  83. def only_optimize_lora_parameters(model):
  84. # turn off the gradient of all the parameters except the LoRA parameters
  85. for name, param in model.named_parameters():
  86. if "lora_right_weight" in name or "lora_left_weight" in name:
  87. param.requires_grad = True
  88. else:
  89. param.requires_grad = False
  90. return model
  91. @pytest.mark.seq_inference
  92. @pytest.mark.parametrize("batch_size", [1], ids=["bsz=1"])
  93. @pytest.mark.parametrize("zero_stage", [2, 3], ids=["zero_stage=2", "zero_stage=3"])
  94. @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neo-125m", "facebook/opt-350m", "bigscience/bloom-560m"])
  95. @pytest.mark.parametrize("offload_device", ["none", "cpu"])
  96. class TestHybridEngineLoRA(DistributedTest):
  97. world_size = 1
  98. def get_model(self, model_name):
  99. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  100. model_config = AutoConfig.from_pretrained(model_name)
  101. model_config.dropout = 0.0
  102. model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config)
  103. model = model.half()
  104. device = get_accelerator().device_name()
  105. model = model.to(f'{device}:{local_rank}')
  106. return model
  107. def get_tokenizer(self, model_name):
  108. tokenizer = AutoTokenizer.from_pretrained(model_name)
  109. tokenizer.pad_token = tokenizer.eos_token
  110. return tokenizer
  111. def get_train_sentences(self, batch_size):
  112. sentences = [
  113. r"\n\nHuman: I am trying to write a fairy tale. What is the most popular plot?\n\n"
  114. r"Assistant: The most popular plot might be a princess goes to a faraway land, falls in love",
  115. r"\n\nHuman: What flowers should I grow to attract bees?\n\nAssistant: The reason you want bees "
  116. r"in your garden is to attract pollinators and get more fruit or vegetable production."
  117. ]
  118. if batch_size <= 2:
  119. return sentences[:batch_size]
  120. else:
  121. raise NotImplementedError(f"batch_size {batch_size} not implemented")
  122. def test_lora(self, batch_size, model_name, zero_stage, offload_device):
  123. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  124. model = self.get_model(model_name)
  125. tokenizer = self.get_tokenizer(model_name)
  126. train_sentences = self.get_train_sentences(batch_size)
  127. # Inject LoRA
  128. model = convert_linear_layer_to_lora(model, "", 8)
  129. model = only_optimize_lora_parameters(model)
  130. ds_config = {
  131. "optimizer": {
  132. "type": "Adam",
  133. "params": {
  134. "lr": 1.0,
  135. "betas": [0.9, 0.95]
  136. }
  137. },
  138. "train_batch_size": batch_size,
  139. "fp16": {
  140. "enabled": True,
  141. "initial_scale_power": 12
  142. },
  143. "hybrid_engine": {
  144. "enabled": True,
  145. "pin_parameters": True
  146. },
  147. "zero_optimization": {
  148. "stage": zero_stage,
  149. "offload_optimizer": {
  150. "device": offload_device
  151. }
  152. }
  153. }
  154. model, *_ = deepspeed.initialize(model=model, config=ds_config)
  155. # Verify gradient norm is larger than 0
  156. before_grad_update_layer0_params = [
  157. ele.detach().cpu().float().numpy() for ele in model.layer_params[0]
  158. if ele is not None and len(ele.shape) > 1
  159. ]
  160. model.train()
  161. batch = tokenizer(train_sentences, max_length=16, padding="max_length", truncation=True, return_tensors="pt")
  162. device = get_accelerator().device_name()
  163. batch = to_device(batch, f'{device}:{local_rank}')
  164. batch["labels"] = batch["input_ids"]
  165. outputs = model(**batch, use_cache=False)
  166. loss = outputs.loss
  167. model.backward(loss)
  168. grad_norm_dict = dict()
  169. for name, param in model.named_parameters():
  170. if param.requires_grad is True:
  171. grad_norm_dict[name] = torch.linalg.norm(safe_get_full_grad(param))
  172. model.step()
  173. grad_norm = sum([ele.detach().cpu().numpy() for ele in grad_norm_dict.values()])
  174. assert grad_norm > 1E-5
  175. # Verify parameter remains the same
  176. after_grad_update_layer0_params = [
  177. ele.detach().cpu().float().numpy() for ele in model.layer_params[0]
  178. if ele is not None and len(ele.shape) > 1
  179. ]
  180. for lhs, rhs in zip(before_grad_update_layer0_params, after_grad_update_layer0_params):
  181. npt.assert_allclose(lhs, rhs, 1E-5, 1E-5)
  182. # Verify fuse will mutate layer_params
  183. model.eval()
  184. with GatheredParameters(model.parameters()):
  185. model.fuse_lora_weight()
  186. after_grad_update_layer0_params_lora_fused = [
  187. ele.detach().cpu().float().numpy() for ele in model.layer_params[0]
  188. if ele is not None and len(ele.shape) > 1
  189. ]
  190. for lhs, rhs in zip(before_grad_update_layer0_params, after_grad_update_layer0_params_lora_fused):
  191. with pytest.raises(AssertionError):
  192. npt.assert_allclose(lhs, rhs, 1E-5, 1E-5)
  193. with GatheredParameters(model.parameters()):
  194. model.unfuse_lora_weight()