123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- import copy
- import numpy as np
- import pandas as pd
- from typing import Callable, Dict, Any
- import ray
- from ray.data import Dataset
- from ray.rllib.policy import Policy
- from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch
- from ray.rllib.utils.annotations import override, DeveloperAPI, ExperimentalAPI
- from ray.rllib.utils.typing import SampleBatchType
- from ray.rllib.offline.offline_evaluator import OfflineEvaluator
- @DeveloperAPI
- def _perturb_fn(batch: np.ndarray, index: int):
- # shuffle the indexth column features
- random_inds = np.random.permutation(batch.shape[0])
- batch[:, index] = batch[random_inds, index]
- @ExperimentalAPI
- def _perturb_df(batch: pd.DataFrame, index: int):
- obs_batch = np.vstack(batch["obs"].values)
- _perturb_fn(obs_batch, index)
- batch["perturbed_obs"] = list(obs_batch)
- return batch
- def _compute_actions(
- batch: pd.DataFrame,
- policy_state: Dict[str, Any],
- input_key: str = "",
- output_key: str = "",
- ):
- """A custom local function to do batch prediction of a policy.
- Given the policy state the action predictions are computed as a function of
- `input_key` and stored in the `output_key` column.
- Args:
- batch: A sub-batch from the dataset.
- policy_state: The state of the policy to use for the prediction.
- input_key: The key to use for the input to the policy. If not given, the
- default is SampleBatch.OBS.
- output_key: The key to use for the output of the policy. If not given, the
- default is "predicted_actions".
- Returns:
- The modified batch with the predicted actions added as a column.
- """
- if not input_key:
- input_key = SampleBatch.OBS
- policy = Policy.from_state(policy_state)
- sample_batch = SampleBatch(
- {
- SampleBatch.OBS: np.vstack(batch[input_key].values),
- }
- )
- actions, _, _ = policy.compute_actions_from_input_dict(sample_batch, explore=False)
- if not output_key:
- output_key = "predicted_actions"
- batch[output_key] = actions
- return batch
- @ray.remote
- def get_feature_importance_on_index(
- dataset: ray.data.Dataset,
- *,
- index: int,
- perturb_fn: Callable[[pd.DataFrame, int], None],
- batch_size: int,
- policy_state: Dict[str, Any],
- ):
- """A remote function to compute the feature importance of a given index.
- Args:
- dataset: The dataset to use for the computation. The dataset should have `obs`
- and `actions` columns. Each record should be flat d-dimensional array.
- index: The index of the feature to compute the importance for.
- perturb_fn: The function to use for perturbing the dataset at the given index.
- batch_size: The batch size to use for the computation.
- policy_state: The state of the policy to use for the computation.
- Returns:
- The modified dataset that contains a `delta` column which is the absolute
- difference between the expected output and the output due to the perturbation.
- """
- perturbed_ds = dataset.map_batches(
- perturb_fn,
- batch_size=batch_size,
- batch_format="pandas",
- fn_kwargs={"index": index},
- )
- perturbed_actions = perturbed_ds.map_batches(
- _compute_actions,
- batch_size=batch_size,
- batch_format="pandas",
- fn_kwargs={
- "output_key": "perturbed_actions",
- "input_key": "perturbed_obs",
- "policy_state": policy_state,
- },
- )
- def delta_fn(batch):
- # take the abs difference between columns 'ref_actions` and `perturbed_actions`
- # and store it in `diff`
- batch["delta"] = np.abs(batch["ref_actions"] - batch["perturbed_actions"])
- return batch
- delta = perturbed_actions.map_batches(
- delta_fn, batch_size=batch_size, batch_format="pandas"
- )
- return delta
- @DeveloperAPI
- class FeatureImportance(OfflineEvaluator):
- @override(OfflineEvaluator)
- def __init__(
- self,
- policy: Policy,
- repeat: int = 1,
- limit_fraction: float = 1.0,
- perturb_fn: Callable[[pd.DataFrame, int], pd.DataFrame] = _perturb_df,
- ):
- """Feature importance in a model inspection technique that can be used for any
- fitted predictor when the data is tablular.
- This implementation is also known as permutation importance that is defined to
- be the variation of the model's prediction when a single feature value is
- randomly shuffled. In RLlib it is implemented as a custom OffPolicyEstimator
- which is used to evaluate RLlib policies without performing environment
- interactions.
- Example usage: In the example below the feature importance module is used to
- evaluate the policy and the each feature's importance is computed after each
- training iteration. The permutation are repeated `self.repeat` times and the
- results are averages across repeats.
- ```python
- config = (
- AlgorithmConfig()
- .offline_data(
- off_policy_estimation_methods=
- {
- "feature_importance": {
- "type": FeatureImportance,
- "repeat": 10,
- "limit_fraction": 0.1,
- }
- }
- )
- )
- algorithm = DQN(config=config)
- results = algorithm.train()
- ```
- Args:
- policy: the policy to use for feature importance.
- repeat: number of times to repeat the perturbation.
- perturb_fn: function to perturb the features. By default reshuffle the
- features within the batch.
- limit_fraction: fraction of the dataset to use for feature importance
- This is only used in estimate_on_dataset when the dataset is too large
- to compute feature importance on.
- """
- super().__init__(policy)
- self.repeat = repeat
- self.perturb_fn = perturb_fn
- self.limit_fraction = limit_fraction
- def estimate(self, batch: SampleBatchType) -> Dict[str, Any]:
- """Estimate the feature importance of the policy.
- Given a batch of tabular observations, the importance of each feature is
- computed by perturbing each feature and computing the difference between the
- perturbed policy and the reference policy. The importance is computed for each
- feature and each perturbation is repeated `self.repeat` times.
- Args:
- batch: the batch of data to use for feature importance.
- Returns:
- A dict mapping each feature index string to its importance.
- """
- batch = convert_ma_batch_to_sample_batch(batch)
- obs_batch = batch["obs"]
- n_features = obs_batch.shape[-1]
- importance = np.zeros((self.repeat, n_features))
- ref_actions, _, _ = self.policy.compute_actions(obs_batch, explore=False)
- for r in range(self.repeat):
- for i in range(n_features):
- copy_obs_batch = copy.deepcopy(obs_batch)
- _perturb_fn(copy_obs_batch, index=i)
- perturbed_actions, _, _ = self.policy.compute_actions(
- copy_obs_batch, explore=False
- )
- importance[r, i] = np.mean(np.abs(perturbed_actions - ref_actions))
- # take an average across repeats
- importance = importance.mean(0)
- metrics = {f"feature_{i}": importance[i] for i in range(len(importance))}
- return metrics
- @override(OfflineEvaluator)
- def estimate_on_dataset(
- self, dataset: Dataset, *, n_parallelism: int = ...
- ) -> Dict[str, Any]:
- """Estimate the feature importance of the policy given a dataset.
- For each feature in the dataset, the importance is computed by applying
- perturbations to each feature and computing the difference between the
- perturbed prediction and the reference prediction. The importance
- computation for each feature and each perturbation is repeated `self.repeat`
- times. If dataset is large the user can initialize the estimator with a
- `limit_fraction` to limit the dataset to a fraction of the original dataset.
- The dataset should include a column named `obs` where each row is a vector of D
- dimensions. The importance is computed for each dimension of the vector.
- Note (Implementation detail): The computation across features are distributed
- with ray workers since each feature is independent of each other.
- Args:
- dataset: the dataset to use for feature importance.
- n_parallelism: number of parallel workers to use for feature importance.
- Returns:
- A dict mapping each feature index string to its importance.
- """
- policy_state = self.policy.get_state()
- # step 1: limit the dataset to a few first rows
- ds = dataset.limit(int(self.limit_fraction * dataset.count()))
- # step 2: compute the reference actions
- bsize = max(1, ds.count() // n_parallelism)
- actions_ds = ds.map_batches(
- _compute_actions,
- batch_size=bsize,
- fn_kwargs={
- "output_key": "ref_actions",
- "policy_state": policy_state,
- },
- )
- # step 3: compute the feature importance
- n_features = ds.take(1)[0][SampleBatch.OBS].shape[-1]
- importance = np.zeros((self.repeat, n_features))
- for r in range(self.repeat):
- # shuffle the entire dataset
- shuffled_ds = actions_ds.random_shuffle()
- bsize_per_task = max(1, (shuffled_ds.count() * n_features) // n_parallelism)
- # for each index perturb the dataset and compute the feat importance score
- remote_fns = [
- get_feature_importance_on_index.remote(
- dataset=shuffled_ds,
- index=i,
- perturb_fn=self.perturb_fn,
- bsize=bsize_per_task,
- policy_state=policy_state,
- )
- for i in range(n_features)
- ]
- ds_w_fi_scores = ray.get(remote_fns)
- importance[r] = np.array([d.mean("delta") for d in ds_w_fi_scores])
- importance = importance.mean(0)
- metrics = {f"feature_{i}": importance[i] for i in range(len(importance))}
- return metrics
|