widget.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. # coding: utf-8
  2. #
  3. # DEPRECATED
  4. #
  5. # This file is deprecated and will be removed in the future.
  6. import logging
  7. import re
  8. import time
  9. from collections import defaultdict, namedtuple
  10. from functools import partial
  11. from pprint import pprint
  12. from typing import Union
  13. import requests
  14. from lxml import etree
  15. import uiautomator2 as u2
  16. from uiautomator2.image import compare_ssim, draw_point, imread
  17. logger = logging.getLogger(__name__)
  18. def xml2nodes(xml_content: Union[str, bytes]):
  19. if isinstance(xml_content, str):
  20. xml_content = xml_content.encode("utf-8")
  21. root = etree.fromstring(xml_content)
  22. nodes = []
  23. for _, n in etree.iterwalk(root):
  24. attrib = dict(n.attrib)
  25. if "bounds" in attrib:
  26. bounds = re.findall(r"(\d+)", attrib.pop("bounds"))
  27. if len(bounds) != 4:
  28. continue
  29. lx, ly, rx, ry = map(int, bounds)
  30. attrib['size'] = (rx - lx, ry - ly)
  31. attrib.pop("index", None)
  32. ok = False
  33. for attrname in ("text", "resource-id", "content-desc"):
  34. if attrname in attrib:
  35. ok = True
  36. break
  37. if ok:
  38. items = []
  39. for k, v in sorted(attrib.items()):
  40. items.append(k + ":" + str(v))
  41. nodes.append('|'.join(items))
  42. return nodes
  43. def hierarchy_sim(xml1: str, xml2: str):
  44. ns1 = xml2nodes(xml1)
  45. ns2 = xml2nodes(xml2)
  46. from collections import Counter
  47. c1 = Counter(ns1)
  48. c2 = Counter(ns2)
  49. same_count = sum(
  50. [min(c1[k], c2[k]) for k in set(c1.keys()).intersection(c2.keys())])
  51. logger.debug("Same count: %d ns1: %d ns2: %d", same_count, len(ns1), len(ns2))
  52. return same_count / (len(ns1) + len(ns2)) * 2
  53. def read_file_content(filename: str) -> bytes:
  54. with open(filename, "rb") as f:
  55. return f.read()
  56. def safe_xmlstr(s):
  57. return s.replace("$", "-")
  58. def frozendict(d: dict):
  59. items = []
  60. for k, v in sorted(d.items()):
  61. items.append(k + ":" + str(v))
  62. return '|'.join(items)
  63. CompareResult = namedtuple("CompareResult", ["score", "detail"])
  64. Point = namedtuple("Point", ['x', 'y'])
  65. class Widget(object):
  66. __domains = {
  67. "lo": "http://localhost:17310",
  68. }
  69. def __init__(self, d: "u2.Device"):
  70. self._d = d
  71. self._widgets = {}
  72. self._compare_results = {}
  73. self.popups = []
  74. @property
  75. def wait_timeout(self):
  76. return self._d.settings['wait_timeout']
  77. def _get_widget(self, id: str):
  78. if id in self._widgets:
  79. return self._widgets[id]
  80. widget_url = self._id2url(id)
  81. r = requests.get(widget_url, timeout=3)
  82. data = r.json()
  83. self._widgets[id] = data
  84. return data
  85. def _id2url(self, id: str):
  86. fields = re.sub("#.*", "", id).split(
  87. "/") # remove chars after # and split host and id
  88. assert len(fields) <= 2
  89. if len(fields) == 1:
  90. return f"http://localhost:17310/api/v1/widgets/{id}"
  91. host = self.__domains.get(fields[0])
  92. id = fields[1] # ignore the third part
  93. if not re.match("^https?://", host):
  94. host = "http://" + host
  95. return f"{host}/api/v1/widgets/{id}"
  96. def _eq(self, precision: float, a, b):
  97. return abs(a - b) < precision
  98. def _percent_equal(self, precision: float, a, b, asize, bsize):
  99. return abs(a / min(asize) - b / min(bsize)) < precision
  100. def _bounds2rect(self, bounds: str):
  101. """
  102. Returns:
  103. tuple: (lx, ly, width, height)
  104. """
  105. if not bounds:
  106. return 0, 0, 0, 0
  107. lx, ly, rx, ry = map(int, re.findall(r"\d+", bounds))
  108. return (lx, ly, rx - lx, ry - ly)
  109. def _compare_node(self, node_a, node_b, size_a, size_b) -> float:
  110. """
  111. Args:
  112. node_a, node_b: etree.Element
  113. size_a, size_b: tuple size
  114. Returns:
  115. CompareResult
  116. """
  117. result_key = (node_a, node_b)
  118. if result_key in self._compare_results:
  119. return self._compare_results[result_key]
  120. scores = defaultdict(dict)
  121. # max 1
  122. if node_a.tag == node_b.tag:
  123. scores['class'] = 1
  124. # max 3
  125. for key in ('text', 'resource-id', 'content-desc'):
  126. if node_a.attrib.get(key) == node_b.attrib.get(key):
  127. scores[key] = 1 if node_a.attrib.get(key) else 0.1
  128. # bounds = node_a.attrib.get("bounds")
  129. # pprint(list(map(int, re.findall(r"\d+", bounds))))
  130. ax, ay, aw, ah = self._bounds2rect(node_a.attrib.get("bounds"))
  131. bx, by, bw, bh = self._bounds2rect(node_b.attrib.get("bounds"))
  132. # max 2
  133. peq = partial(self._percent_equal, 1 / 20, asize=size_a, bsize=size_b)
  134. if peq(ax, bx) and peq(ay, by):
  135. scores['left_top'] = 1
  136. if peq(aw, bw) and peq(ah, bh):
  137. scores['size'] = 1
  138. score = round(sum(scores.values()), 1)
  139. result = self._compare_results[result_key] = CompareResult(
  140. score, scores)
  141. return result
  142. def node2string(self, node: etree.Element):
  143. return node.tag + ":" + '|'.join([
  144. node.attrib.get(key, "")
  145. for key in ["text", "resource-id", "content-desc"]
  146. ])
  147. def hybird_compare_node(self, node_a, node_b, size_a, size_b):
  148. """
  149. Returns:
  150. (scores, results)
  151. Return example:
  152. 【3.0, 3.2], [CompareResult(score=3.0), CompareResult(score=3.2)]
  153. """
  154. cmp_node = partial(self._compare_node, size_a=size_a, size_b=size_b)
  155. results = []
  156. results.append(cmp_node(node_a, node_b))
  157. results.append(cmp_node(node_a.getparent(), node_b.getparent()))
  158. a_children = node_a.getparent().getchildren()
  159. b_children = node_b.getparent().getchildren()
  160. if len(a_children) != len(b_children):
  161. return results
  162. children_result = []
  163. a_children.remove(node_a)
  164. b_children.remove(node_b)
  165. for i in range(len(a_children)):
  166. children_result.append(cmp_node(a_children[i], b_children[i]))
  167. results.append(children_result)
  168. return results
  169. def _hybird_result_to_score(self, obj: Union[list, CompareResult]):
  170. """
  171. Convert hybird_compare_node returns to score
  172. """
  173. if isinstance(obj, CompareResult):
  174. return obj.score
  175. ret = []
  176. for item in obj:
  177. ret.append(self._hybird_result_to_score(item))
  178. return ret
  179. def replace_etree_node_to_class(self, root: etree.ElementTree):
  180. for node in root.xpath("//node"):
  181. node.tag = safe_xmlstr(node.attrib.pop("class", "") or "node")
  182. return root
  183. def compare_hierarchy(self, node, root, node_wsize, root_wsize):
  184. results = {}
  185. for node2 in root.xpath("/hierarchy//*"):
  186. result = self.hybird_compare_node(node, node2, node_wsize, root_wsize)
  187. results[node2] = result #score
  188. return results
  189. def etree_fromstring(self, s: str):
  190. root = etree.fromstring(s.encode('utf-8'))
  191. return self.replace_etree_node_to_class(root)
  192. def node_center_point(self, node) -> Point:
  193. lx, ly, rx, ry = map(int, re.findall(r"\d+",
  194. node.attrib.get("bounds")))
  195. return Point((lx + rx) // 2, (ly + ry) // 2)
  196. def match(self, widget: dict, hierarchy=None, window_size: tuple = None):
  197. """
  198. Args:
  199. widget: widget id
  200. hierarchy (optional): current page hierarchy
  201. window_size (tuple): width and height
  202. Returns:
  203. None or MatchResult(point, score, detail, xpath, node, next_result)
  204. """
  205. window_size = window_size or self._d.window_size()
  206. hierarchy = hierarchy or self._d.dump_hierarchy()
  207. w = widget.copy()
  208. widget_root = self.etree_fromstring(w['hierarchy'])
  209. widget_node = widget_root.xpath(w['xpath'])[0]
  210. # 节点打分
  211. target_root = self.etree_fromstring(hierarchy)
  212. results = self.compare_hierarchy(widget_node, target_root, w['window_size'], window_size) # yapf: disable
  213. # score结构调整
  214. scores = {}
  215. for node, result in results.items():
  216. scores[node] = self._hybird_result_to_score(result) # score eg: [3.2, 2.2, [1.0, 1.2]]
  217. # 打分排序
  218. nodes = list(scores.keys())
  219. nodes.sort(key=lambda n: scores[n], reverse=True)
  220. possible_nodes = nodes[:10]
  221. # compare image
  222. # screenshot = self._d.screenshot()
  223. # for node in possible_nodes:
  224. # bounds = node.attrib.get("bounds")
  225. # lx, ly, rx, ry = bounds = list(map(int, re.findall(r"\d+", bounds)))
  226. # w, h = rx - lx, ry - ly
  227. # crop_image = screenshot.crop(bounds)
  228. # template = imread(w['target_image']['url'])
  229. # try:
  230. # score = compare_ssim(template, crop_image)
  231. # scores[node][0] += score
  232. # except ValueError:
  233. # pass
  234. # nodes.sort(key=lambda n: scores[n], reverse=True)
  235. first, second = nodes[:2]
  236. MatchResult = namedtuple(
  237. "MatchResult",
  238. ["point", "score", "detail", "xpath", "node", "next_result"])
  239. def get_result(node, next_result=None):
  240. point = self.node_center_point(node)
  241. xpath = node.getroottree().getpath(node)
  242. return MatchResult(point, scores[node], results[node], xpath,
  243. node, next_result)
  244. return get_result(first, get_result(second))
  245. def exists(self, id: str) -> bool:
  246. pass
  247. def update_widget(self, id, hierarchy, xpath):
  248. url = self._id2url(id)
  249. r = requests.put(url, json={"hierarchy": hierarchy, "xpath": xpath})
  250. print(r.json())
  251. def wait(self, id: str, timeout=None):
  252. """
  253. Args:
  254. timeout (float): seconds to wait
  255. Returns:
  256. None or Result
  257. """
  258. timeout = timeout or self.wait_timeout
  259. widget = self._get_widget(id) # 获取节点信息
  260. begin_time = time.time()
  261. deadline = time.time() + timeout
  262. while time.time() < deadline:
  263. hierarchy = self._d.dump_hierarchy()
  264. hsim = hierarchy_sim(hierarchy, widget['hierarchy'])
  265. app = self._d.app_current()
  266. is_same_activity = widget['activity'] == app['activity']
  267. if not is_same_activity:
  268. print("activity different:", "got", app['activity'], 'expect', widget['activity'])
  269. print("hierarchy: %.1f%%" % hsim)
  270. print("----------------------")
  271. window_size = self._d.window_size()
  272. page_ok = False
  273. if is_same_activity:
  274. if hsim > 0.7:
  275. page_ok = True
  276. if time.time() - begin_time > 10.0 and hsim > 0.6:
  277. page_ok = True
  278. if page_ok:
  279. result = self.match(widget, hierarchy, window_size)
  280. if result.score[0] < 2:
  281. time.sleep(0.5)
  282. continue
  283. if hsim < 0.8:
  284. self.update_widget(id, hierarchy, result.xpath)
  285. return result
  286. time.sleep(1.0)
  287. def click(self, id: str, debug: bool = False, timeout=10):
  288. print("Click", id)
  289. result = self.wait(id, timeout=timeout)
  290. if result is None:
  291. raise RuntimeError("target not found")
  292. x, y = result.point
  293. if debug:
  294. show_click_position(self._d, Point(x, y))
  295. self._d.click(x, y)
  296. # return
  297. # while True:
  298. # hierarchy = self._d.dump_hierarchy()
  299. # hsim = hierarchy_sim(hierarchy, widget['hierarchy'])
  300. # app = self._d.app_current()
  301. # is_same_activity = widget['activity'] == app['activity']
  302. # print("activity same:", is_same_activity)
  303. # print("hierarchy:", hsim)
  304. # window_size = self._d.window_size()
  305. # if is_same_activity and hsim > 0.8:
  306. # result = self.match(widget, hierarchy, window_size)
  307. # pprint(result.score)
  308. # pprint(result.second.score)
  309. # x, y = result.point
  310. # self._d.click(x, y)
  311. # return
  312. # time.sleep(0.1)
  313. # return
  314. def show_click_position(d: u2.Device, point: Point):
  315. # # pprint(result.widget)
  316. # # pprint(dict(result.node.attrib))
  317. im = draw_point(d.screenshot(), point.x, point.y)
  318. im.show()
  319. def main():
  320. d = u2.connect("30.10.93.26")
  321. # d.widget.click("00013#推荐歌单第一首")
  322. d.widget.exists("lo/00019#播放全部")
  323. return
  324. d.widget.click("00019#播放全部")
  325. # d.widget.click("00018#播放暂停")
  326. d.widget.click("00018#播放暂停")
  327. d.widget.click("00021#转到上一层级")
  328. return
  329. d.widget.click("每日推荐")
  330. widget_id = "00009#上新"
  331. widget_id = "00011#每日推荐"
  332. widget_id = "00014#立减20"
  333. result = d.widget.match(widget_id)
  334. # e = Widget(d)
  335. # result = e.match("00003")
  336. # print(result)
  337. # # e.match("00002")
  338. # # result = e.match("00007")
  339. wsize = d.window_size()
  340. from lxml import etree
  341. result = d.widget.match(widget_id)
  342. pprint(result.node.attrib)
  343. pprint(result.score)
  344. pprint(result.detail)
  345. show_click_position(d, result.point)
  346. return
  347. root = etree.parse(
  348. '/Users/shengxiang/Projects/weditor/widgets/00010/hierarchy.xml')
  349. nodes = root.xpath('/hierarchy/node/node/node/node')
  350. a, b = nodes[0], nodes[1]
  351. result = d.widget.hybird_compare_node(a, b, wsize, wsize)
  352. pprint(result)
  353. score = d.widget._hybird_result_to_score(result)
  354. pprint(score)
  355. return
  356. score = d.widget._compare_node(a, b, wsize, wsize)
  357. print(score)
  358. a, b = nodes[0].getparent(), nodes[1].getparent()
  359. score = d.widget._compare_node(a, b, wsize, wsize)
  360. pprint(score)
  361. return
  362. print("score:", result.score)
  363. x, y = result.point
  364. # # pprint(result.widget)
  365. # # pprint(dict(result.node.attrib))
  366. pprint(result.detail)
  367. im = draw_point(d.screenshot(), x, y)
  368. im.show()
  369. if __name__ == "__main__":
  370. main()