flatten_unflatten.cpp 788 B

1234567891011121314151617181920212223242526272829
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. /*
  5. Copyright NVIDIA/apex
  6. This file is adapted from fused adam in NVIDIA/apex, commit a109f85
  7. */
  8. #include <torch/csrc/utils/tensor_flatten.h>
  9. #include <torch/extension.h>
  10. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
  11. at::Tensor flatten(std::vector<at::Tensor> tensors)
  12. {
  13. return torch::utils::flatten_dense_tensors(tensors);
  14. }
  15. std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors)
  16. {
  17. return torch::utils::unflatten_dense_tensors(flat, tensors);
  18. }
  19. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  20. {
  21. m.def("flatten", &flatten, "Flatten dense tensors");
  22. m.def("unflatten", &unflatten, "Unflatten dense tensors");
  23. }