flatten_unflatten.cpp 742 B

12345678910111213141516171819202122232425
  1. /*
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. Copyright NVIDIA/apex
  4. This file is adapted from fused adam in NVIDIA/apex, commit a109f85
  5. */
  6. #include <torch/csrc/utils/tensor_flatten.h>
  7. #include <torch/extension.h>
  8. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
  9. at::Tensor flatten(std::vector<at::Tensor> tensors)
  10. {
  11. return torch::utils::flatten_dense_tensors(tensors);
  12. }
  13. std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors)
  14. {
  15. return torch::utils::unflatten_dense_tensors(flat, tensors);
  16. }
  17. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  18. {
  19. m.def("flatten", &flatten, "Flatten dense tensors");
  20. m.def("unflatten", &unflatten, "Unflatten dense tensors");
  21. }