action_matching.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. '''
  2. Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
  3. '''
  4. import jax
  5. import jax.numpy as jnp
  6. import numpy as np
  7. # import action_type as action_type_lib
  8. import enum
  9. class ActionType(enum.IntEnum):
  10. # Placeholders for unused enum values
  11. UNUSED_0 = 0
  12. UNUSED_1 = 1
  13. UNUSED_2 = 2
  14. UNUSED_8 = 8
  15. UNUSED_9 = 9
  16. ########### Agent actions ###########
  17. # A type action that sends text to the emulator. Note that this simply sends
  18. # text and does not perform any clicks for element focus or enter presses for
  19. # submitting text.
  20. TYPE = 3
  21. # The dual point action used to represent all gestures.
  22. DUAL_POINT = 4
  23. # These actions differentiate pressing the home and back button from touches.
  24. # They represent explicit presses of back and home performed using ADB.
  25. PRESS_BACK = 5
  26. PRESS_HOME = 6
  27. # An action representing that ADB command for hitting enter was performed.
  28. PRESS_ENTER = 7
  29. ########### Episode status actions ###########
  30. # An action used to indicate the desired task has been completed and resets
  31. # the environment. This action should also be used in the case that the task
  32. # has already been completed and there is nothing to do.
  33. # e.g. The task is to turn on the Wi-Fi when it is already on
  34. STATUS_TASK_COMPLETE = 10
  35. # An action used to indicate that desired task is impossible to complete and
  36. # resets the environment. This can be a result of many different things
  37. # including UI changes, Android version differences, etc.
  38. STATUS_TASK_IMPOSSIBLE = 11
  39. _TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen
  40. ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
  41. ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
  42. # Interval determining if an action is a tap or a swipe.
  43. _SWIPE_DISTANCE_THRESHOLD = 0.04
  44. def _yx_in_bounding_boxes(
  45. yx, bounding_boxes
  46. ):
  47. """Check if the (y,x) point is contained in each bounding box.
  48. Args:
  49. yx: The (y, x) coordinate in pixels of the point.
  50. bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row
  51. represents a bounding box: (y_top_left, x_top_left, box_height,
  52. box_width). Note: containment is inclusive of the bounding box edges.
  53. Returns:
  54. is_inside: A 1D bool array where each element specifies if the point is
  55. contained within the respective box.
  56. """
  57. y, x = yx
  58. # `bounding_boxes` has shape (n_elements, 4); we extract each array along the
  59. # last axis into shape (n_elements, 1), then squeeze unneeded dimension.
  60. top, left, height, width = [
  61. jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
  62. ]
  63. # The y-axis is inverted for AndroidEnv, so bottom = top + height.
  64. bottom, right = top + height, left + width
  65. return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(
  66. x >= left, x <= right)
  67. def _resize_annotation_bounding_boxes(
  68. annotation_positions, annotation_width_augment_fraction,
  69. annotation_height_augment_fraction):
  70. """Resize the bounding boxes by the given fractions.
  71. Args:
  72. annotation_positions: Array of shape (N, 4), where each row represents the
  73. (y, x, height, width) of the bounding boxes.
  74. annotation_width_augment_fraction: The fraction to augment the box widths,
  75. E.g., 1.4 == 240% total increase.
  76. annotation_height_augment_fraction: Same as described for width, but for box
  77. height.
  78. Returns:
  79. Resized bounding box.
  80. """
  81. height_change = (
  82. annotation_height_augment_fraction * annotation_positions[:, 2])
  83. width_change = (
  84. annotation_width_augment_fraction * annotation_positions[:, 3])
  85. # Limit bounding box positions to the screen.
  86. resized_annotations = jnp.stack([
  87. jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
  88. jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
  89. jnp.minimum(1, annotation_positions[:, 2] + height_change),
  90. jnp.minimum(1, annotation_positions[:, 3] + width_change),
  91. ],
  92. axis=1)
  93. return resized_annotations
  94. def is_tap_action(normalized_start_yx,
  95. normalized_end_yx):
  96. distance = jnp.linalg.norm(
  97. jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
  98. return distance <= _SWIPE_DISTANCE_THRESHOLD
  99. def _is_non_dual_point_action(action_type):
  100. return jnp.not_equal(action_type, ActionType.DUAL_POINT)
  101. def _check_tap_actions_match(
  102. tap_1_yx,
  103. tap_2_yx,
  104. annotation_positions,
  105. matching_tap_distance_threshold_screen_percentage,
  106. annotation_width_augment_fraction,
  107. annotation_height_augment_fraction,
  108. ):
  109. """Determines if two tap actions are the same."""
  110. resized_annotation_positions = _resize_annotation_bounding_boxes(
  111. annotation_positions,
  112. annotation_width_augment_fraction,
  113. annotation_height_augment_fraction,
  114. )
  115. # Check if the ground truth tap action falls in an annotation's bounding box.
  116. tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
  117. tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
  118. both_in_box = jnp.max(tap1_in_box & tap2_in_box)
  119. # If the ground-truth tap action falls outside any of the annotation
  120. # bounding boxes or one of the actions is inside a bounding box and the other
  121. # is outside bounding box or vice versa, compare the points using Euclidean
  122. # distance.
  123. within_threshold = (
  124. jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
  125. <= matching_tap_distance_threshold_screen_percentage
  126. )
  127. return jnp.logical_or(both_in_box, within_threshold)
  128. def _check_drag_actions_match(
  129. drag_1_touch_yx,
  130. drag_1_lift_yx,
  131. drag_2_touch_yx,
  132. drag_2_lift_yx,
  133. ):
  134. """Determines if two drag actions are the same."""
  135. # Store drag deltas (the change in the y and x coordinates from touch to
  136. # lift), magnitudes, and the index of the main axis, which is the axis with
  137. # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
  138. # ending at (0.3, 0.5) has a main axis index of 1).
  139. drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
  140. drag_1_magnitudes = jnp.abs(drag_1_deltas)
  141. drag_1_main_axis = np.argmax(drag_1_magnitudes)
  142. drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
  143. drag_2_magnitudes = jnp.abs(drag_2_deltas)
  144. drag_2_main_axis = np.argmax(drag_2_magnitudes)
  145. return jnp.equal(drag_1_main_axis, drag_2_main_axis)
  146. def check_actions_match(
  147. action_1_touch_yx,
  148. action_1_lift_yx,
  149. action_1_action_type,
  150. action_2_touch_yx,
  151. action_2_lift_yx,
  152. action_2_action_type,
  153. annotation_positions,
  154. tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
  155. annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
  156. annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
  157. ):
  158. """Determines if two actions are considered to be the same.
  159. Two actions being "the same" is defined here as two actions that would result
  160. in a similar screen state.
  161. Args:
  162. action_1_touch_yx: The (y, x) coordinates of the first action's touch.
  163. action_1_lift_yx: The (y, x) coordinates of the first action's lift.
  164. action_1_action_type: The action type of the first action.
  165. action_2_touch_yx: The (y, x) coordinates of the second action's touch.
  166. action_2_lift_yx: The (y, x) coordinates of the second action's lift.
  167. action_2_action_type: The action type of the second action.
  168. annotation_positions: The positions of the UI annotations for the screen. It
  169. is A 2D int array of shape (num_bboxes, 4), where each row represents a
  170. bounding box: (y_top_left, x_top_left, box_height, box_width). Note that
  171. containment is inclusive of the bounding box edges.
  172. tap_distance_threshold: The threshold that determines if two taps result in
  173. a matching screen state if they don't fall the same bounding boxes.
  174. annotation_width_augment_fraction: The fraction to increase the width of the
  175. bounding box by.
  176. annotation_height_augment_fraction: The fraction to increase the height of
  177. of the bounding box by.
  178. Returns:
  179. A boolean representing whether the two given actions are the same or not.
  180. """
  181. action_1_touch_yx = jnp.asarray(action_1_touch_yx)
  182. action_1_lift_yx = jnp.asarray(action_1_lift_yx)
  183. action_2_touch_yx = jnp.asarray(action_2_touch_yx)
  184. action_2_lift_yx = jnp.asarray(action_2_lift_yx)
  185. # Checks if at least one of the actions is global (i.e. not DUAL_POINT),
  186. # because if that is the case, only the actions' types need to be compared.
  187. has_non_dual_point_action = jnp.logical_or(
  188. _is_non_dual_point_action(action_1_action_type),
  189. _is_non_dual_point_action(action_2_action_type),
  190. )
  191. #print("non dual point: "+str(has_non_dual_point_action))
  192. different_dual_point_types = jnp.logical_xor(
  193. is_tap_action(action_1_touch_yx, action_1_lift_yx),
  194. is_tap_action(action_2_touch_yx, action_2_lift_yx),
  195. )
  196. #print("different dual type: "+str(different_dual_point_types))
  197. is_tap = jnp.logical_and(
  198. is_tap_action(action_1_touch_yx, action_1_lift_yx),
  199. is_tap_action(action_2_touch_yx, action_2_lift_yx),
  200. )
  201. #print("is tap: "+str(is_tap))
  202. taps_match = _check_tap_actions_match(
  203. action_1_touch_yx,
  204. action_2_touch_yx,
  205. annotation_positions,
  206. tap_distance_threshold,
  207. annotation_width_augment_fraction,
  208. annotation_height_augment_fraction,
  209. )
  210. #print("tap match: "+str(taps_match))
  211. taps_match = jnp.logical_and(is_tap, taps_match)
  212. #print("tap match: "+str(taps_match))
  213. drags_match = _check_drag_actions_match(
  214. action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
  215. )
  216. drags_match = jnp.where(is_tap, False, drags_match)
  217. #print("drag match: "+str(drags_match))
  218. return jnp.where(
  219. has_non_dual_point_action,
  220. jnp.equal(action_1_action_type, action_2_action_type),
  221. jnp.where(
  222. different_dual_point_types,
  223. False,
  224. jnp.logical_or(taps_match, drags_match),
  225. ),
  226. )
  227. def action_2_format(step_data):
  228. # 把test数据集中的动作格式转换为计算matching score的格式
  229. action_type = step_data["action_type_id"]
  230. if action_type == 4:
  231. if step_data["action_type_text"] == 'click': # 点击
  232. touch_point = step_data["touch"]
  233. lift_point = step_data["lift"]
  234. else: # 上下左右滑动
  235. if step_data["action_type_text"] == 'scroll down':
  236. touch_point = [0.5, 0.8]
  237. lift_point = [0.5, 0.2]
  238. elif step_data["action_type_text"] == 'scroll up':
  239. touch_point = [0.5, 0.2]
  240. lift_point = [0.5, 0.8]
  241. elif step_data["action_type_text"] == 'scroll left':
  242. touch_point = [0.2, 0.5]
  243. lift_point = [0.8, 0.5]
  244. elif step_data["action_type_text"] == 'scroll right':
  245. touch_point = [0.8, 0.5]
  246. lift_point = [0.2, 0.5]
  247. else:
  248. touch_point = [-1.0, -1.0]
  249. lift_point = [-1.0, -1.0]
  250. if action_type == 3:
  251. typed_text = step_data["type_text"]
  252. else:
  253. typed_text = ""
  254. action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point,
  255. "typed_text": typed_text}
  256. action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
  257. action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
  258. action["typed_text"] = action["typed_text"].lower()
  259. return action
  260. def pred_2_format(step_data):
  261. # 把模型输出的内容转换为计算action_matching的格式
  262. action_type = step_data["action_type"]
  263. if action_type == 4: # 点击
  264. action_type_new = 4
  265. touch_point = step_data["click_point"]
  266. lift_point = step_data["click_point"]
  267. typed_text = ""
  268. elif action_type == 0:
  269. action_type_new = 4
  270. touch_point = [0.5, 0.8]
  271. lift_point = [0.5, 0.2]
  272. typed_text = ""
  273. elif action_type == 1:
  274. action_type_new = 4
  275. touch_point = [0.5, 0.2]
  276. lift_point = [0.5, 0.8]
  277. typed_text = ""
  278. elif action_type == 8:
  279. action_type_new = 4
  280. touch_point = [0.2, 0.5]
  281. lift_point = [0.8, 0.5]
  282. typed_text = ""
  283. elif action_type == 9:
  284. action_type_new = 4
  285. touch_point = [0.8, 0.5]
  286. lift_point = [0.2, 0.5]
  287. typed_text = ""
  288. else:
  289. action_type_new = action_type
  290. touch_point = [-1.0, -1.0]
  291. lift_point = [-1.0, -1.0]
  292. typed_text = ""
  293. if action_type_new == 3:
  294. typed_text = step_data["typed_text"]
  295. action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
  296. "typed_text": typed_text}
  297. action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
  298. action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
  299. action["typed_text"] = action["typed_text"].lower()
  300. return action
  301. def pred_2_format_simplified(step_data):
  302. # 把模型输出的内容转换为计算action_matching的格式
  303. action_type = step_data["action_type"]
  304. if action_type == 'click' : # 点击
  305. action_type_new = 4
  306. touch_point = step_data["click_point"]
  307. lift_point = step_data["click_point"]
  308. typed_text = ""
  309. elif action_type == 'scroll' and step_data["direction"] == 'down':
  310. action_type_new = 4
  311. touch_point = [0.5, 0.8]
  312. lift_point = [0.5, 0.2]
  313. typed_text = ""
  314. elif action_type == 'scroll' and step_data["direction"] == 'up':
  315. action_type_new = 4
  316. touch_point = [0.5, 0.2]
  317. lift_point = [0.5, 0.8]
  318. typed_text = ""
  319. elif action_type == 'scroll' and step_data["direction"] == 'left':
  320. action_type_new = 4
  321. touch_point = [0.2, 0.5]
  322. lift_point = [0.8, 0.5]
  323. typed_text = ""
  324. elif action_type == 'scroll' and step_data["direction"] == 'right':
  325. action_type_new = 4
  326. touch_point = [0.8, 0.5]
  327. lift_point = [0.2, 0.5]
  328. typed_text = ""
  329. elif action_type == 'type':
  330. action_type_new = 3
  331. touch_point = [-1.0, -1.0]
  332. lift_point = [-1.0, -1.0]
  333. typed_text = step_data["text"]
  334. elif action_type == 'navigate_back':
  335. action_type_new = 5
  336. touch_point = [-1.0, -1.0]
  337. lift_point = [-1.0, -1.0]
  338. typed_text = ""
  339. elif action_type == 'navigate_home':
  340. action_type_new = 6
  341. touch_point = [-1.0, -1.0]
  342. lift_point = [-1.0, -1.0]
  343. typed_text = ""
  344. else:
  345. action_type_new = action_type
  346. touch_point = [-1.0, -1.0]
  347. lift_point = [-1.0, -1.0]
  348. typed_text = ""
  349. # if action_type_new == 'type':
  350. # typed_text = step_data["text"]
  351. action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
  352. "typed_text": typed_text}
  353. action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
  354. action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
  355. action["typed_text"] = action["typed_text"].lower()
  356. return action