dmonitoringmodeld.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #!/usr/bin/env python3
  2. import os
  3. import gc
  4. import math
  5. import time
  6. import ctypes
  7. import numpy as np
  8. from pathlib import Path
  9. from cereal import messaging
  10. from cereal.messaging import PubMaster, SubMaster
  11. from cereal.visionipc import VisionIpcClient, VisionStreamType, VisionBuf
  12. from openpilot.common.swaglog import cloudlog
  13. from openpilot.common.params import Params
  14. from openpilot.common.realtime import set_realtime_priority
  15. from openpilot.selfdrive.modeld.runners import ModelRunner, Runtime
  16. from openpilot.selfdrive.modeld.models.commonmodel_pyx import sigmoid
  17. CALIB_LEN = 3
  18. REG_SCALE = 0.25
  19. MODEL_WIDTH = 1440
  20. MODEL_HEIGHT = 960
  21. OUTPUT_SIZE = 84
  22. SEND_RAW_PRED = os.getenv('SEND_RAW_PRED')
  23. MODEL_PATHS = {
  24. ModelRunner.SNPE: Path(__file__).parent / 'models/dmonitoring_model_q.dlc',
  25. ModelRunner.ONNX: Path(__file__).parent / 'models/dmonitoring_model.onnx'}
  26. class DriverStateResult(ctypes.Structure):
  27. _fields_ = [
  28. ("face_orientation", ctypes.c_float*3),
  29. ("face_position", ctypes.c_float*3),
  30. ("face_orientation_std", ctypes.c_float*3),
  31. ("face_position_std", ctypes.c_float*3),
  32. ("face_prob", ctypes.c_float),
  33. ("_unused_a", ctypes.c_float*8),
  34. ("left_eye_prob", ctypes.c_float),
  35. ("_unused_b", ctypes.c_float*8),
  36. ("right_eye_prob", ctypes.c_float),
  37. ("left_blink_prob", ctypes.c_float),
  38. ("right_blink_prob", ctypes.c_float),
  39. ("sunglasses_prob", ctypes.c_float),
  40. ("occluded_prob", ctypes.c_float),
  41. ("ready_prob", ctypes.c_float*4),
  42. ("not_ready_prob", ctypes.c_float*2)]
  43. class DMonitoringModelResult(ctypes.Structure):
  44. _fields_ = [
  45. ("driver_state_lhd", DriverStateResult),
  46. ("driver_state_rhd", DriverStateResult),
  47. ("poor_vision_prob", ctypes.c_float),
  48. ("wheel_on_right_prob", ctypes.c_float)]
  49. class ModelState:
  50. inputs: dict[str, np.ndarray]
  51. output: np.ndarray
  52. model: ModelRunner
  53. def __init__(self):
  54. assert ctypes.sizeof(DMonitoringModelResult) == OUTPUT_SIZE * ctypes.sizeof(ctypes.c_float)
  55. self.output = np.zeros(OUTPUT_SIZE, dtype=np.float32)
  56. self.inputs = {
  57. 'input_img': np.zeros(MODEL_HEIGHT * MODEL_WIDTH, dtype=np.uint8),
  58. 'calib': np.zeros(CALIB_LEN, dtype=np.float32)}
  59. self.model = ModelRunner(MODEL_PATHS, self.output, Runtime.DSP, True, None)
  60. self.model.addInput("input_img", None)
  61. self.model.addInput("calib", self.inputs['calib'])
  62. def run(self, buf:VisionBuf, calib:np.ndarray) -> tuple[np.ndarray, float]:
  63. self.inputs['calib'][:] = calib
  64. v_offset = buf.height - MODEL_HEIGHT
  65. h_offset = (buf.width - MODEL_WIDTH) // 2
  66. buf_data = buf.data.reshape(-1, buf.stride)
  67. input_data = self.inputs['input_img'].reshape(MODEL_HEIGHT, MODEL_WIDTH)
  68. input_data[:] = buf_data[v_offset:v_offset+MODEL_HEIGHT, h_offset:h_offset+MODEL_WIDTH]
  69. t1 = time.perf_counter()
  70. self.model.setInputBuffer("input_img", self.inputs['input_img'].view(np.float32))
  71. self.model.execute()
  72. t2 = time.perf_counter()
  73. return self.output, t2 - t1
  74. def fill_driver_state(msg, ds_result: DriverStateResult):
  75. msg.faceOrientation = [x * REG_SCALE for x in ds_result.face_orientation]
  76. msg.faceOrientationStd = [math.exp(x) for x in ds_result.face_orientation_std]
  77. msg.facePosition = [x * REG_SCALE for x in ds_result.face_position[:2]]
  78. msg.facePositionStd = [math.exp(x) for x in ds_result.face_position_std[:2]]
  79. msg.faceProb = sigmoid(ds_result.face_prob)
  80. msg.leftEyeProb = sigmoid(ds_result.left_eye_prob)
  81. msg.rightEyeProb = sigmoid(ds_result.right_eye_prob)
  82. msg.leftBlinkProb = sigmoid(ds_result.left_blink_prob)
  83. msg.rightBlinkProb = sigmoid(ds_result.right_blink_prob)
  84. msg.sunglassesProb = sigmoid(ds_result.sunglasses_prob)
  85. msg.occludedProb = sigmoid(ds_result.occluded_prob)
  86. msg.readyProb = [sigmoid(x) for x in ds_result.ready_prob]
  87. msg.notReadyProb = [sigmoid(x) for x in ds_result.not_ready_prob]
  88. def get_driverstate_packet(model_output: np.ndarray, frame_id: int, location_ts: int, execution_time: float, dsp_execution_time: float):
  89. model_result = ctypes.cast(model_output.ctypes.data, ctypes.POINTER(DMonitoringModelResult)).contents
  90. msg = messaging.new_message('driverStateV2', valid=True)
  91. ds = msg.driverStateV2
  92. ds.frameId = frame_id
  93. ds.modelExecutionTime = execution_time
  94. ds.dspExecutionTime = dsp_execution_time
  95. ds.poorVisionProb = sigmoid(model_result.poor_vision_prob)
  96. ds.wheelOnRightProb = sigmoid(model_result.wheel_on_right_prob)
  97. ds.rawPredictions = model_output.tobytes() if SEND_RAW_PRED else b''
  98. fill_driver_state(ds.leftDriverData, model_result.driver_state_lhd)
  99. fill_driver_state(ds.rightDriverData, model_result.driver_state_rhd)
  100. return msg
  101. def main():
  102. gc.disable()
  103. set_realtime_priority(1)
  104. model = ModelState()
  105. cloudlog.warning("models loaded, dmonitoringmodeld starting")
  106. Params().put_bool("DmModelInitialized", True)
  107. cloudlog.warning("connecting to driver stream")
  108. vipc_client = VisionIpcClient("camerad", VisionStreamType.VISION_STREAM_DRIVER, True)
  109. while not vipc_client.connect(False):
  110. time.sleep(0.1)
  111. assert vipc_client.is_connected()
  112. cloudlog.warning(f"connected with buffer size: {vipc_client.buffer_len}")
  113. sm = SubMaster(["liveCalibration"])
  114. pm = PubMaster(["driverStateV2"])
  115. calib = np.zeros(CALIB_LEN, dtype=np.float32)
  116. # last = 0
  117. while True:
  118. buf = vipc_client.recv()
  119. if buf is None:
  120. continue
  121. sm.update(0)
  122. if sm.updated["liveCalibration"]:
  123. calib[:] = np.array(sm["liveCalibration"].rpyCalib)
  124. t1 = time.perf_counter()
  125. model_output, dsp_execution_time = model.run(buf, calib)
  126. t2 = time.perf_counter()
  127. pm.send("driverStateV2", get_driverstate_packet(model_output, vipc_client.frame_id, vipc_client.timestamp_sof, t2 - t1, dsp_execution_time))
  128. # print("dmonitoring process: %.2fms, from last %.2fms\n" % (t2 - t1, t1 - last))
  129. # last = t1
  130. if __name__ == "__main__":
  131. main()