noisy_layer.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import numpy as np
  2. from ray.rllib.models.utils import get_activation_fn
  3. from ray.rllib.utils.framework import try_import_torch, TensorType
  4. torch, nn = try_import_torch()
  5. class NoisyLayer(nn.Module):
  6. """A Layer that adds learnable Noise to some previous layer's outputs.
  7. Consists of:
  8. - a common dense layer: y = w^{T}x + b
  9. - a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x +
  10. (b+\\epsilon_b*\\sigma_b)
  11. , where \epsilon are random variables sampled from factorized normal
  12. distributions and \\sigma are trainable variables which are expected to
  13. vanish along the training procedure.
  14. """
  15. def __init__(self,
  16. in_size: int,
  17. out_size: int,
  18. sigma0: float,
  19. activation: str = "relu"):
  20. """Initializes a NoisyLayer object.
  21. Args:
  22. in_size: Input size for Noisy Layer
  23. out_size: Output size for Noisy Layer
  24. sigma0: Initialization value for sigma_b (bias noise)
  25. activation: Non-linear activation for Noisy Layer
  26. """
  27. super().__init__()
  28. self.in_size = in_size
  29. self.out_size = out_size
  30. self.sigma0 = sigma0
  31. self.activation = get_activation_fn(activation, framework="torch")
  32. if self.activation is not None:
  33. self.activation = self.activation()
  34. sigma_w = nn.Parameter(
  35. torch.from_numpy(
  36. np.random.uniform(
  37. low=-1.0 / np.sqrt(float(self.in_size)),
  38. high=1.0 / np.sqrt(float(self.in_size)),
  39. size=[self.in_size, out_size])).float())
  40. self.register_parameter("sigma_w", sigma_w)
  41. sigma_b = nn.Parameter(
  42. torch.from_numpy(
  43. np.full(
  44. shape=[out_size],
  45. fill_value=sigma0 / np.sqrt(float(self.in_size)))).float())
  46. self.register_parameter("sigma_b", sigma_b)
  47. w = nn.Parameter(
  48. torch.from_numpy(
  49. np.full(
  50. shape=[self.in_size, self.out_size],
  51. fill_value=6 /
  52. np.sqrt(float(in_size) + float(out_size)))).float())
  53. self.register_parameter("w", w)
  54. b = nn.Parameter(torch.from_numpy(np.zeros([out_size])).float())
  55. self.register_parameter("b", b)
  56. def forward(self, inputs: TensorType) -> TensorType:
  57. epsilon_in = self._f_epsilon(
  58. torch.normal(
  59. mean=torch.zeros([self.in_size]),
  60. std=torch.ones([self.in_size])).to(inputs.device))
  61. epsilon_out = self._f_epsilon(
  62. torch.normal(
  63. mean=torch.zeros([self.out_size]),
  64. std=torch.ones([self.out_size])).to(inputs.device))
  65. epsilon_w = torch.matmul(
  66. torch.unsqueeze(epsilon_in, -1),
  67. other=torch.unsqueeze(epsilon_out, 0))
  68. epsilon_b = epsilon_out
  69. action_activation = torch.matmul(
  70. inputs, self.w +
  71. self.sigma_w * epsilon_w) + self.b + self.sigma_b * epsilon_b
  72. if self.activation is not None:
  73. action_activation = self.activation(action_activation)
  74. return action_activation
  75. def _f_epsilon(self, x: TensorType) -> TensorType:
  76. return torch.sign(x) * torch.pow(torch.abs(x), 0.5)