attention_net_supervised.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from gym.spaces import Box, Discrete
  2. import numpy as np
  3. from rllib.models.tf.attention_net import TrXLNet
  4. from ray.rllib.utils.framework import try_import_tf
  5. tf1, tf, tfv = try_import_tf()
  6. def bit_shift_generator(seq_length, shift, batch_size):
  7. while True:
  8. values = np.array([0., 1.], dtype=np.float32)
  9. seq = np.random.choice(values, (batch_size, seq_length, 1))
  10. targets = np.squeeze(np.roll(seq, shift, axis=1).astype(np.int32))
  11. targets[:, :shift] = 0
  12. yield seq, targets
  13. def train_loss(targets, outputs):
  14. loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
  15. labels=targets, logits=outputs)
  16. return tf.reduce_mean(loss)
  17. def train_bit_shift(seq_length, num_iterations, print_every_n):
  18. optimizer = tf.keras.optimizers.Adam(1e-3)
  19. model = TrXLNet(
  20. observation_space=Box(low=0, high=1, shape=(1, ), dtype=np.int32),
  21. action_space=Discrete(2),
  22. num_outputs=2,
  23. model_config={"max_seq_len": seq_length},
  24. name="trxl",
  25. num_transformer_units=1,
  26. attention_dim=10,
  27. num_heads=5,
  28. head_dim=20,
  29. position_wise_mlp_dim=20,
  30. )
  31. shift = 10
  32. train_batch = 10
  33. test_batch = 100
  34. data_gen = bit_shift_generator(
  35. seq_length, shift=shift, batch_size=train_batch)
  36. test_gen = bit_shift_generator(
  37. seq_length, shift=shift, batch_size=test_batch)
  38. @tf.function
  39. def update_step(inputs, targets):
  40. model_out = model(
  41. {
  42. "obs": inputs
  43. },
  44. state=[tf.reshape(inputs, [-1, seq_length, 1])],
  45. seq_lens=np.full(shape=(train_batch, ), fill_value=seq_length))
  46. optimizer.minimize(lambda: train_loss(targets, model_out),
  47. lambda: model.trainable_variables)
  48. for i, (inputs, targets) in zip(range(num_iterations), data_gen):
  49. inputs_in = np.reshape(inputs, [-1, 1])
  50. targets_in = np.reshape(targets, [-1])
  51. update_step(
  52. tf.convert_to_tensor(inputs_in), tf.convert_to_tensor(targets_in))
  53. if i % print_every_n == 0:
  54. test_inputs, test_targets = next(test_gen)
  55. print(i, train_loss(test_targets, model(test_inputs)))
  56. if __name__ == "__main__":
  57. tf.enable_eager_execution()
  58. train_bit_shift(
  59. seq_length=20,
  60. num_iterations=2000,
  61. print_every_n=200,
  62. )