gru_gate.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from ray.rllib.utils.framework import try_import_torch
  2. from ray.rllib.utils.framework import TensorType
  3. torch, nn = try_import_torch()
  4. class GRUGate(nn.Module):
  5. """Implements a gated recurrent unit for use in AttentionNet"""
  6. def __init__(self, dim: int, init_bias: int = 0., **kwargs):
  7. """
  8. input_shape (torch.Tensor): dimension of the input
  9. init_bias (int): Bias added to every input to stabilize training
  10. """
  11. super().__init__(**kwargs)
  12. # Xavier initialization of torch tensors
  13. self._w_r = nn.Parameter(torch.zeros(dim, dim))
  14. self._w_z = nn.Parameter(torch.zeros(dim, dim))
  15. self._w_h = nn.Parameter(torch.zeros(dim, dim))
  16. nn.init.xavier_uniform_(self._w_r)
  17. nn.init.xavier_uniform_(self._w_z)
  18. nn.init.xavier_uniform_(self._w_h)
  19. self.register_parameter("_w_r", self._w_r)
  20. self.register_parameter("_w_z", self._w_z)
  21. self.register_parameter("_w_h", self._w_h)
  22. self._u_r = nn.Parameter(torch.zeros(dim, dim))
  23. self._u_z = nn.Parameter(torch.zeros(dim, dim))
  24. self._u_h = nn.Parameter(torch.zeros(dim, dim))
  25. nn.init.xavier_uniform_(self._u_r)
  26. nn.init.xavier_uniform_(self._u_z)
  27. nn.init.xavier_uniform_(self._u_h)
  28. self.register_parameter("_u_r", self._u_r)
  29. self.register_parameter("_u_z", self._u_z)
  30. self.register_parameter("_u_h", self._u_h)
  31. self._bias_z = nn.Parameter(torch.zeros(dim, ).fill_(init_bias))
  32. self.register_parameter("_bias_z", self._bias_z)
  33. def forward(self, inputs: TensorType, **kwargs) -> TensorType:
  34. # Pass in internal state first.
  35. h, X = inputs
  36. r = torch.tensordot(X, self._w_r, dims=1) + \
  37. torch.tensordot(h, self._u_r, dims=1)
  38. r = torch.sigmoid(r)
  39. z = torch.tensordot(X, self._w_z, dims=1) + \
  40. torch.tensordot(h, self._u_z, dims=1) - self._bias_z
  41. z = torch.sigmoid(z)
  42. h_next = torch.tensordot(X, self._w_h, dims=1) + \
  43. torch.tensordot((h * r), self._u_h, dims=1)
  44. h_next = torch.tanh(h_next)
  45. return (1 - z) * h + z * h_next