tf_example.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # flake8: noqa
  2. """
  3. This file holds code for the TF best-practices guide in the documentation.
  4. It ignores yapf because yapf doesn't allow comments right after code blocks,
  5. but we put comments right after code blocks to prevent large white spaces
  6. in the documentation.
  7. """
  8. # yapf: disable
  9. # __tf_model_start__
  10. def create_keras_model():
  11. from tensorflow import keras
  12. from tensorflow.keras import layers
  13. model = keras.Sequential()
  14. # Adds a densely-connected layer with 64 units to the model:
  15. model.add(layers.Dense(64, activation="relu", input_shape=(32, )))
  16. # Add another:
  17. model.add(layers.Dense(64, activation="relu"))
  18. # Add a softmax layer with 10 output units:
  19. model.add(layers.Dense(10, activation="softmax"))
  20. model.compile(
  21. optimizer=keras.optimizers.RMSprop(0.01),
  22. loss=keras.losses.categorical_crossentropy,
  23. metrics=[keras.metrics.categorical_accuracy])
  24. return model
  25. # __tf_model_end__
  26. # yapf: enable
  27. # yapf: disable
  28. # __ray_start__
  29. import ray
  30. import numpy as np
  31. ray.init()
  32. def random_one_hot_labels(shape):
  33. n, n_class = shape
  34. classes = np.random.randint(0, n_class, n)
  35. labels = np.zeros((n, n_class))
  36. labels[np.arange(n), classes] = 1
  37. return labels
  38. # Use GPU wth
  39. # @ray.remote(num_gpus=1)
  40. @ray.remote
  41. class Network(object):
  42. def __init__(self):
  43. self.model = create_keras_model()
  44. self.dataset = np.random.random((1000, 32))
  45. self.labels = random_one_hot_labels((1000, 10))
  46. def train(self):
  47. history = self.model.fit(self.dataset, self.labels, verbose=False)
  48. return history.history
  49. def get_weights(self):
  50. return self.model.get_weights()
  51. def set_weights(self, weights):
  52. # Note that for simplicity this does not handle the optimizer state.
  53. self.model.set_weights(weights)
  54. # __ray_end__
  55. # yapf: enable
  56. # yapf: disable
  57. # __actor_start__
  58. NetworkActor = Network.remote()
  59. result_object_ref = NetworkActor.train.remote()
  60. ray.get(result_object_ref)
  61. # __actor_end__
  62. # yapf: enable
  63. # yapf: disable
  64. # __weight_average_start__
  65. NetworkActor2 = Network.remote()
  66. NetworkActor2.train.remote()
  67. weights = ray.get(
  68. [NetworkActor.get_weights.remote(),
  69. NetworkActor2.get_weights.remote()])
  70. averaged_weights = [(layer1 + layer2) / 2
  71. for layer1, layer2 in zip(weights[0], weights[1])]
  72. weight_id = ray.put(averaged_weights)
  73. [
  74. actor.set_weights.remote(weight_id)
  75. for actor in [NetworkActor, NetworkActor2]
  76. ]
  77. ray.get([actor.train.remote() for actor in [NetworkActor, NetworkActor2]])