123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- import argparse
- import numpy as np
- import torch
- import torch.nn.functional as F
- import pytest
- import json
- import random
- import time
- import copy
- from torch import nn
- from modelingpreln import BertEncoder as BertEncoderPreln
- from modeling import BertEncoder as BertEncoderPostln
- from modeling import BertConfig, BertLayerNorm
- from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
- import sys
- def check_equal(first, second, atol=1e-2, verbose=False):
- diction_x = {}
- diction_y = {}
- for i, (x, y) in enumerate(zip(first, second)):
- print(x[1], y[1])
- for i, (x, y) in enumerate(zip(first, second)):
- k = 0
- while (diction_x.get((k, x[1])) is not None):
- k = k + 1
- diction_x[k, x[1]] = x[0]
- k = 0
- while (diction_y.get((k, y[1])) is not None):
- k = k + 1
- diction_y[k, y[1]] = y[0]
- if verbose:
- print()
- for i, (x, y) in enumerate(zip(diction_x, diction_y)):
- print(x, y)
- for i, (x, y) in enumerate(zip(diction_x, diction_y)):
- if (x[0] == 1): continue
- print("checking ", x[1], ":")
- y = diction_y[x[0], x[1]]
- x = diction_x[x[0], x[1]]
- x = x.cpu().detach().numpy()
- y = y.cpu().detach().numpy()
- print(x)
- print(y)
- avgx = np.sum(abs(x), dtype=float)
- countx = x.shape[0]
- for i in range(len(x.shape) - 1):
- countx *= x.shape[i + 1]
- avgx = np.sum(avgx)
- tollerance = 1
- if avgx != float('inf') and avgx != -float('inf'):
- avgx = avgx / countx
- tollerance = avgx * atol
- print("tollerance is ", tollerance)
- if verbose:
- print("x = {}".format(x.flatten()))
- print("y = {}".format(y.flatten()))
- print('-' * 80)
- np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i), atol=tollerance)
- def zero_grad(variables):
- for variable in variables:
- variable.grad.zero_()
- device = torch.device("cuda")
- kwargs_fp32 = {'dtype': torch.float, 'device': device, 'requires_grad': True}
- kwargs_fp16 = {'dtype': torch.half, 'device': device, 'requires_grad': True}
- class DSEncoder(nn.Module):
- def __init__(self, config, weights, biases):
- super(DSEncoder, self).__init__()
- self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
- self.layer = nn.ModuleList([
- copy.deepcopy(DeepSpeedTransformerLayer(i,
- config,
- weights,
- biases))
- for i in range(config.num_hidden_layers)
- ])
- self.grads = []
- self.pre_or_post = config.pre_layer_norm
- def forward(self,
- hidden_states,
- attention_mask,
- output_all_encoded_layers=True,
- checkpoint_activations=False):
- all_encoder_layers = []
- def custom(start, end):
- def custom_forward(*inputs):
- layers = self.layer[start:end]
- x_ = inputs[0]
- for layer in layers:
- x_ = layer(x_, inputs[1])
- return x_
- return custom_forward
- if checkpoint_activations:
- l = 0
- num_layers = len(self.layer)
- chunk_length = math.ceil(math.sqrt(num_layers))
- while l < num_layers:
- hidden_states = checkpoint.checkpoint(custom(l,
- l + chunk_length),
- hidden_states,
- attention_mask * 1)
- l += chunk_length
- # decoder layers
- else:
- for i, layer_module in enumerate(self.layer):
- hidden_states = layer_module(hidden_states, attention_mask, self.grads)
- hidden_states.register_hook(
- lambda x,
- self=self: self.grads.append([x,
- "hidden_state"]))
- if output_all_encoded_layers:
- all_encoder_layers.append(hidden_states)
- if not output_all_encoded_layers or checkpoint_activations:
- if (self.pre_or_post):
- hidden_states = self.FinalLayerNorm(hidden_states)
- all_encoder_layers.append(hidden_states)
- return all_encoder_layers
- def get_grads(self):
- return self.grads
- def create_models(ds_config):
- bert_config = BertConfig(vocab_size_or_config_json_file=119547,
- hidden_size=ds_config.hidden_size,
- num_hidden_layers=ds_config.num_hidden_layers,
- num_attention_heads=ds_config.heads,
- intermediate_size=4 * ds_config.hidden_size,
- hidden_act="gelu",
- hidden_dropout_prob=ds_config.hidden_dropout_ratio,
- attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
- max_position_embeddings=ds_config.max_seq_length,
- type_vocab_size=2,
- initializer_range=ds_config.initializer_range)
- weights = []
- biases = []
- for i in range(4):
- weights.append(
- nn.Parameter(torch.Tensor(ds_config.hidden_size,
- ds_config.hidden_size)))
- weights[i].data.normal_(mean=0.0, std=ds_config.initializer_range)
- weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
- weights[4].data.fill_(1.0)
- weights.append(
- nn.Parameter(torch.Tensor(4 * ds_config.hidden_size,
- ds_config.hidden_size)))
- weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range)
- weights.append(
- nn.Parameter(torch.Tensor(ds_config.hidden_size,
- 4 * ds_config.hidden_size)))
- weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range)
- weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
- weights[7].data.fill_(1.0)
- biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
- biases[0].data.zero_()
- for i in range(4):
- biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
- biases[i + 1].data.zero_()
- biases.append(nn.Parameter(torch.Tensor(4 * ds_config.hidden_size)))
- biases[5].data.zero_()
- biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
- biases[6].data.zero_()
- biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
- biases[7].data.zero_()
- if (ds_config.pre_layer_norm):
- bert_encoder = BertEncoderPreln(bert_config, weights, biases)
- else:
- bert_encoder = BertEncoderPostln(bert_config, weights, biases)
- ds_encoder = DSEncoder(ds_config, weights, biases)
- if ds_config.fp16:
- bert_encoder.half()
- ds_encoder.half()
- bert_encoder.cuda()
- ds_encoder.cuda()
- return bert_encoder, ds_encoder
- def set_seed(seed):
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- def run_backward(ds_config, atol=1e-2, verbose=False):
- set_seed(123)
- bert_encoder, ds_encoder = create_models(ds_config)
- # prepare test data
- kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
- hidden_states = torch.randn(ds_config.batch_size,
- ds_config.max_seq_length,
- ds_config.hidden_size,
- **kwargs)
- input_mask = torch.randn(ds_config.batch_size,
- 1,
- 1,
- ds_config.max_seq_length,
- **kwargs)
- Y = torch.randn(ds_config.batch_size,
- ds_config.max_seq_length,
- ds_config.hidden_size,
- **kwargs)
- # run baseline
- base_results = bert_encoder(hidden_states,
- input_mask,
- output_all_encoded_layers=False,
- checkpoint_activations=False)
- loss = (Y - base_results[0]).pow(2).sum()
- loss.backward()
- base_grads = bert_encoder.get_grads()
- # run ds
- ds_results = ds_encoder(hidden_states,
- input_mask,
- output_all_encoded_layers=False,
- checkpoint_activations=False)
- loss = (Y - ds_results[0]).pow(2).sum()
- loss.backward()
- ds_grads = ds_encoder.get_grads()
- # check grads
- check_equal(base_grads, ds_grads, atol=atol, verbose=verbose)
- @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
- [
- (3,1024,128,16,24,True,False, 0.05),
- (3,1024,128,16,24,True,True, 0.05),
- (3,1024,128,16,24,False,False, 0.1),
- (3,1024,128,16,24,False,True, 0.2),
- ]) # yapf: disable
- def test_backward(batch_size,
- hidden_size,
- seq_len,
- heads,
- num_layers,
- is_preln,
- use_fp16,
- atol):
- # Only run fp16 test cases on devices with 7+ capability.
- major, _ = torch.cuda.get_device_capability()
- if major < 7 and (use_fp16 is True or is_preln is False):
- return
- ds_config = DeepSpeedTransformerConfig()
- ds_config.layer_id = None
- ds_config.batch_size = batch_size
- ds_config.hidden_size = hidden_size
- ds_config.max_seq_length = seq_len
- ds_config.heads = heads
- ds_config.attn_dropout_ratio = 0.0
- ds_config.hidden_dropout_ratio = 0.0
- ds_config.num_hidden_layers = num_layers
- ds_config.pre_layer_norm = is_preln
- ds_config.initializer_range = 0.02
- ds_config.fp16 = use_fp16
- run_backward(ds_config, atol=atol)
- #@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
- # [
- # (3,1024,128,16,24,True,False, 0.07),
- # (3,1024,128,16,24,True,True, 0.05),
- # (3,1024,128,16,24,False,False, 0.1),
- # (3,1024,128,16,24,False,True, 0.2),
- # ]) # yapf: disable
- #def test_backward_stochastic(batch_size,
- # hidden_size,
- # seq_len,
- # heads,
- # num_layers,
- # is_preln,
- # use_fp16,
- # atol):
- # # Only run fp16 test cases on devices with 7+ capability.
- # major, _ = torch.cuda.get_device_capability()
- # if major < 7 and (use_fp16 is True or is_preln is False):
- # return
- #
- # ds_config = DeepSpeedTransformerConfig()
- # ds_config.layer_id = None
- # ds_config.batch_size = batch_size
- # ds_config.hidden_size = hidden_size
- # ds_config.max_seq_length = seq_len
- # ds_config.heads = heads
- # ds_config.attn_dropout_ratio = 0.0
- # ds_config.hidden_dropout_ratio = 0.0
- # ds_config.num_hidden_layers = num_layers
- # ds_config.pre_layer_norm = is_preln
- # ds_config.initializer_range = 0.02
- # ds_config.fp16 = use_fp16
- # ds_config.stochastic_mode = True
- #
- # run_backward(ds_config, atol=atol)
|