utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. import abc
  2. import asyncio
  3. import datetime
  4. import functools
  5. import importlib
  6. import json
  7. import logging
  8. import pkgutil
  9. import socket
  10. from abc import ABCMeta, abstractmethod
  11. from base64 import b64decode
  12. from collections import namedtuple
  13. from collections.abc import MutableMapping, Mapping, Sequence
  14. import aioredis # noqa: F401
  15. import aiosignal # noqa: F401
  16. from google.protobuf.json_format import MessageToDict
  17. from frozenlist import FrozenList # noqa: F401
  18. from ray._private.utils import binary_to_hex, check_dashboard_dependencies_installed
  19. try:
  20. create_task = asyncio.create_task
  21. except AttributeError:
  22. create_task = asyncio.ensure_future
  23. logger = logging.getLogger(__name__)
  24. class FrontendNotFoundError(OSError):
  25. pass
  26. class DashboardAgentModule(abc.ABC):
  27. def __init__(self, dashboard_agent):
  28. """
  29. Initialize current module when DashboardAgent loading modules.
  30. :param dashboard_agent: The DashboardAgent instance.
  31. """
  32. self._dashboard_agent = dashboard_agent
  33. @abc.abstractmethod
  34. async def run(self, server):
  35. """
  36. Run the module in an asyncio loop. An agent module can provide
  37. servicers to the server.
  38. :param server: Asyncio GRPC server.
  39. """
  40. @staticmethod
  41. @abc.abstractclassmethod
  42. def is_minimal_module():
  43. """
  44. Return True if the module is minimal, meaning it
  45. should work with `pip install ray` that doesn't requires additonal
  46. dependencies.
  47. """
  48. class DashboardHeadModule(abc.ABC):
  49. def __init__(self, dashboard_head):
  50. """
  51. Initialize current module when DashboardHead loading modules.
  52. :param dashboard_head: The DashboardHead instance.
  53. """
  54. self._dashboard_head = dashboard_head
  55. @abc.abstractmethod
  56. async def run(self, server):
  57. """
  58. Run the module in an asyncio loop. A head module can provide
  59. servicers to the server.
  60. :param server: Asyncio GRPC server.
  61. """
  62. @staticmethod
  63. @abc.abstractclassmethod
  64. def is_minimal_module():
  65. """
  66. Return True if the module is minimal, meaning it
  67. should work with `pip install ray` that doesn't requires additonal
  68. dependencies.
  69. """
  70. def dashboard_module(enable):
  71. """A decorator for dashboard module."""
  72. def _cls_wrapper(cls):
  73. cls.__ray_dashboard_module_enable__ = enable
  74. return cls
  75. return _cls_wrapper
  76. def get_all_modules(module_type):
  77. """
  78. Get all importable modules that are subclass of a given module type.
  79. """
  80. logger.info(f"Get all modules by type: {module_type.__name__}")
  81. import ray.dashboard.modules
  82. should_only_load_minimal_modules = not check_dashboard_dependencies_installed()
  83. for module_loader, name, ispkg in pkgutil.walk_packages(
  84. ray.dashboard.modules.__path__, ray.dashboard.modules.__name__ + "."
  85. ):
  86. try:
  87. importlib.import_module(name)
  88. except ModuleNotFoundError as e:
  89. logger.info(
  90. f"Module {name} cannot be loaded because "
  91. "we cannot import all dependencies. Download "
  92. "`pip install ray[default]` for the full "
  93. f"dashboard functionality. Error: {e}"
  94. )
  95. if not should_only_load_minimal_modules:
  96. logger.info(
  97. "Although `pip install ray[default] is downloaded, "
  98. "module couldn't be imported`"
  99. )
  100. raise e
  101. imported_modules = []
  102. # module_type.__subclasses__() should contain modules that
  103. # we could successfully import.
  104. for m in module_type.__subclasses__():
  105. if not getattr(m, "__ray_dashboard_module_enable__", True):
  106. continue
  107. if should_only_load_minimal_modules and not m.is_minimal_module():
  108. continue
  109. imported_modules.append(m)
  110. logger.info(f"Available modules: {imported_modules}")
  111. return imported_modules
  112. def to_posix_time(dt):
  113. return (dt - datetime.datetime(1970, 1, 1)).total_seconds()
  114. def address_tuple(address):
  115. if isinstance(address, tuple):
  116. return address
  117. ip, port = address.split(":")
  118. return ip, int(port)
  119. class CustomEncoder(json.JSONEncoder):
  120. def default(self, obj):
  121. if isinstance(obj, bytes):
  122. return binary_to_hex(obj)
  123. if isinstance(obj, Immutable):
  124. return obj.mutable()
  125. # Let the base class default method raise the TypeError
  126. return json.JSONEncoder.default(self, obj)
  127. def to_camel_case(snake_str):
  128. """Convert a snake str to camel case."""
  129. components = snake_str.split("_")
  130. # We capitalize the first letter of each component except the first one
  131. # with the 'title' method and join them together.
  132. return components[0] + "".join(x.title() for x in components[1:])
  133. def to_google_style(d):
  134. """Recursive convert all keys in dict to google style."""
  135. new_dict = {}
  136. for k, v in d.items():
  137. if isinstance(v, dict):
  138. new_dict[to_camel_case(k)] = to_google_style(v)
  139. elif isinstance(v, list):
  140. new_list = []
  141. for i in v:
  142. if isinstance(i, dict):
  143. new_list.append(to_google_style(i))
  144. else:
  145. new_list.append(i)
  146. new_dict[to_camel_case(k)] = new_list
  147. else:
  148. new_dict[to_camel_case(k)] = v
  149. return new_dict
  150. def message_to_dict(message, decode_keys=None, **kwargs):
  151. """Convert protobuf message to Python dict."""
  152. def _decode_keys(d):
  153. for k, v in d.items():
  154. if isinstance(v, dict):
  155. d[k] = _decode_keys(v)
  156. if isinstance(v, list):
  157. new_list = []
  158. for i in v:
  159. if isinstance(i, dict):
  160. new_list.append(_decode_keys(i))
  161. else:
  162. new_list.append(i)
  163. d[k] = new_list
  164. else:
  165. if k in decode_keys:
  166. d[k] = binary_to_hex(b64decode(v))
  167. else:
  168. d[k] = v
  169. return d
  170. if decode_keys:
  171. return _decode_keys(
  172. MessageToDict(message, use_integers_for_enums=False, **kwargs)
  173. )
  174. else:
  175. return MessageToDict(message, use_integers_for_enums=False, **kwargs)
  176. class SignalManager:
  177. _signals = FrozenList()
  178. @classmethod
  179. def register(cls, sig):
  180. cls._signals.append(sig)
  181. @classmethod
  182. def freeze(cls):
  183. cls._signals.freeze()
  184. for sig in cls._signals:
  185. sig.freeze()
  186. class Signal(aiosignal.Signal):
  187. __slots__ = ()
  188. def __init__(self, owner):
  189. super().__init__(owner)
  190. SignalManager.register(self)
  191. class Bunch(dict):
  192. """A dict with attribute-access."""
  193. def __getattr__(self, key):
  194. try:
  195. return self.__getitem__(key)
  196. except KeyError:
  197. raise AttributeError(key)
  198. def __setattr__(self, key, value):
  199. self.__setitem__(key, value)
  200. class Change:
  201. """Notify change object."""
  202. def __init__(self, owner=None, old=None, new=None):
  203. self.owner = owner
  204. self.old = old
  205. self.new = new
  206. def __str__(self):
  207. return (
  208. f"Change(owner: {type(self.owner)}), " f"old: {self.old}, new: {self.new}"
  209. )
  210. class NotifyQueue:
  211. """Asyncio notify queue for Dict signal."""
  212. _queue = asyncio.Queue()
  213. @classmethod
  214. def put(cls, co):
  215. cls._queue.put_nowait(co)
  216. @classmethod
  217. async def get(cls):
  218. return await cls._queue.get()
  219. """
  220. https://docs.python.org/3/library/json.html?highlight=json#json.JSONEncoder
  221. +-------------------+---------------+
  222. | Python | JSON |
  223. +===================+===============+
  224. | dict | object |
  225. +-------------------+---------------+
  226. | list, tuple | array |
  227. +-------------------+---------------+
  228. | str | string |
  229. +-------------------+---------------+
  230. | int, float | number |
  231. +-------------------+---------------+
  232. | True | true |
  233. +-------------------+---------------+
  234. | False | false |
  235. +-------------------+---------------+
  236. | None | null |
  237. +-------------------+---------------+
  238. """
  239. _json_compatible_types = {dict, list, tuple, str, int, float, bool, type(None), bytes}
  240. def is_immutable(self):
  241. raise TypeError("%r objects are immutable" % self.__class__.__name__)
  242. def make_immutable(value, strict=True):
  243. value_type = type(value)
  244. if value_type is dict:
  245. return ImmutableDict(value)
  246. if value_type is list:
  247. return ImmutableList(value)
  248. if strict:
  249. if value_type not in _json_compatible_types:
  250. raise TypeError("Type {} can't be immutable.".format(value_type))
  251. return value
  252. class Immutable(metaclass=ABCMeta):
  253. @abstractmethod
  254. def mutable(self):
  255. pass
  256. class ImmutableList(Immutable, Sequence):
  257. """Makes a :class:`list` immutable."""
  258. __slots__ = ("_list", "_proxy")
  259. def __init__(self, list_value):
  260. if type(list_value) not in (list, ImmutableList):
  261. raise TypeError(f"{type(list_value)} object is not a list.")
  262. if isinstance(list_value, ImmutableList):
  263. list_value = list_value.mutable()
  264. self._list = list_value
  265. self._proxy = [None] * len(list_value)
  266. def __reduce_ex__(self, protocol):
  267. return type(self), (self._list,)
  268. def mutable(self):
  269. return self._list
  270. def __eq__(self, other):
  271. if isinstance(other, ImmutableList):
  272. other = other.mutable()
  273. return list.__eq__(self._list, other)
  274. def __ne__(self, other):
  275. if isinstance(other, ImmutableList):
  276. other = other.mutable()
  277. return list.__ne__(self._list, other)
  278. def __contains__(self, item):
  279. if isinstance(item, Immutable):
  280. item = item.mutable()
  281. return list.__contains__(self._list, item)
  282. def __getitem__(self, item):
  283. proxy = self._proxy[item]
  284. if proxy is None:
  285. proxy = self._proxy[item] = make_immutable(self._list[item])
  286. return proxy
  287. def __len__(self):
  288. return len(self._list)
  289. def __repr__(self):
  290. return "%s(%s)" % (self.__class__.__name__, list.__repr__(self._list))
  291. class ImmutableDict(Immutable, Mapping):
  292. """Makes a :class:`dict` immutable."""
  293. __slots__ = ("_dict", "_proxy")
  294. def __init__(self, dict_value):
  295. if type(dict_value) not in (dict, ImmutableDict):
  296. raise TypeError(f"{type(dict_value)} object is not a dict.")
  297. if isinstance(dict_value, ImmutableDict):
  298. dict_value = dict_value.mutable()
  299. self._dict = dict_value
  300. self._proxy = {}
  301. def __reduce_ex__(self, protocol):
  302. return type(self), (self._dict,)
  303. def mutable(self):
  304. return self._dict
  305. def get(self, key, default=None):
  306. try:
  307. return self[key]
  308. except KeyError:
  309. return make_immutable(default)
  310. def __eq__(self, other):
  311. if isinstance(other, ImmutableDict):
  312. other = other.mutable()
  313. return dict.__eq__(self._dict, other)
  314. def __ne__(self, other):
  315. if isinstance(other, ImmutableDict):
  316. other = other.mutable()
  317. return dict.__ne__(self._dict, other)
  318. def __contains__(self, item):
  319. if isinstance(item, Immutable):
  320. item = item.mutable()
  321. return dict.__contains__(self._dict, item)
  322. def __getitem__(self, item):
  323. proxy = self._proxy.get(item, None)
  324. if proxy is None:
  325. proxy = self._proxy[item] = make_immutable(self._dict[item])
  326. return proxy
  327. def __len__(self) -> int:
  328. return len(self._dict)
  329. def __iter__(self):
  330. if len(self._proxy) != len(self._dict):
  331. for key in self._dict.keys() - self._proxy.keys():
  332. self._proxy[key] = make_immutable(self._dict[key])
  333. return iter(self._proxy)
  334. def __repr__(self):
  335. return "%s(%s)" % (self.__class__.__name__, dict.__repr__(self._dict))
  336. class Dict(ImmutableDict, MutableMapping):
  337. """A simple descriptor for dict type to notify data changes.
  338. :note: Only the first level data report change.
  339. """
  340. ChangeItem = namedtuple("DictChangeItem", ["key", "value"])
  341. def __init__(self, *args, **kwargs):
  342. super().__init__(dict(*args, **kwargs))
  343. self.signal = Signal(self)
  344. def __setitem__(self, key, value):
  345. old = self._dict.pop(key, None)
  346. self._proxy.pop(key, None)
  347. self._dict[key] = value
  348. if len(self.signal) and old != value:
  349. if old is None:
  350. co = self.signal.send(
  351. Change(owner=self, new=Dict.ChangeItem(key, value))
  352. )
  353. else:
  354. co = self.signal.send(
  355. Change(
  356. owner=self,
  357. old=Dict.ChangeItem(key, old),
  358. new=Dict.ChangeItem(key, value),
  359. )
  360. )
  361. NotifyQueue.put(co)
  362. def __delitem__(self, key):
  363. old = self._dict.pop(key, None)
  364. self._proxy.pop(key, None)
  365. if len(self.signal) and old is not None:
  366. co = self.signal.send(Change(owner=self, old=Dict.ChangeItem(key, old)))
  367. NotifyQueue.put(co)
  368. def reset(self, d):
  369. assert isinstance(d, Mapping)
  370. for key in self._dict.keys() - d.keys():
  371. del self[key]
  372. for key, value in d.items():
  373. self[key] = value
  374. # Register immutable types.
  375. for immutable_type in Immutable.__subclasses__():
  376. _json_compatible_types.add(immutable_type)
  377. def async_loop_forever(interval_seconds, cancellable=False):
  378. def _wrapper(coro):
  379. @functools.wraps(coro)
  380. async def _looper(*args, **kwargs):
  381. while True:
  382. try:
  383. await coro(*args, **kwargs)
  384. except asyncio.CancelledError as ex:
  385. if cancellable:
  386. logger.info(
  387. f"An async loop forever coroutine " f"is cancelled {coro}."
  388. )
  389. raise ex
  390. else:
  391. logger.exception(
  392. f"Can not cancel the async loop "
  393. f"forever coroutine {coro}."
  394. )
  395. except Exception:
  396. logger.exception(f"Error looping coroutine {coro}.")
  397. await asyncio.sleep(interval_seconds)
  398. return _looper
  399. return _wrapper
  400. async def get_aioredis_client(
  401. redis_address, redis_password, retry_interval_seconds, retry_times
  402. ):
  403. for x in range(retry_times):
  404. try:
  405. return await aioredis.create_redis_pool(
  406. address=redis_address, password=redis_password
  407. )
  408. except (socket.gaierror, ConnectionError) as ex:
  409. logger.error("Connect to Redis failed: %s, retry...", ex)
  410. await asyncio.sleep(retry_interval_seconds)
  411. # Raise exception from create_redis_pool
  412. return await aioredis.create_redis_pool(
  413. address=redis_address, password=redis_password
  414. )