feature_importance.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import copy
  2. import numpy as np
  3. import pandas as pd
  4. from typing import Callable, Dict, Any
  5. import ray
  6. from ray.data import Dataset
  7. from ray.rllib.policy import Policy
  8. from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch
  9. from ray.rllib.utils.annotations import override, DeveloperAPI, ExperimentalAPI
  10. from ray.rllib.utils.typing import SampleBatchType
  11. from ray.rllib.offline.offline_evaluator import OfflineEvaluator
  12. @DeveloperAPI
  13. def _perturb_fn(batch: np.ndarray, index: int):
  14. # shuffle the indexth column features
  15. random_inds = np.random.permutation(batch.shape[0])
  16. batch[:, index] = batch[random_inds, index]
  17. @ExperimentalAPI
  18. def _perturb_df(batch: pd.DataFrame, index: int):
  19. obs_batch = np.vstack(batch["obs"].values)
  20. _perturb_fn(obs_batch, index)
  21. batch["perturbed_obs"] = list(obs_batch)
  22. return batch
  23. def _compute_actions(
  24. batch: pd.DataFrame,
  25. policy_state: Dict[str, Any],
  26. input_key: str = "",
  27. output_key: str = "",
  28. ):
  29. """A custom local function to do batch prediction of a policy.
  30. Given the policy state the action predictions are computed as a function of
  31. `input_key` and stored in the `output_key` column.
  32. Args:
  33. batch: A sub-batch from the dataset.
  34. policy_state: The state of the policy to use for the prediction.
  35. input_key: The key to use for the input to the policy. If not given, the
  36. default is SampleBatch.OBS.
  37. output_key: The key to use for the output of the policy. If not given, the
  38. default is "predicted_actions".
  39. Returns:
  40. The modified batch with the predicted actions added as a column.
  41. """
  42. if not input_key:
  43. input_key = SampleBatch.OBS
  44. policy = Policy.from_state(policy_state)
  45. sample_batch = SampleBatch(
  46. {
  47. SampleBatch.OBS: np.vstack(batch[input_key].values),
  48. }
  49. )
  50. actions, _, _ = policy.compute_actions_from_input_dict(sample_batch, explore=False)
  51. if not output_key:
  52. output_key = "predicted_actions"
  53. batch[output_key] = actions
  54. return batch
  55. @ray.remote
  56. def get_feature_importance_on_index(
  57. dataset: ray.data.Dataset,
  58. *,
  59. index: int,
  60. perturb_fn: Callable[[pd.DataFrame, int], None],
  61. batch_size: int,
  62. policy_state: Dict[str, Any],
  63. ):
  64. """A remote function to compute the feature importance of a given index.
  65. Args:
  66. dataset: The dataset to use for the computation. The dataset should have `obs`
  67. and `actions` columns. Each record should be flat d-dimensional array.
  68. index: The index of the feature to compute the importance for.
  69. perturb_fn: The function to use for perturbing the dataset at the given index.
  70. batch_size: The batch size to use for the computation.
  71. policy_state: The state of the policy to use for the computation.
  72. Returns:
  73. The modified dataset that contains a `delta` column which is the absolute
  74. difference between the expected output and the output due to the perturbation.
  75. """
  76. perturbed_ds = dataset.map_batches(
  77. perturb_fn,
  78. batch_size=batch_size,
  79. batch_format="pandas",
  80. fn_kwargs={"index": index},
  81. )
  82. perturbed_actions = perturbed_ds.map_batches(
  83. _compute_actions,
  84. batch_size=batch_size,
  85. batch_format="pandas",
  86. fn_kwargs={
  87. "output_key": "perturbed_actions",
  88. "input_key": "perturbed_obs",
  89. "policy_state": policy_state,
  90. },
  91. )
  92. def delta_fn(batch):
  93. # take the abs difference between columns 'ref_actions` and `perturbed_actions`
  94. # and store it in `diff`
  95. batch["delta"] = np.abs(batch["ref_actions"] - batch["perturbed_actions"])
  96. return batch
  97. delta = perturbed_actions.map_batches(
  98. delta_fn, batch_size=batch_size, batch_format="pandas"
  99. )
  100. return delta
  101. @DeveloperAPI
  102. class FeatureImportance(OfflineEvaluator):
  103. @override(OfflineEvaluator)
  104. def __init__(
  105. self,
  106. policy: Policy,
  107. repeat: int = 1,
  108. limit_fraction: float = 1.0,
  109. perturb_fn: Callable[[pd.DataFrame, int], pd.DataFrame] = _perturb_df,
  110. ):
  111. """Feature importance in a model inspection technique that can be used for any
  112. fitted predictor when the data is tablular.
  113. This implementation is also known as permutation importance that is defined to
  114. be the variation of the model's prediction when a single feature value is
  115. randomly shuffled. In RLlib it is implemented as a custom OffPolicyEstimator
  116. which is used to evaluate RLlib policies without performing environment
  117. interactions.
  118. Example usage: In the example below the feature importance module is used to
  119. evaluate the policy and the each feature's importance is computed after each
  120. training iteration. The permutation are repeated `self.repeat` times and the
  121. results are averages across repeats.
  122. ```python
  123. config = (
  124. AlgorithmConfig()
  125. .offline_data(
  126. off_policy_estimation_methods=
  127. {
  128. "feature_importance": {
  129. "type": FeatureImportance,
  130. "repeat": 10,
  131. "limit_fraction": 0.1,
  132. }
  133. }
  134. )
  135. )
  136. algorithm = DQN(config=config)
  137. results = algorithm.train()
  138. ```
  139. Args:
  140. policy: the policy to use for feature importance.
  141. repeat: number of times to repeat the perturbation.
  142. perturb_fn: function to perturb the features. By default reshuffle the
  143. features within the batch.
  144. limit_fraction: fraction of the dataset to use for feature importance
  145. This is only used in estimate_on_dataset when the dataset is too large
  146. to compute feature importance on.
  147. """
  148. super().__init__(policy)
  149. self.repeat = repeat
  150. self.perturb_fn = perturb_fn
  151. self.limit_fraction = limit_fraction
  152. def estimate(self, batch: SampleBatchType) -> Dict[str, Any]:
  153. """Estimate the feature importance of the policy.
  154. Given a batch of tabular observations, the importance of each feature is
  155. computed by perturbing each feature and computing the difference between the
  156. perturbed policy and the reference policy. The importance is computed for each
  157. feature and each perturbation is repeated `self.repeat` times.
  158. Args:
  159. batch: the batch of data to use for feature importance.
  160. Returns:
  161. A dict mapping each feature index string to its importance.
  162. """
  163. batch = convert_ma_batch_to_sample_batch(batch)
  164. obs_batch = batch["obs"]
  165. n_features = obs_batch.shape[-1]
  166. importance = np.zeros((self.repeat, n_features))
  167. ref_actions, _, _ = self.policy.compute_actions(obs_batch, explore=False)
  168. for r in range(self.repeat):
  169. for i in range(n_features):
  170. copy_obs_batch = copy.deepcopy(obs_batch)
  171. _perturb_fn(copy_obs_batch, index=i)
  172. perturbed_actions, _, _ = self.policy.compute_actions(
  173. copy_obs_batch, explore=False
  174. )
  175. importance[r, i] = np.mean(np.abs(perturbed_actions - ref_actions))
  176. # take an average across repeats
  177. importance = importance.mean(0)
  178. metrics = {f"feature_{i}": importance[i] for i in range(len(importance))}
  179. return metrics
  180. @override(OfflineEvaluator)
  181. def estimate_on_dataset(
  182. self, dataset: Dataset, *, n_parallelism: int = ...
  183. ) -> Dict[str, Any]:
  184. """Estimate the feature importance of the policy given a dataset.
  185. For each feature in the dataset, the importance is computed by applying
  186. perturbations to each feature and computing the difference between the
  187. perturbed prediction and the reference prediction. The importance
  188. computation for each feature and each perturbation is repeated `self.repeat`
  189. times. If dataset is large the user can initialize the estimator with a
  190. `limit_fraction` to limit the dataset to a fraction of the original dataset.
  191. The dataset should include a column named `obs` where each row is a vector of D
  192. dimensions. The importance is computed for each dimension of the vector.
  193. Note (Implementation detail): The computation across features are distributed
  194. with ray workers since each feature is independent of each other.
  195. Args:
  196. dataset: the dataset to use for feature importance.
  197. n_parallelism: number of parallel workers to use for feature importance.
  198. Returns:
  199. A dict mapping each feature index string to its importance.
  200. """
  201. policy_state = self.policy.get_state()
  202. # step 1: limit the dataset to a few first rows
  203. ds = dataset.limit(int(self.limit_fraction * dataset.count()))
  204. # step 2: compute the reference actions
  205. bsize = max(1, ds.count() // n_parallelism)
  206. actions_ds = ds.map_batches(
  207. _compute_actions,
  208. batch_size=bsize,
  209. fn_kwargs={
  210. "output_key": "ref_actions",
  211. "policy_state": policy_state,
  212. },
  213. )
  214. # step 3: compute the feature importance
  215. n_features = ds.take(1)[0][SampleBatch.OBS].shape[-1]
  216. importance = np.zeros((self.repeat, n_features))
  217. for r in range(self.repeat):
  218. # shuffle the entire dataset
  219. shuffled_ds = actions_ds.random_shuffle()
  220. bsize_per_task = max(1, (shuffled_ds.count() * n_features) // n_parallelism)
  221. # for each index perturb the dataset and compute the feat importance score
  222. remote_fns = [
  223. get_feature_importance_on_index.remote(
  224. dataset=shuffled_ds,
  225. index=i,
  226. perturb_fn=self.perturb_fn,
  227. bsize=bsize_per_task,
  228. policy_state=policy_state,
  229. )
  230. for i in range(n_features)
  231. ]
  232. ds_w_fi_scores = ray.get(remote_fns)
  233. importance[r] = np.array([d.mean("delta") for d in ds_w_fi_scores])
  234. importance = importance.mean(0)
  235. metrics = {f"feature_{i}": importance[i] for i in range(len(importance))}
  236. return metrics