onnx_tf.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import numpy as np
  2. import ray
  3. import ray.rllib.agents.ppo as ppo
  4. import onnxruntime
  5. import os
  6. import shutil
  7. # Configure our PPO trainer
  8. config = ppo.DEFAULT_CONFIG.copy()
  9. config["num_gpus"] = 0
  10. config["num_workers"] = 1
  11. config["framework"] = "tf"
  12. outdir = "export_tf"
  13. if os.path.exists(outdir):
  14. shutil.rmtree(outdir)
  15. np.random.seed(1234)
  16. # We will run inference with this test batch
  17. test_data = {
  18. "obs": np.random.uniform(0, 1., size=(10, 4)).astype(np.float32),
  19. }
  20. # Start Ray and initialize a PPO trainer
  21. ray.init()
  22. trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
  23. # You could train the model here
  24. # trainer.train()
  25. # Let's run inference on the tensorflow model
  26. policy = trainer.get_policy()
  27. result_tf, _ = policy.model(test_data)
  28. # Evaluate tensor to fetch numpy array
  29. with policy._sess.as_default():
  30. result_tf = result_tf.eval()
  31. # This line will export the model to ONNX
  32. res = trainer.export_policy_model(outdir, onnx=11)
  33. # Import ONNX model
  34. exported_model_file = os.path.join(outdir, "saved_model.onnx")
  35. # Start an inference session for the ONNX model
  36. session = onnxruntime.InferenceSession(exported_model_file, None)
  37. # Pass the same test batch to the ONNX model (rename to match tensor names)
  38. onnx_test_data = {f"default_policy/{k}:0": v for k, v in test_data.items()}
  39. result_onnx = session.run(["default_policy/model/fc_out/BiasAdd:0"],
  40. onnx_test_data)
  41. # These results should be equal!
  42. print("TENSORFLOW", result_tf)
  43. print("ONNX", result_onnx)
  44. assert np.allclose(result_tf, result_onnx), \
  45. "Model outputs are NOT equal. FAILED"
  46. print("Model outputs are equal. PASSED")