offline_evaluation_utils.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import numpy as np
  2. import pandas as pd
  3. from typing import Any, Dict, Type, TYPE_CHECKING
  4. from ray.rllib.policy.sample_batch import SampleBatch
  5. from ray.rllib.policy import Policy
  6. from ray.rllib.utils.numpy import convert_to_numpy
  7. from ray.rllib.utils.annotations import DeveloperAPI
  8. if TYPE_CHECKING:
  9. from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
  10. from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
  11. @DeveloperAPI
  12. def compute_q_and_v_values(
  13. batch: pd.DataFrame,
  14. model_class: Type["FQETorchModel"],
  15. model_state: Dict[str, Any],
  16. compute_q_values: bool = True,
  17. ) -> pd.DataFrame:
  18. """Computes the Q and V values for the given batch of samples.
  19. This function is to be used with map_batches() to perform a batch prediction on a
  20. dataset of records with `obs` and `actions` columns.
  21. Args:
  22. batch: A sub-batch from the dataset.
  23. model_class: The model class to use for the prediction. This class should be a
  24. sub-class of FQEModel that implements the estimate_q() and estimate_v()
  25. methods.
  26. model_state: The state of the model to use for the prediction.
  27. compute_q_values: Whether to compute the Q values or not. If False, only the V
  28. is computed and returned.
  29. Returns:
  30. The modified batch with the Q and V values added as columns.
  31. """
  32. model = model_class.from_state(model_state)
  33. sample_batch = SampleBatch(
  34. {
  35. SampleBatch.OBS: np.vstack(batch[SampleBatch.OBS]),
  36. SampleBatch.ACTIONS: np.vstack(batch[SampleBatch.ACTIONS]).squeeze(-1),
  37. }
  38. )
  39. v_values = model.estimate_v(sample_batch)
  40. v_values = convert_to_numpy(v_values)
  41. batch["v_values"] = v_values
  42. if compute_q_values:
  43. q_values = model.estimate_q(sample_batch)
  44. q_values = convert_to_numpy(q_values)
  45. batch["q_values"] = q_values
  46. return batch
  47. @DeveloperAPI
  48. def compute_is_weights(
  49. batch: pd.DataFrame,
  50. policy_state: Dict[str, Any],
  51. estimator_class: Type["OffPolicyEstimator"],
  52. ) -> pd.DataFrame:
  53. """Computes the importance sampling weights for the given batch of samples.
  54. For a lot of off-policy estimators, the importance sampling weights are computed as
  55. the propensity score ratio between the new and old policies
  56. (i.e. new_pi(act|obs) / old_pi(act|obs)). This function is to be used with
  57. map_batches() to perform a batch prediction on a dataset of records with `obs`,
  58. `actions`, `action_prob` and `rewards` columns.
  59. Args:
  60. batch: A sub-batch from the dataset.
  61. policy_state: The state of the policy to use for the prediction.
  62. estimator_class: The estimator class to use for the prediction. This class
  63. Returns:
  64. The modified batch with the importance sampling weights, weighted rewards, new
  65. and old propensities added as columns.
  66. """
  67. policy = Policy.from_state(policy_state)
  68. estimator = estimator_class(policy=policy, gamma=0, epsilon_greedy=0)
  69. sample_batch = SampleBatch(
  70. {
  71. SampleBatch.OBS: np.vstack(batch["obs"].values),
  72. SampleBatch.ACTIONS: np.vstack(batch["actions"].values).squeeze(-1),
  73. SampleBatch.ACTION_PROB: np.vstack(batch["action_prob"].values).squeeze(-1),
  74. SampleBatch.REWARDS: np.vstack(batch["rewards"].values).squeeze(-1),
  75. }
  76. )
  77. new_prob = estimator.compute_action_probs(sample_batch)
  78. old_prob = sample_batch[SampleBatch.ACTION_PROB]
  79. rewards = sample_batch[SampleBatch.REWARDS]
  80. weights = new_prob / old_prob
  81. weighted_rewards = weights * rewards
  82. batch["weights"] = weights
  83. batch["weighted_rewards"] = weighted_rewards
  84. batch["new_prob"] = new_prob
  85. batch["old_prob"] = old_prob
  86. return batch
  87. @DeveloperAPI
  88. def remove_time_dim(batch: pd.DataFrame) -> pd.DataFrame:
  89. """Removes the time dimension from the given sub-batch of the dataset.
  90. If each row in a dataset has a time dimension ([T, D]), and T=1, this function will
  91. remove the T dimension to convert each row to of shape [D]. If T > 1, the row is
  92. left unchanged. This function is to be used with map_batches().
  93. Args:
  94. batch: The batch to remove the time dimension from.
  95. Returns:
  96. The modified batch with the time dimension removed (when applicable)
  97. """
  98. BATCHED_KEYS = {
  99. SampleBatch.OBS,
  100. SampleBatch.ACTIONS,
  101. SampleBatch.ACTION_PROB,
  102. SampleBatch.REWARDS,
  103. SampleBatch.NEXT_OBS,
  104. SampleBatch.DONES,
  105. }
  106. for k in batch.columns:
  107. if k in BATCHED_KEYS:
  108. batch[k] = batch[k].apply(lambda x: x[0] if len(x) == 1 else x)
  109. return batch