curiosity.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. from gym.spaces import Discrete, MultiDiscrete, Space
  2. import numpy as np
  3. from typing import Optional, Tuple, Union
  4. from ray.rllib.models.action_dist import ActionDistribution
  5. from ray.rllib.models.catalog import ModelCatalog
  6. from ray.rllib.models.modelv2 import ModelV2
  7. from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical
  8. from ray.rllib.models.torch.misc import SlimFC
  9. from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
  10. TorchMultiCategorical
  11. from ray.rllib.models.utils import get_activation_fn
  12. from ray.rllib.policy.sample_batch import SampleBatch
  13. from ray.rllib.utils import NullContextManager
  14. from ray.rllib.utils.annotations import override
  15. from ray.rllib.utils.exploration.exploration import Exploration
  16. from ray.rllib.utils.framework import try_import_tf, \
  17. try_import_torch
  18. from ray.rllib.utils.from_config import from_config
  19. from ray.rllib.utils.tf_utils import get_placeholder, one_hot as tf_one_hot
  20. from ray.rllib.utils.torch_utils import one_hot
  21. from ray.rllib.utils.typing import FromConfigSpec, ModelConfigDict, TensorType
  22. tf1, tf, tfv = try_import_tf()
  23. torch, nn = try_import_torch()
  24. F = None
  25. if nn is not None:
  26. F = nn.functional
  27. class Curiosity(Exploration):
  28. """Implementation of:
  29. [1] Curiosity-driven Exploration by Self-supervised Prediction
  30. Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017.
  31. https://arxiv.org/pdf/1705.05363.pdf
  32. Learns a simplified model of the environment based on three networks:
  33. 1) Embedding observations into latent space ("feature" network).
  34. 2) Predicting the action, given two consecutive embedded observations
  35. ("inverse" network).
  36. 3) Predicting the next embedded obs, given an obs and action
  37. ("forward" network).
  38. The less the agent is able to predict the actually observed next feature
  39. vector, given obs and action (through the forwards network), the larger the
  40. "intrinsic reward", which will be added to the extrinsic reward.
  41. Therefore, if a state transition was unexpected, the agent becomes
  42. "curious" and will further explore this transition leading to better
  43. exploration in sparse rewards environments.
  44. """
  45. def __init__(self,
  46. action_space: Space,
  47. *,
  48. framework: str,
  49. model: ModelV2,
  50. feature_dim: int = 288,
  51. feature_net_config: Optional[ModelConfigDict] = None,
  52. inverse_net_hiddens: Tuple[int] = (256, ),
  53. inverse_net_activation: str = "relu",
  54. forward_net_hiddens: Tuple[int] = (256, ),
  55. forward_net_activation: str = "relu",
  56. beta: float = 0.2,
  57. eta: float = 1.0,
  58. lr: float = 1e-3,
  59. sub_exploration: Optional[FromConfigSpec] = None,
  60. **kwargs):
  61. """Initializes a Curiosity object.
  62. Uses as defaults the hyperparameters described in [1].
  63. Args:
  64. feature_dim (int): The dimensionality of the feature (phi)
  65. vectors.
  66. feature_net_config (Optional[ModelConfigDict]): Optional model
  67. configuration for the feature network, producing feature
  68. vectors (phi) from observations. This can be used to configure
  69. fcnet- or conv_net setups to properly process any observation
  70. space.
  71. inverse_net_hiddens (Tuple[int]): Tuple of the layer sizes of the
  72. inverse (action predicting) NN head (on top of the feature
  73. outputs for phi and phi').
  74. inverse_net_activation (str): Activation specifier for the inverse
  75. net.
  76. forward_net_hiddens (Tuple[int]): Tuple of the layer sizes of the
  77. forward (phi' predicting) NN head.
  78. forward_net_activation (str): Activation specifier for the forward
  79. net.
  80. beta (float): Weight for the forward loss (over the inverse loss,
  81. which gets weight=1.0-beta) in the common loss term.
  82. eta (float): Weight for intrinsic rewards before being added to
  83. extrinsic ones.
  84. lr (float): The learning rate for the curiosity-specific
  85. optimizer, optimizing feature-, inverse-, and forward nets.
  86. sub_exploration (Optional[FromConfigSpec]): The config dict for
  87. the underlying Exploration to use (e.g. epsilon-greedy for
  88. DQN). If None, uses the FromSpecDict provided in the Policy's
  89. default config.
  90. """
  91. if not isinstance(action_space, (Discrete, MultiDiscrete)):
  92. raise ValueError(
  93. "Only (Multi)Discrete action spaces supported for Curiosity "
  94. "so far!")
  95. super().__init__(
  96. action_space, model=model, framework=framework, **kwargs)
  97. if self.policy_config["num_workers"] != 0:
  98. raise ValueError(
  99. "Curiosity exploration currently does not support parallelism."
  100. " `num_workers` must be 0!")
  101. self.feature_dim = feature_dim
  102. if feature_net_config is None:
  103. feature_net_config = self.policy_config["model"].copy()
  104. self.feature_net_config = feature_net_config
  105. self.inverse_net_hiddens = inverse_net_hiddens
  106. self.inverse_net_activation = inverse_net_activation
  107. self.forward_net_hiddens = forward_net_hiddens
  108. self.forward_net_activation = forward_net_activation
  109. self.action_dim = self.action_space.n if isinstance(
  110. self.action_space, Discrete) else np.sum(self.action_space.nvec)
  111. self.beta = beta
  112. self.eta = eta
  113. self.lr = lr
  114. # TODO: (sven) if sub_exploration is None, use Trainer's default
  115. # Exploration config.
  116. if sub_exploration is None:
  117. raise NotImplementedError
  118. self.sub_exploration = sub_exploration
  119. # Creates modules/layers inside the actual ModelV2.
  120. self._curiosity_feature_net = ModelCatalog.get_model_v2(
  121. self.model.obs_space,
  122. self.action_space,
  123. self.feature_dim,
  124. model_config=self.feature_net_config,
  125. framework=self.framework,
  126. name="feature_net",
  127. )
  128. self._curiosity_inverse_fcnet = self._create_fc_net(
  129. [2 * self.feature_dim] + list(self.inverse_net_hiddens) +
  130. [self.action_dim],
  131. self.inverse_net_activation,
  132. name="inverse_net")
  133. self._curiosity_forward_fcnet = self._create_fc_net(
  134. [self.feature_dim + self.action_dim] + list(
  135. self.forward_net_hiddens) + [self.feature_dim],
  136. self.forward_net_activation,
  137. name="forward_net")
  138. # This is only used to select the correct action
  139. self.exploration_submodule = from_config(
  140. cls=Exploration,
  141. config=self.sub_exploration,
  142. action_space=self.action_space,
  143. framework=self.framework,
  144. policy_config=self.policy_config,
  145. model=self.model,
  146. num_workers=self.num_workers,
  147. worker_index=self.worker_index,
  148. )
  149. @override(Exploration)
  150. def get_exploration_action(self,
  151. *,
  152. action_distribution: ActionDistribution,
  153. timestep: Union[int, TensorType],
  154. explore: bool = True):
  155. # Simply delegate to sub-Exploration module.
  156. return self.exploration_submodule.get_exploration_action(
  157. action_distribution=action_distribution,
  158. timestep=timestep,
  159. explore=explore)
  160. @override(Exploration)
  161. def get_exploration_optimizer(self, optimizers):
  162. # Create, but don't add Adam for curiosity NN updating to the policy.
  163. # If we added and returned it here, it would be used in the policy's
  164. # update loop, which we don't want (curiosity updating happens inside
  165. # `postprocess_trajectory`).
  166. if self.framework == "torch":
  167. feature_params = list(self._curiosity_feature_net.parameters())
  168. inverse_params = list(self._curiosity_inverse_fcnet.parameters())
  169. forward_params = list(self._curiosity_forward_fcnet.parameters())
  170. # Now that the Policy's own optimizer(s) have been created (from
  171. # the Model parameters (IMPORTANT: w/o(!) the curiosity params),
  172. # we can add our curiosity sub-modules to the Policy's Model.
  173. self.model._curiosity_feature_net = \
  174. self._curiosity_feature_net.to(self.device)
  175. self.model._curiosity_inverse_fcnet = \
  176. self._curiosity_inverse_fcnet.to(self.device)
  177. self.model._curiosity_forward_fcnet = \
  178. self._curiosity_forward_fcnet.to(self.device)
  179. self._optimizer = torch.optim.Adam(
  180. forward_params + inverse_params + feature_params, lr=self.lr)
  181. else:
  182. self.model._curiosity_feature_net = self._curiosity_feature_net
  183. self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet
  184. self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet
  185. # Feature net is a RLlib ModelV2, the other 2 are keras Models.
  186. self._optimizer_var_list = \
  187. self._curiosity_feature_net.base_model.variables + \
  188. self._curiosity_inverse_fcnet.variables + \
  189. self._curiosity_forward_fcnet.variables
  190. self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr)
  191. # Create placeholders and initialize the loss.
  192. if self.framework == "tf":
  193. self._obs_ph = get_placeholder(
  194. space=self.model.obs_space, name="_curiosity_obs")
  195. self._next_obs_ph = get_placeholder(
  196. space=self.model.obs_space, name="_curiosity_next_obs")
  197. self._action_ph = get_placeholder(
  198. space=self.model.action_space, name="_curiosity_action")
  199. self._forward_l2_norm_sqared, self._update_op = \
  200. self._postprocess_helper_tf(
  201. self._obs_ph, self._next_obs_ph, self._action_ph)
  202. return optimizers
  203. @override(Exploration)
  204. def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
  205. """Calculates phi values (obs, obs', and predicted obs') and ri.
  206. Also calculates forward and inverse losses and updates the curiosity
  207. module on the provided batch using our optimizer.
  208. """
  209. if self.framework != "torch":
  210. self._postprocess_tf(policy, sample_batch, tf_sess)
  211. else:
  212. self._postprocess_torch(policy, sample_batch)
  213. def _postprocess_tf(self, policy, sample_batch, tf_sess):
  214. # tf1 static-graph: Perform session call on our loss and update ops.
  215. if self.framework == "tf":
  216. forward_l2_norm_sqared, _ = tf_sess.run(
  217. [self._forward_l2_norm_sqared, self._update_op],
  218. feed_dict={
  219. self._obs_ph: sample_batch[SampleBatch.OBS],
  220. self._next_obs_ph: sample_batch[SampleBatch.NEXT_OBS],
  221. self._action_ph: sample_batch[SampleBatch.ACTIONS],
  222. })
  223. # tf-eager: Perform model calls, loss calculations, and optimizer
  224. # stepping on the fly.
  225. else:
  226. forward_l2_norm_sqared, _ = self._postprocess_helper_tf(
  227. sample_batch[SampleBatch.OBS],
  228. sample_batch[SampleBatch.NEXT_OBS],
  229. sample_batch[SampleBatch.ACTIONS],
  230. )
  231. # Scale intrinsic reward by eta hyper-parameter.
  232. sample_batch[SampleBatch.REWARDS] = \
  233. sample_batch[SampleBatch.REWARDS] + \
  234. self.eta * forward_l2_norm_sqared
  235. return sample_batch
  236. def _postprocess_helper_tf(self, obs, next_obs, actions):
  237. with (tf.GradientTape()
  238. if self.framework != "tf" else NullContextManager()) as tape:
  239. # Push both observations through feature net to get both phis.
  240. phis, _ = self.model._curiosity_feature_net({
  241. SampleBatch.OBS: tf.concat([obs, next_obs], axis=0)
  242. })
  243. phi, next_phi = tf.split(phis, 2)
  244. # Predict next phi with forward model.
  245. predicted_next_phi = self.model._curiosity_forward_fcnet(
  246. tf.concat(
  247. [phi, tf_one_hot(actions, self.action_space)], axis=-1))
  248. # Forward loss term (predicted phi', given phi and action vs
  249. # actually observed phi').
  250. forward_l2_norm_sqared = 0.5 * tf.reduce_sum(
  251. tf.square(predicted_next_phi - next_phi), axis=-1)
  252. forward_loss = tf.reduce_mean(forward_l2_norm_sqared)
  253. # Inverse loss term (prediced action that led from phi to phi' vs
  254. # actual action taken).
  255. phi_cat_next_phi = tf.concat([phi, next_phi], axis=-1)
  256. dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
  257. action_dist = Categorical(dist_inputs, self.model) if \
  258. isinstance(self.action_space, Discrete) else \
  259. MultiCategorical(
  260. dist_inputs, self.model, self.action_space.nvec)
  261. # Neg log(p); p=probability of observed action given the inverse-NN
  262. # predicted action distribution.
  263. inverse_loss = -action_dist.logp(tf.convert_to_tensor(actions))
  264. inverse_loss = tf.reduce_mean(inverse_loss)
  265. # Calculate the ICM loss.
  266. loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
  267. # Step the optimizer.
  268. if self.framework != "tf":
  269. grads = tape.gradient(loss, self._optimizer_var_list)
  270. grads_and_vars = [(g, v)
  271. for g, v in zip(grads, self._optimizer_var_list)
  272. if g is not None]
  273. update_op = self._optimizer.apply_gradients(grads_and_vars)
  274. else:
  275. update_op = self._optimizer.minimize(
  276. loss, var_list=self._optimizer_var_list)
  277. # Return the squared l2 norm and the optimizer update op.
  278. return forward_l2_norm_sqared, update_op
  279. def _postprocess_torch(self, policy, sample_batch):
  280. # Push both observations through feature net to get both phis.
  281. phis, _ = self.model._curiosity_feature_net({
  282. SampleBatch.OBS: torch.cat([
  283. torch.from_numpy(sample_batch[SampleBatch.OBS]),
  284. torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS])
  285. ])
  286. })
  287. phi, next_phi = torch.chunk(phis, 2)
  288. actions_tensor = torch.from_numpy(
  289. sample_batch[SampleBatch.ACTIONS]).long().to(policy.device)
  290. # Predict next phi with forward model.
  291. predicted_next_phi = self.model._curiosity_forward_fcnet(
  292. torch.cat(
  293. [phi, one_hot(actions_tensor, self.action_space).float()],
  294. dim=-1))
  295. # Forward loss term (predicted phi', given phi and action vs actually
  296. # observed phi').
  297. forward_l2_norm_sqared = 0.5 * torch.sum(
  298. torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1)
  299. forward_loss = torch.mean(forward_l2_norm_sqared)
  300. # Scale intrinsic reward by eta hyper-parameter.
  301. sample_batch[SampleBatch.REWARDS] = \
  302. sample_batch[SampleBatch.REWARDS] + \
  303. self.eta * forward_l2_norm_sqared.detach().cpu().numpy()
  304. # Inverse loss term (prediced action that led from phi to phi' vs
  305. # actual action taken).
  306. phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1)
  307. dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
  308. action_dist = TorchCategorical(dist_inputs, self.model) if \
  309. isinstance(self.action_space, Discrete) else \
  310. TorchMultiCategorical(
  311. dist_inputs, self.model, self.action_space.nvec)
  312. # Neg log(p); p=probability of observed action given the inverse-NN
  313. # predicted action distribution.
  314. inverse_loss = -action_dist.logp(actions_tensor)
  315. inverse_loss = torch.mean(inverse_loss)
  316. # Calculate the ICM loss.
  317. loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
  318. # Perform an optimizer step.
  319. self._optimizer.zero_grad()
  320. loss.backward()
  321. self._optimizer.step()
  322. # Return the postprocessed sample batch (with the corrected rewards).
  323. return sample_batch
  324. def _create_fc_net(self, layer_dims, activation, name=None):
  325. """Given a list of layer dimensions (incl. input-dim), creates FC-net.
  326. Args:
  327. layer_dims (Tuple[int]): Tuple of layer dims, including the input
  328. dimension.
  329. activation (str): An activation specifier string (e.g. "relu").
  330. Examples:
  331. If layer_dims is [4,8,6] we'll have a two layer net: 4->8 (8 nodes)
  332. and 8->6 (6 nodes), where the second layer (6 nodes) does not have
  333. an activation anymore. 4 is the input dimension.
  334. """
  335. layers = [
  336. tf.keras.layers.Input(
  337. shape=(layer_dims[0], ), name="{}_in".format(name))
  338. ] if self.framework != "torch" else []
  339. for i in range(len(layer_dims) - 1):
  340. act = activation if i < len(layer_dims) - 2 else None
  341. if self.framework == "torch":
  342. layers.append(
  343. SlimFC(
  344. in_size=layer_dims[i],
  345. out_size=layer_dims[i + 1],
  346. initializer=torch.nn.init.xavier_uniform_,
  347. activation_fn=act))
  348. else:
  349. layers.append(
  350. tf.keras.layers.Dense(
  351. units=layer_dims[i + 1],
  352. activation=get_activation_fn(act),
  353. name="{}_{}".format(name, i)))
  354. if self.framework == "torch":
  355. return nn.Sequential(*layers)
  356. else:
  357. return tf.keras.Sequential(layers)