neural_computer.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. from collections import OrderedDict
  2. import gym
  3. from typing import Union, Dict, List, Tuple
  4. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  5. from ray.rllib.models.torch.misc import SlimFC
  6. from ray.rllib.utils.framework import try_import_torch
  7. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  8. try:
  9. from dnc import DNC
  10. except ModuleNotFoundError:
  11. print("dnc module not found. Did you forget to 'pip install dnc'?")
  12. raise
  13. torch, nn = try_import_torch()
  14. class DNCMemory(TorchModelV2, nn.Module):
  15. """Differentiable Neural Computer wrapper around ixaxaar's DNC implementation,
  16. see https://github.com/ixaxaar/pytorch-dnc"""
  17. DEFAULT_CONFIG = {
  18. "dnc_model": DNC,
  19. # Number of controller hidden layers
  20. "num_hidden_layers": 1,
  21. # Number of weights per controller hidden layer
  22. "hidden_size": 64,
  23. # Number of LSTM units
  24. "num_layers": 1,
  25. # Number of read heads, i.e. how many addrs are read at once
  26. "read_heads": 4,
  27. # Number of memory cells in the controller
  28. "nr_cells": 32,
  29. # Size of each cell
  30. "cell_size": 16,
  31. # LSTM activation function
  32. "nonlinearity": "tanh",
  33. # Observation goes through this torch.nn.Module before
  34. # feeding to the DNC
  35. "preprocessor": torch.nn.Sequential(
  36. torch.nn.Linear(64, 64), torch.nn.Tanh()),
  37. # Input size to the preprocessor
  38. "preprocessor_input_size": 64,
  39. # The output size of the preprocessor
  40. # and the input size of the dnc
  41. "preprocessor_output_size": 64,
  42. }
  43. MEMORY_KEYS = [
  44. "memory",
  45. "link_matrix",
  46. "precedence",
  47. "read_weights",
  48. "write_weights",
  49. "usage_vector",
  50. ]
  51. def __init__(
  52. self,
  53. obs_space: gym.spaces.Space,
  54. action_space: gym.spaces.Space,
  55. num_outputs: int,
  56. model_config: ModelConfigDict,
  57. name: str,
  58. **custom_model_kwargs,
  59. ):
  60. nn.Module.__init__(self)
  61. super(DNCMemory, self).__init__(obs_space, action_space, num_outputs,
  62. model_config, name)
  63. self.num_outputs = num_outputs
  64. self.obs_dim = gym.spaces.utils.flatdim(obs_space)
  65. self.act_dim = gym.spaces.utils.flatdim(action_space)
  66. self.cfg = dict(self.DEFAULT_CONFIG, **custom_model_kwargs)
  67. assert (self.cfg["num_layers"] == 1
  68. ), "num_layers != 1 has not been implemented yet"
  69. self.cur_val = None
  70. self.preprocessor = torch.nn.Sequential(
  71. torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_input_size"]),
  72. self.cfg["preprocessor"],
  73. )
  74. self.logit_branch = SlimFC(
  75. in_size=self.cfg["hidden_size"],
  76. out_size=self.num_outputs,
  77. activation_fn=None,
  78. initializer=torch.nn.init.xavier_uniform_,
  79. )
  80. self.value_branch = SlimFC(
  81. in_size=self.cfg["hidden_size"],
  82. out_size=1,
  83. activation_fn=None,
  84. initializer=torch.nn.init.xavier_uniform_,
  85. )
  86. self.dnc: Union[None, DNC] = None
  87. def get_initial_state(self) -> List[TensorType]:
  88. ctrl_hidden = [
  89. torch.zeros(self.cfg["num_hidden_layers"],
  90. self.cfg["hidden_size"]),
  91. torch.zeros(self.cfg["num_hidden_layers"],
  92. self.cfg["hidden_size"]),
  93. ]
  94. m = self.cfg["nr_cells"]
  95. r = self.cfg["read_heads"]
  96. w = self.cfg["cell_size"]
  97. memory = [
  98. torch.zeros(m, w), # memory
  99. torch.zeros(1, m, m), # link_matrix
  100. torch.zeros(1, m), # precedence
  101. torch.zeros(r, m), # read_weights
  102. torch.zeros(1, m), # write_weights
  103. torch.zeros(m), # usage_vector
  104. ]
  105. read_vecs = torch.zeros(w * r)
  106. state = [*ctrl_hidden, read_vecs, *memory]
  107. assert len(state) == 9
  108. return state
  109. def value_function(self) -> TensorType:
  110. assert self.cur_val is not None, "must call forward() first"
  111. return self.cur_val
  112. def unpack_state(
  113. self,
  114. state: List[TensorType],
  115. ) -> Tuple[List[Tuple[TensorType, TensorType]], Dict[str, TensorType],
  116. TensorType]:
  117. """Given a list of tensors, reformat for self.dnc input"""
  118. assert len(state) == 9, "Failed to verify unpacked state"
  119. ctrl_hidden: List[Tuple[TensorType, TensorType]] = [(
  120. state[0].permute(1, 0, 2).contiguous(),
  121. state[1].permute(1, 0, 2).contiguous(),
  122. )]
  123. read_vecs: TensorType = state[2]
  124. memory: List[TensorType] = state[3:]
  125. memory_dict: OrderedDict[str, TensorType] = OrderedDict(
  126. zip(self.MEMORY_KEYS, memory))
  127. return ctrl_hidden, memory_dict, read_vecs
  128. def pack_state(
  129. self,
  130. ctrl_hidden: List[Tuple[TensorType, TensorType]],
  131. memory_dict: Dict[str, TensorType],
  132. read_vecs: TensorType,
  133. ) -> List[TensorType]:
  134. """Given the dnc output, pack it into a list of tensors
  135. for rllib state. Order is ctrl_hidden, read_vecs, memory_dict"""
  136. state = []
  137. ctrl_hidden = [
  138. ctrl_hidden[0][0].permute(1, 0, 2),
  139. ctrl_hidden[0][1].permute(1, 0, 2),
  140. ]
  141. state += ctrl_hidden
  142. assert len(state) == 2, "Failed to verify packed state"
  143. state.append(read_vecs)
  144. assert len(state) == 3, "Failed to verify packed state"
  145. state += memory_dict.values()
  146. assert len(state) == 9, "Failed to verify packed state"
  147. return state
  148. def validate_unpack(self, dnc_output, unpacked_state):
  149. """Ensure the unpacked state shapes match the DNC output"""
  150. s_ctrl_hidden, s_memory_dict, s_read_vecs = unpacked_state
  151. ctrl_hidden, memory_dict, read_vecs = dnc_output
  152. for i in range(len(ctrl_hidden)):
  153. for j in range(len(ctrl_hidden[i])):
  154. assert s_ctrl_hidden[i][j].shape == ctrl_hidden[i][j].shape, (
  155. "Controller state mismatch: got "
  156. f"{s_ctrl_hidden[i][j].shape} should be "
  157. f"{ctrl_hidden[i][j].shape}")
  158. for k in memory_dict:
  159. assert s_memory_dict[k].shape == memory_dict[k].shape, (
  160. "Memory state mismatch at key "
  161. f"{k}: got {s_memory_dict[k].shape} should be "
  162. f"{memory_dict[k].shape}")
  163. assert s_read_vecs.shape == read_vecs.shape, (
  164. "Read state mismatch: got "
  165. f"{s_read_vecs.shape} should be "
  166. f"{read_vecs.shape}")
  167. def build_dnc(self, device_idx: Union[int, None]) -> None:
  168. self.dnc = self.cfg["dnc_model"](
  169. input_size=self.cfg["preprocessor_output_size"],
  170. hidden_size=self.cfg["hidden_size"],
  171. num_layers=self.cfg["num_layers"],
  172. num_hidden_layers=self.cfg["num_hidden_layers"],
  173. read_heads=self.cfg["read_heads"],
  174. cell_size=self.cfg["cell_size"],
  175. nr_cells=self.cfg["nr_cells"],
  176. nonlinearity=self.cfg["nonlinearity"],
  177. gpu_id=device_idx,
  178. )
  179. def forward(
  180. self,
  181. input_dict: Dict[str, TensorType],
  182. state: List[TensorType],
  183. seq_lens: TensorType,
  184. ) -> Tuple[TensorType, List[TensorType]]:
  185. flat = input_dict["obs_flat"]
  186. # Batch and Time
  187. # Forward expects outputs as [B, T, logits]
  188. B = len(seq_lens)
  189. T = flat.shape[0] // B
  190. # Deconstruct batch into batch and time dimensions: [B, T, feats]
  191. flat = torch.reshape(flat, [-1, T] + list(flat.shape[1:]))
  192. # First run
  193. if self.dnc is None:
  194. gpu_id = flat.device.index if flat.device.index is not None else -1
  195. self.build_dnc(gpu_id)
  196. hidden = (None, None, None)
  197. else:
  198. hidden = self.unpack_state(state) # type: ignore
  199. # Run thru preprocessor before DNC
  200. z = self.preprocessor(flat.reshape(B * T, self.obs_dim))
  201. z = z.reshape(B, T, self.cfg["preprocessor_output_size"])
  202. output, hidden = self.dnc(z, hidden)
  203. packed_state = self.pack_state(*hidden)
  204. # Compute action/value from output
  205. logits = self.logit_branch(output.view(B * T, -1))
  206. values = self.value_branch(output.view(B * T, -1))
  207. self.cur_val = values.squeeze(1)
  208. return logits, packed_state