model_ensemble.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import gym
  2. from gym.spaces import Discrete, Box
  3. import numpy as np
  4. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  5. from ray.rllib.utils.framework import try_import_torch
  6. from ray.rllib.evaluation.rollout_worker import get_global_worker
  7. from ray.rllib.policy.sample_batch import SampleBatch
  8. from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER
  9. from ray.rllib.utils.typing import SampleBatchType
  10. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  11. torch, nn = try_import_torch()
  12. class TDModel(nn.Module):
  13. """Transition Dynamics Model (FC Network with Weight Norm)
  14. """
  15. def __init__(self,
  16. input_size,
  17. output_size,
  18. hidden_layers=(512, 512),
  19. hidden_nonlinearity=None,
  20. output_nonlinearity=None,
  21. weight_normalization=False,
  22. use_bias=True):
  23. super().__init__()
  24. assert len(hidden_layers) >= 1
  25. if not hidden_nonlinearity:
  26. hidden_nonlinearity = nn.ReLU
  27. if weight_normalization:
  28. weight_norm = nn.utils.weight_norm
  29. self.layers = []
  30. cur_size = input_size
  31. for h_size in hidden_layers:
  32. layer = nn.Linear(cur_size, h_size, bias=use_bias)
  33. if weight_normalization:
  34. layer = weight_norm(layer)
  35. self.layers.append(layer)
  36. if hidden_nonlinearity:
  37. self.layers.append(hidden_nonlinearity())
  38. cur_size = h_size
  39. layer = nn.Linear(cur_size, output_size, bias=use_bias)
  40. if weight_normalization:
  41. layer = weight_norm(layer)
  42. self.layers.append(layer)
  43. if output_nonlinearity:
  44. self.layers.append(output_nonlinearity())
  45. self.model = nn.Sequential(*self.layers)
  46. def forward(self, x):
  47. return self.model(x)
  48. if torch:
  49. class TDDataset(torch.utils.data.Dataset):
  50. def __init__(self, dataset: SampleBatchType, norms):
  51. self.count = dataset.count
  52. obs = dataset[SampleBatch.CUR_OBS]
  53. actions = dataset[SampleBatch.ACTIONS]
  54. delta = dataset[SampleBatch.NEXT_OBS] - obs
  55. if norms:
  56. obs = normalize(obs, norms[SampleBatch.CUR_OBS])
  57. actions = normalize(actions, norms[SampleBatch.ACTIONS])
  58. delta = normalize(delta, norms["delta"])
  59. self.x = np.concatenate([obs, actions], axis=1)
  60. self.y = delta
  61. def __len__(self):
  62. return self.count
  63. def __getitem__(self, index):
  64. return self.x[index], self.y[index]
  65. def normalize(data_array, stats):
  66. mean, std = stats
  67. return (data_array - mean) / (std + 1e-10)
  68. def denormalize(data_array, stats):
  69. mean, std = stats
  70. return data_array * (std + 1e-10) + mean
  71. def mean_std_stats(dataset: SampleBatchType):
  72. norm_dict = {}
  73. obs = dataset[SampleBatch.CUR_OBS]
  74. act = dataset[SampleBatch.ACTIONS]
  75. delta = dataset[SampleBatch.NEXT_OBS] - obs
  76. norm_dict[SampleBatch.CUR_OBS] = (np.mean(obs, axis=0), np.std(
  77. obs, axis=0))
  78. norm_dict[SampleBatch.ACTIONS] = (np.mean(act, axis=0), np.std(
  79. act, axis=0))
  80. norm_dict["delta"] = (np.mean(delta, axis=0), np.std(delta, axis=0))
  81. return norm_dict
  82. def process_samples(samples: SampleBatchType):
  83. filter_keys = [
  84. SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.NEXT_OBS
  85. ]
  86. filtered = {}
  87. for key in filter_keys:
  88. filtered[key] = samples[key]
  89. return SampleBatch(filtered)
  90. class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
  91. """Represents an ensemble of transition dynamics (TD) models.
  92. """
  93. def __init__(self, obs_space, action_space, num_outputs, model_config,
  94. name):
  95. """Initializes a DynamicEnsemble object.
  96. """
  97. nn.Module.__init__(self)
  98. if isinstance(action_space, Discrete):
  99. input_space = gym.spaces.Box(
  100. obs_space.low[0],
  101. obs_space.high[0],
  102. shape=(obs_space.shape[0] + action_space.n, ))
  103. elif isinstance(action_space, Box):
  104. input_space = gym.spaces.Box(
  105. obs_space.low[0],
  106. obs_space.high[0],
  107. shape=(obs_space.shape[0] + action_space.shape[0], ))
  108. else:
  109. raise NotImplementedError
  110. super(DynamicsEnsembleCustomModel, self).__init__(
  111. input_space, action_space, num_outputs, model_config, name)
  112. # Keep the original Env's observation space for possible clipping.
  113. self.env_obs_space = obs_space
  114. self.num_models = model_config["ensemble_size"]
  115. self.max_epochs = model_config["train_epochs"]
  116. self.lr = model_config["lr"]
  117. self.valid_split = model_config["valid_split_ratio"]
  118. self.batch_size = model_config["batch_size"]
  119. self.normalize_data = model_config["normalize_data"]
  120. self.normalizations = {}
  121. self.dynamics_ensemble = [
  122. TDModel(
  123. input_size=input_space.shape[0],
  124. output_size=obs_space.shape[0],
  125. hidden_layers=model_config["fcnet_hiddens"],
  126. hidden_nonlinearity=nn.ReLU,
  127. output_nonlinearity=None,
  128. weight_normalization=True) for _ in range(self.num_models)
  129. ]
  130. for i in range(self.num_models):
  131. self.add_module("TD-model-" + str(i), self.dynamics_ensemble[i])
  132. self.replay_buffer_max = 10000
  133. self.replay_buffer = None
  134. self.optimizers = [
  135. torch.optim.Adam(
  136. self.dynamics_ensemble[i].parameters(), lr=self.lr)
  137. for i in range(self.num_models)
  138. ]
  139. # Metric Reporting
  140. self.metrics = {}
  141. self.metrics[STEPS_SAMPLED_COUNTER] = 0
  142. # For each worker, choose a random model to choose trajectories from
  143. worker_index = get_global_worker().worker_index
  144. self.sample_index = int((worker_index - 1) / self.num_models)
  145. self.global_itr = 0
  146. self.device = (torch.device("cuda")
  147. if torch.cuda.is_available() else torch.device("cpu"))
  148. def forward(self, x):
  149. """Outputs the delta between next and current observation.
  150. """
  151. return self.dynamics_ensemble[self.sample_index](x)
  152. # Loss functions for each TD model in Ensemble (Standard L2 Loss)
  153. def loss(self, x, y):
  154. xs = torch.chunk(x, self.num_models)
  155. ys = torch.chunk(y, self.num_models)
  156. return [
  157. torch.mean(
  158. torch.pow(self.dynamics_ensemble[i](xs[i]) - ys[i], 2.0))
  159. for i in range(self.num_models)
  160. ]
  161. # Fitting Dynamics Ensembles per MBMPO Iter
  162. def fit(self):
  163. # Add env samples to Replay Buffer
  164. local_worker = get_global_worker()
  165. for pid, pol in local_worker.policy_map.items():
  166. pol.view_requirements[
  167. SampleBatch.NEXT_OBS].used_for_training = True
  168. new_samples = local_worker.sample()
  169. # Initial Exploration of 8000 timesteps
  170. if not self.global_itr:
  171. extra = local_worker.sample()
  172. new_samples.concat(extra)
  173. # Process Samples
  174. new_samples = process_samples(new_samples)
  175. if isinstance(self.action_space, Discrete):
  176. act = new_samples["actions"]
  177. new_act = np.zeros((act.size, act.max() + 1))
  178. new_act[np.arange(act.size), act] = 1
  179. new_samples["actions"] = new_act.astype("float32")
  180. if not self.replay_buffer:
  181. self.replay_buffer = new_samples
  182. else:
  183. self.replay_buffer = self.replay_buffer.concat(new_samples)
  184. # Keep Replay Buffer Size Constant
  185. self.replay_buffer = self.replay_buffer.slice(
  186. start=-self.replay_buffer_max, end=None)
  187. if self.normalize_data:
  188. self.normalizations = mean_std_stats(self.replay_buffer)
  189. # Keep Track of Timesteps from Real Environment Timesteps Sampled
  190. self.metrics[STEPS_SAMPLED_COUNTER] += new_samples.count
  191. # Create Train and Val Datasets for each TD model
  192. train_loaders = []
  193. val_loaders = []
  194. for i in range(self.num_models):
  195. t, v = self.split_train_val(self.replay_buffer)
  196. train_loaders.append(
  197. torch.utils.data.DataLoader(
  198. TDDataset(t, self.normalizations),
  199. batch_size=self.batch_size,
  200. shuffle=True))
  201. val_loaders.append(
  202. torch.utils.data.DataLoader(
  203. TDDataset(v, self.normalizations),
  204. batch_size=v.count,
  205. shuffle=False))
  206. # List of which models in ensemble to train
  207. indexes = list(range(self.num_models))
  208. valid_loss_roll_avg = None
  209. roll_avg_persitency = 0.95
  210. def convert_to_str(lst):
  211. return " ".join([str(elem) for elem in lst])
  212. for epoch in range(self.max_epochs):
  213. # Training
  214. for data in zip(*train_loaders):
  215. x = torch.cat([d[0] for d in data], dim=0).to(self.device)
  216. y = torch.cat([d[1] for d in data], dim=0).to(self.device)
  217. train_losses = self.loss(x, y)
  218. for ind in indexes:
  219. self.optimizers[ind].zero_grad()
  220. train_losses[ind].backward()
  221. self.optimizers[ind].step()
  222. for ind in range(self.num_models):
  223. train_losses[ind] = train_losses[
  224. ind].detach().cpu().numpy()
  225. # Validation
  226. val_lists = []
  227. for data in zip(*val_loaders):
  228. x = torch.cat([d[0] for d in data], dim=0).to(self.device)
  229. y = torch.cat([d[1] for d in data], dim=0).to(self.device)
  230. val_losses = self.loss(x, y)
  231. val_lists.append(val_losses)
  232. for ind in indexes:
  233. self.optimizers[ind].zero_grad()
  234. for ind in range(self.num_models):
  235. val_losses[ind] = val_losses[ind].detach().cpu().numpy()
  236. val_lists = np.array(val_lists)
  237. avg_val_losses = np.mean(val_lists, axis=0)
  238. if valid_loss_roll_avg is None:
  239. # Make sure that training doesnt end first epoch
  240. valid_loss_roll_avg = 1.5 * avg_val_losses
  241. valid_loss_roll_avg_prev = 2.0 * avg_val_losses
  242. valid_loss_roll_avg = roll_avg_persitency*valid_loss_roll_avg + \
  243. (1.0-roll_avg_persitency)*avg_val_losses
  244. print("Training Dynamics Ensemble - Epoch #%i:"
  245. "Train loss: %s, Valid Loss: %s, Moving Avg Valid Loss: %s"
  246. % (epoch, convert_to_str(train_losses),
  247. convert_to_str(avg_val_losses),
  248. convert_to_str(valid_loss_roll_avg)))
  249. for i in range(self.num_models):
  250. if (valid_loss_roll_avg_prev[i] < valid_loss_roll_avg[i]
  251. or epoch == self.max_epochs - 1) and i in indexes:
  252. indexes.remove(i)
  253. print("Stopping Training of Model %i" % i)
  254. valid_loss_roll_avg_prev = valid_loss_roll_avg
  255. if (len(indexes) == 0):
  256. break
  257. self.global_itr += 1
  258. # Returns Metric Dictionary
  259. return self.metrics
  260. def split_train_val(self, samples: SampleBatchType):
  261. dataset_size = samples.count
  262. indices = np.arange(dataset_size)
  263. np.random.shuffle(indices)
  264. split_idx = int(dataset_size * (1 - self.valid_split))
  265. idx_train = indices[:split_idx]
  266. idx_test = indices[split_idx:]
  267. train = {}
  268. val = {}
  269. for key in samples.keys():
  270. train[key] = samples[key][idx_train, :]
  271. val[key] = samples[key][idx_test, :]
  272. return SampleBatch(train), SampleBatch(val)
  273. def predict_model_batches(self, obs, actions, device=None):
  274. """Used by worker who gather trajectories via TD models.
  275. """
  276. pre_obs = obs
  277. if self.normalize_data:
  278. obs = normalize(obs, self.normalizations[SampleBatch.CUR_OBS])
  279. actions = normalize(actions,
  280. self.normalizations[SampleBatch.ACTIONS])
  281. x = np.concatenate([obs, actions], axis=-1)
  282. x = convert_to_torch_tensor(x, device=device)
  283. delta = self.forward(x).detach().cpu().numpy()
  284. if self.normalize_data:
  285. delta = denormalize(delta, self.normalizations["delta"])
  286. new_obs = pre_obs + delta
  287. clipped_obs = np.clip(new_obs, self.env_obs_space.low,
  288. self.env_obs_space.high)
  289. return clipped_obs
  290. def set_norms(self, normalization_dict):
  291. self.normalizations = normalization_dict