model_replay.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. #!/usr/bin/env python3
  2. import os
  3. import sys
  4. import time
  5. from collections import defaultdict
  6. from typing import Any
  7. import cereal.messaging as messaging
  8. from openpilot.common.params import Params
  9. from openpilot.system.hardware import PC
  10. from openpilot.selfdrive.manager.process_config import managed_processes
  11. from openpilot.tools.lib.openpilotci import BASE_URL, get_url
  12. from openpilot.selfdrive.test.process_replay.compare_logs import compare_logs, format_diff
  13. from openpilot.selfdrive.test.process_replay.process_replay import get_process_config, replay_process
  14. from openpilot.system.version import get_commit
  15. from openpilot.tools.lib.framereader import FrameReader
  16. from openpilot.tools.lib.logreader import LogReader
  17. from openpilot.tools.lib.helpers import save_log
  18. TEST_ROUTE = "2f4452b03ccb98f0|2022-12-03--13-45-30"
  19. SEGMENT = 6
  20. MAX_FRAMES = 100 if PC else 600
  21. NAV_FRAMES = 50
  22. NO_NAV = "NO_NAV" in os.environ
  23. NO_MODEL = "NO_MODEL" in os.environ
  24. SEND_EXTRA_INPUTS = bool(int(os.getenv("SEND_EXTRA_INPUTS", "0")))
  25. def get_log_fn(ref_commit, test_route):
  26. return f"{test_route}_model_tici_{ref_commit}.bz2"
  27. def trim_logs_to_max_frames(logs, max_frames, frs_types, include_all_types):
  28. all_msgs = []
  29. cam_state_counts = defaultdict(int)
  30. # keep adding messages until cam states are equal to MAX_FRAMES
  31. for msg in sorted(logs, key=lambda m: m.logMonoTime):
  32. all_msgs.append(msg)
  33. if msg.which() in frs_types:
  34. cam_state_counts[msg.which()] += 1
  35. if all(cam_state_counts[state] == max_frames for state in frs_types):
  36. break
  37. if len(include_all_types) != 0:
  38. other_msgs = [m for m in logs if m.which() in include_all_types]
  39. all_msgs.extend(other_msgs)
  40. return all_msgs
  41. def nav_model_replay(lr):
  42. sm = messaging.SubMaster(['navModel', 'navThumbnail', 'mapRenderState'])
  43. pm = messaging.PubMaster(['liveLocationKalman', 'navRoute'])
  44. nav = [m for m in lr if m.which() == 'navRoute']
  45. llk = [m for m in lr if m.which() == 'liveLocationKalman']
  46. assert len(nav) > 0 and len(llk) >= NAV_FRAMES and nav[0].logMonoTime < llk[-NAV_FRAMES].logMonoTime
  47. log_msgs = []
  48. try:
  49. assert "MAPBOX_TOKEN" in os.environ
  50. os.environ['MAP_RENDER_TEST_MODE'] = '1'
  51. Params().put_bool('DmModelInitialized', True)
  52. managed_processes['mapsd'].start()
  53. managed_processes['navmodeld'].start()
  54. # setup position and route
  55. for _ in range(10):
  56. for s in (llk[-NAV_FRAMES], nav[0]):
  57. pm.send(s.which(), s.as_builder().to_bytes())
  58. sm.update(1000)
  59. if sm.updated['navModel']:
  60. break
  61. time.sleep(1)
  62. if not sm.updated['navModel']:
  63. raise Exception("no navmodeld outputs, failed to initialize")
  64. # drain
  65. time.sleep(2)
  66. sm.update(0)
  67. # run replay
  68. for n in range(len(llk) - NAV_FRAMES, len(llk)):
  69. pm.send(llk[n].which(), llk[n].as_builder().to_bytes())
  70. m = messaging.recv_one(sm.sock['navThumbnail'])
  71. assert m is not None, f"no navThumbnail, frame={n}"
  72. log_msgs.append(m)
  73. m = messaging.recv_one(sm.sock['mapRenderState'])
  74. assert m is not None, f"no mapRenderState, frame={n}"
  75. log_msgs.append(m)
  76. m = messaging.recv_one(sm.sock['navModel'])
  77. assert m is not None, f"no navModel response, frame={n}"
  78. log_msgs.append(m)
  79. finally:
  80. managed_processes['mapsd'].stop()
  81. managed_processes['navmodeld'].stop()
  82. return log_msgs
  83. def model_replay(lr, frs):
  84. # modeld is using frame pairs
  85. modeld_logs = trim_logs_to_max_frames(lr, MAX_FRAMES, {"roadCameraState", "wideRoadCameraState"}, {"roadEncodeIdx", "wideRoadEncodeIdx", "carParams"})
  86. dmodeld_logs = trim_logs_to_max_frames(lr, MAX_FRAMES, {"driverCameraState"}, {"driverEncodeIdx", "carParams"})
  87. if not SEND_EXTRA_INPUTS:
  88. modeld_logs = [msg for msg in modeld_logs if msg.which() not in ["liveCalibration",]]
  89. dmodeld_logs = [msg for msg in dmodeld_logs if msg.which() not in ["liveCalibration",]]
  90. # initial calibration
  91. cal_msg = next(msg for msg in lr if msg.which() == "liveCalibration").as_builder()
  92. cal_msg.logMonoTime = lr[0].logMonoTime
  93. modeld_logs.insert(0, cal_msg.as_reader())
  94. dmodeld_logs.insert(0, cal_msg.as_reader())
  95. modeld = get_process_config("modeld")
  96. dmonitoringmodeld = get_process_config("dmonitoringmodeld")
  97. modeld_msgs = replay_process(modeld, modeld_logs, frs)
  98. dmonitoringmodeld_msgs = replay_process(dmonitoringmodeld, dmodeld_logs, frs)
  99. return modeld_msgs + dmonitoringmodeld_msgs
  100. if __name__ == "__main__":
  101. update = "--update" in sys.argv
  102. replay_dir = os.path.dirname(os.path.abspath(__file__))
  103. ref_commit_fn = os.path.join(replay_dir, "model_replay_ref_commit")
  104. # load logs
  105. lr = list(LogReader(get_url(TEST_ROUTE, SEGMENT)))
  106. frs = {
  107. 'roadCameraState': FrameReader(get_url(TEST_ROUTE, SEGMENT, log_type="fcamera"), readahead=True),
  108. 'driverCameraState': FrameReader(get_url(TEST_ROUTE, SEGMENT, log_type="dcamera"), readahead=True),
  109. 'wideRoadCameraState': FrameReader(get_url(TEST_ROUTE, SEGMENT, log_type="ecamera"), readahead=True)
  110. }
  111. # Update tile refs
  112. if update:
  113. import urllib
  114. import requests
  115. import threading
  116. import http.server
  117. from openpilot.tools.lib.openpilotci import upload_bytes
  118. os.environ['MAPS_HOST'] = 'http://localhost:5000'
  119. class HTTPRequestHandler(http.server.BaseHTTPRequestHandler):
  120. def do_GET(self):
  121. assert len(self.path) > 10 # Sanity check on path length
  122. r = requests.get(f'https://api.mapbox.com{self.path}', timeout=30)
  123. upload_bytes(r.content, urllib.parse.urlparse(self.path).path.lstrip('/'))
  124. self.send_response(r.status_code)
  125. self.send_header('Content-type','text/html')
  126. self.end_headers()
  127. self.wfile.write(r.content)
  128. server = http.server.HTTPServer(("127.0.0.1", 5000), HTTPRequestHandler)
  129. thread = threading.Thread(None, server.serve_forever, daemon=True)
  130. thread.start()
  131. else:
  132. os.environ['MAPS_HOST'] = BASE_URL.rstrip('/')
  133. log_msgs = []
  134. # run replays
  135. if not NO_MODEL:
  136. log_msgs += model_replay(lr, frs)
  137. if not NO_NAV:
  138. log_msgs += nav_model_replay(lr)
  139. # get diff
  140. failed = False
  141. if not update:
  142. with open(ref_commit_fn) as f:
  143. ref_commit = f.read().strip()
  144. log_fn = get_log_fn(ref_commit, TEST_ROUTE)
  145. try:
  146. all_logs = list(LogReader(BASE_URL + log_fn))
  147. cmp_log = []
  148. # logs are ordered based on type: modelV2, driverStateV2, nav messages (navThumbnail, mapRenderState, navModel)
  149. if not NO_MODEL:
  150. model_start_index = next(i for i, m in enumerate(all_logs) if m.which() in ("modelV2", "cameraOdometry"))
  151. cmp_log += all_logs[model_start_index:model_start_index + MAX_FRAMES*2]
  152. dmon_start_index = next(i for i, m in enumerate(all_logs) if m.which() == "driverStateV2")
  153. cmp_log += all_logs[dmon_start_index:dmon_start_index + MAX_FRAMES]
  154. if not NO_NAV:
  155. nav_start_index = next(i for i, m in enumerate(all_logs) if m.which() in ["navThumbnail", "mapRenderState", "navModel"])
  156. nav_logs = all_logs[nav_start_index:nav_start_index + NAV_FRAMES*3]
  157. cmp_log += nav_logs
  158. ignore = [
  159. 'logMonoTime',
  160. 'modelV2.frameDropPerc',
  161. 'modelV2.modelExecutionTime',
  162. 'driverStateV2.modelExecutionTime',
  163. 'driverStateV2.dspExecutionTime',
  164. 'navModel.dspExecutionTime',
  165. 'navModel.modelExecutionTime',
  166. 'navThumbnail.timestampEof',
  167. 'mapRenderState.locationMonoTime',
  168. 'mapRenderState.renderTime',
  169. ]
  170. if PC:
  171. ignore += [
  172. 'modelV2.laneLines.0.t',
  173. 'modelV2.laneLines.1.t',
  174. 'modelV2.laneLines.2.t',
  175. 'modelV2.laneLines.3.t',
  176. 'modelV2.roadEdges.0.t',
  177. 'modelV2.roadEdges.1.t',
  178. ]
  179. # TODO this tolerance is absurdly large
  180. tolerance = 2.0 if PC else None
  181. results: Any = {TEST_ROUTE: {}}
  182. log_paths: Any = {TEST_ROUTE: {"models": {'ref': BASE_URL + log_fn, 'new': log_fn}}}
  183. results[TEST_ROUTE]["models"] = compare_logs(cmp_log, log_msgs, tolerance=tolerance, ignore_fields=ignore)
  184. diff_short, diff_long, failed = format_diff(results, log_paths, ref_commit)
  185. print(diff_long)
  186. print('-------------\n'*5)
  187. print(diff_short)
  188. with open("model_diff.txt", "w") as f:
  189. f.write(diff_long)
  190. except Exception as e:
  191. print(str(e))
  192. failed = True
  193. # upload new refs
  194. if (update or failed) and not PC:
  195. from openpilot.tools.lib.openpilotci import upload_file
  196. print("Uploading new refs")
  197. new_commit = get_commit()
  198. log_fn = get_log_fn(new_commit, TEST_ROUTE)
  199. save_log(log_fn, log_msgs)
  200. try:
  201. upload_file(log_fn, os.path.basename(log_fn))
  202. except Exception as e:
  203. print("failed to upload", e)
  204. with open(ref_commit_fn, 'w') as f:
  205. f.write(str(new_commit))
  206. print("\n\nNew ref commit: ", new_commit)
  207. sys.exit(int(failed))