wis_estimator.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
  2. OffPolicyEstimate
  3. from ray.rllib.policy import Policy
  4. from ray.rllib.utils.annotations import override
  5. from ray.rllib.utils.typing import SampleBatchType
  6. class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
  7. """The weighted step-wise IS estimator.
  8. Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""
  9. def __init__(self, policy: Policy, gamma: float):
  10. super().__init__(policy, gamma)
  11. self.filter_values = []
  12. self.filter_counts = []
  13. @override(OffPolicyEstimator)
  14. def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
  15. self.check_can_estimate_for(batch)
  16. rewards, old_prob = batch["rewards"], batch["action_prob"]
  17. new_prob = self.action_log_likelihood(batch)
  18. # calculate importance ratios
  19. p = []
  20. for t in range(batch.count):
  21. if t == 0:
  22. pt_prev = 1.0
  23. else:
  24. pt_prev = p[t - 1]
  25. p.append(pt_prev * new_prob[t] / old_prob[t])
  26. for t, v in enumerate(p):
  27. if t >= len(self.filter_values):
  28. self.filter_values.append(v)
  29. self.filter_counts.append(1.0)
  30. else:
  31. self.filter_values[t] += v
  32. self.filter_counts[t] += 1.0
  33. # calculate stepwise weighted IS estimate
  34. V_prev, V_step_WIS = 0.0, 0.0
  35. for t in range(batch.count):
  36. V_prev += rewards[t] * self.gamma**t
  37. w_t = self.filter_values[t] / self.filter_counts[t]
  38. V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t
  39. estimation = OffPolicyEstimate(
  40. "wis", {
  41. "V_prev": V_prev,
  42. "V_step_WIS": V_step_WIS,
  43. "V_gain_est": V_step_WIS / max(1e-8, V_prev),
  44. })
  45. return estimation