dict_ops.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import itertools as it
  2. import numpy as np
  3. def merge_dicts_recursively(*dicts):
  4. """
  5. Creates a dict whose keyset is the union of all the
  6. input dictionaries. The value for each key is based
  7. on the first dict in the list with that key.
  8. dicts later in the list have higher priority
  9. When values are dictionaries, it is applied recursively
  10. """
  11. result = dict()
  12. all_items = it.chain(*[d.items() for d in dicts])
  13. for key, value in all_items:
  14. if key in result and isinstance(result[key], dict) and isinstance(value, dict):
  15. result[key] = merge_dicts_recursively(result[key], value)
  16. else:
  17. result[key] = value
  18. return result
  19. def soft_dict_update(d1, d2):
  20. """
  21. Adds key values pairs of d2 to d1 only when d1 doesn't
  22. already have that key
  23. """
  24. for key, value in list(d2.items()):
  25. if key not in d1:
  26. d1[key] = value
  27. def dict_eq(d1, d2):
  28. if len(d1) != len(d2):
  29. return False
  30. for key in d1:
  31. value1 = d1[key]
  32. value2 = d2[key]
  33. if type(value1) != type(value2):
  34. return False
  35. if type(d1[key]) == np.ndarray:
  36. if any(d1[key] != d2[key]):
  37. return False
  38. elif d1[key] != d2[key]:
  39. return False
  40. return True