uncertainty_wrappers.py 898 B

123456789101112131415161718192021222324
  1. ##########
  2. # Contribution by the Center on Long-Term Risk:
  3. # https://github.com/longtermrisk/marltoolbox
  4. ##########
  5. import numpy as np
  6. def add_RewardUncertaintyEnvClassWrapper(EnvClass,
  7. reward_uncertainty_std,
  8. reward_uncertainty_mean=0.0):
  9. class RewardUncertaintyEnvClassWrapper(EnvClass):
  10. def step(self, action):
  11. observations, rewards, done, info = super().step(action)
  12. return observations, self.reward_wrapper(rewards), done, info
  13. def reward_wrapper(self, reward_dict):
  14. for k in reward_dict.keys():
  15. reward_dict[k] += np.random.normal(
  16. loc=reward_uncertainty_mean,
  17. scale=reward_uncertainty_std,
  18. size=())
  19. return reward_dict
  20. return RewardUncertaintyEnvClassWrapper