get_model_metadata.py 1003 B

12345678910111213141516171819202122232425262728
  1. #!/usr/bin/env python3
  2. import sys
  3. import pathlib
  4. import onnx
  5. import codecs
  6. import pickle
  7. def get_name_and_shape(value_info:onnx.ValueInfoProto) -> tuple[str, tuple[int,...]]:
  8. shape = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim])
  9. name = value_info.name
  10. return name, shape
  11. if __name__ == "__main__":
  12. model_path = pathlib.Path(sys.argv[1])
  13. model = onnx.load(str(model_path))
  14. i = [x.key for x in model.metadata_props].index('output_slices')
  15. output_slices = model.metadata_props[i].value
  16. metadata = {}
  17. metadata['output_slices'] = pickle.loads(codecs.decode(output_slices.encode(), "base64"))
  18. metadata['input_shapes'] = dict([get_name_and_shape(x) for x in model.graph.input])
  19. metadata['output_shapes'] = dict([get_name_and_shape(x) for x in model.graph.output])
  20. metadata_path = model_path.parent / (model_path.stem + '_metadata.pkl')
  21. with open(metadata_path, 'wb') as f:
  22. pickle.dump(metadata, f)
  23. print(f'saved metadata to {metadata_path}')