test_dead_actors.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. #!/usr/bin/env python
  2. import argparse
  3. import json
  4. import logging
  5. import numpy as np
  6. import os
  7. import sys
  8. import time
  9. import ray
  10. logging.basicConfig(level=logging.INFO)
  11. logger = logging.getLogger(__name__)
  12. ray.init(address="auto")
  13. @ray.remote
  14. class Child(object):
  15. def __init__(self, death_probability):
  16. self.death_probability = death_probability
  17. def ping(self):
  18. # Exit process with some probability.
  19. exit_chance = np.random.rand()
  20. if exit_chance > self.death_probability:
  21. sys.exit(-1)
  22. @ray.remote
  23. class Parent(object):
  24. def __init__(self, num_children, death_probability):
  25. self.death_probability = death_probability
  26. self.children = [Child.remote(death_probability) for _ in range(num_children)]
  27. def ping(self, num_pings):
  28. children_outputs = []
  29. for _ in range(num_pings):
  30. children_outputs += [child.ping.remote() for child in self.children]
  31. try:
  32. ray.get(children_outputs)
  33. except Exception:
  34. # Replace the children if one of them died.
  35. self.__init__(len(self.children), self.death_probability)
  36. def kill(self):
  37. # Clean up children.
  38. ray.get([child.__ray_terminate__.remote() for child in self.children])
  39. def parse_script_args():
  40. parser = argparse.ArgumentParser()
  41. parser.add_argument("--num-nodes", type=int, default=100)
  42. parser.add_argument("--num-parents", type=int, default=10)
  43. parser.add_argument("--num-children", type=int, default=10)
  44. parser.add_argument("--death-probability", type=int, default=0.95)
  45. return parser.parse_known_args()
  46. if __name__ == "__main__":
  47. args, unknown = parse_script_args()
  48. result = {"success": 0}
  49. # These numbers need to correspond with the autoscaler config file.
  50. # The number of remote nodes in the autoscaler should upper bound
  51. # these because sometimes nodes fail to update.
  52. num_remote_nodes = args.num_nodes
  53. num_parents = args.num_parents
  54. num_children = args.num_children
  55. death_probability = args.death_probability
  56. # Wait until the expected number of nodes have joined the cluster.
  57. num_nodes = len(ray.nodes())
  58. assert (
  59. num_nodes >= num_remote_nodes + 1
  60. ), f"Expect {num_remote_nodes+1}, but only {num_nodes} joined."
  61. logger.info(
  62. "Nodes have all joined. There are %s resources.", ray.cluster_resources()
  63. )
  64. parents = [
  65. Parent.remote(num_children, death_probability) for _ in range(num_parents)
  66. ]
  67. start = time.time()
  68. loop_times = []
  69. for i in range(100):
  70. loop_start = time.time()
  71. ray.get([parent.ping.remote(10) for parent in parents])
  72. # Kill a parent actor with some probability.
  73. exit_chance = np.random.rand()
  74. if exit_chance > death_probability:
  75. parent_index = np.random.randint(len(parents))
  76. parents[parent_index].kill.remote()
  77. parents[parent_index] = Parent.remote(num_children, death_probability)
  78. logger.info("Finished trial %s", i)
  79. loop_times.append(time.time() - loop_start)
  80. print("Finished in: {}s".format(time.time() - start))
  81. print("Average iteration time: {}s".format(sum(loop_times) / len(loop_times)))
  82. print("Max iteration time: {}s".format(max(loop_times)))
  83. print("Min iteration time: {}s".format(min(loop_times)))
  84. result["total_time"] = time.time() - start
  85. result["avg_iteration_time"] = sum(loop_times) / len(loop_times)
  86. result["max_iteration_time"] = max(loop_times)
  87. result["min_iteration_time"] = min(loop_times)
  88. result["success"] = 1
  89. if os.environ.get("IS_SMOKE_TEST") != "1":
  90. result["perf_metrics"] = [
  91. {
  92. "perf_metric_name": "avg_iteration_time",
  93. "perf_metric_value": result["avg_iteration_time"],
  94. "perf_metric_type": "LATENCY",
  95. }
  96. ]
  97. print("PASSED.")
  98. with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
  99. f.write(json.dumps(result))