weight_init.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import torch.nn as nn
  2. import torch
  3. def constant_init(module, val, bias=0):
  4. nn.init.constant_(module.weight, val)
  5. if hasattr(module, 'bias') and module.bias is not None:
  6. nn.init.constant_(module.bias, bias)
  7. def xavier_init(module, gain=1, bias=0, distribution='normal'):
  8. assert distribution in ['uniform', 'normal']
  9. if distribution == 'uniform':
  10. nn.init.xavier_uniform_(module.weight, gain=gain)
  11. else:
  12. nn.init.xavier_normal_(module.weight, gain=gain)
  13. if hasattr(module, 'bias') and module.bias is not None:
  14. nn.init.constant_(module.bias, bias)
  15. def normal_init(module, mean=0, std=1, bias=0):
  16. nn.init.normal_(module.weight, mean, std)
  17. if hasattr(module, 'bias') and module.bias is not None:
  18. nn.init.constant_(module.bias, bias)
  19. def uniform_init(module, a=0, b=1, bias=0):
  20. nn.init.uniform_(module.weight, a, b)
  21. if hasattr(module, 'bias') and module.bias is not None:
  22. nn.init.constant_(module.bias, bias)
  23. def kaiming_init(module,
  24. a=0,
  25. is_rnn=False,
  26. mode='fan_in',
  27. nonlinearity='leaky_relu',
  28. bias=0,
  29. distribution='normal'):
  30. assert distribution in ['uniform', 'normal']
  31. if distribution == 'uniform':
  32. if is_rnn:
  33. for name, param in module.named_parameters():
  34. if 'bias' in name:
  35. nn.init.constant_(param, bias)
  36. elif 'weight' in name:
  37. nn.init.kaiming_uniform_(param,
  38. a=a,
  39. mode=mode,
  40. nonlinearity=nonlinearity)
  41. else:
  42. nn.init.kaiming_uniform_(module.weight,
  43. a=a,
  44. mode=mode,
  45. nonlinearity=nonlinearity)
  46. else:
  47. if is_rnn:
  48. for name, param in module.named_parameters():
  49. if 'bias' in name:
  50. nn.init.constant_(param, bias)
  51. elif 'weight' in name:
  52. nn.init.kaiming_normal_(param,
  53. a=a,
  54. mode=mode,
  55. nonlinearity=nonlinearity)
  56. else:
  57. nn.init.kaiming_normal_(module.weight,
  58. a=a,
  59. mode=mode,
  60. nonlinearity=nonlinearity)
  61. if not is_rnn and hasattr(module, 'bias') and module.bias is not None:
  62. nn.init.constant_(module.bias, bias)
  63. def bilinear_kernel(in_channels, out_channels, kernel_size):
  64. factor = (kernel_size + 1) // 2
  65. if kernel_size % 2 == 1:
  66. center = factor - 1
  67. else:
  68. center = factor - 0.5
  69. og = (torch.arange(kernel_size).reshape(-1, 1),
  70. torch.arange(kernel_size).reshape(1, -1))
  71. filt = (1 - torch.abs(og[0] - center) / factor) * \
  72. (1 - torch.abs(og[1] - center) / factor)
  73. weight = torch.zeros((in_channels, out_channels,
  74. kernel_size, kernel_size))
  75. weight[range(in_channels), range(out_channels), :, :] = filt
  76. return weight
  77. def init_weights(m):
  78. # for m in modules:
  79. if isinstance(m, nn.Conv2d):
  80. kaiming_init(m)
  81. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  82. constant_init(m, 1)
  83. elif isinstance(m, nn.Linear):
  84. xavier_init(m)
  85. elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
  86. kaiming_init(m, is_rnn=True)
  87. # elif isinstance(m, nn.ConvTranspose2d):
  88. # m.weight.data.copy_(bilinear_kernel(m.in_channels, m.out_channels, 4));