export_meta.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import types
  6. import torch
  7. import torch.nn as nn
  8. from funasr.register import tables
  9. def export_rebuild_model(model, **kwargs):
  10. model.device = kwargs.get("device")
  11. is_onnx = kwargs.get("type", "onnx") == "onnx"
  12. # encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
  13. # model.encoder = encoder_class(model.encoder, onnx=is_onnx)
  14. from funasr.utils.torch_function import sequence_mask
  15. model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
  16. model.forward = types.MethodType(export_forward, model)
  17. model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
  18. model.export_input_names = types.MethodType(export_input_names, model)
  19. model.export_output_names = types.MethodType(export_output_names, model)
  20. model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
  21. model.export_name = types.MethodType(export_name, model)
  22. model.export_name = "model"
  23. return model
  24. def export_forward(
  25. self,
  26. speech: torch.Tensor,
  27. speech_lengths: torch.Tensor,
  28. language: torch.Tensor,
  29. textnorm: torch.Tensor,
  30. **kwargs,
  31. ):
  32. speech = speech.to(device=kwargs["device"])
  33. speech_lengths = speech_lengths.to(device=kwargs["device"])
  34. language_query = self.embed(language).to(speech.device)
  35. textnorm_query = self.embed(textnorm).to(speech.device)
  36. speech = torch.cat((textnorm_query, speech), dim=1)
  37. speech_lengths += 1
  38. event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
  39. speech.size(0), 1, 1
  40. )
  41. input_query = torch.cat((language_query, event_emo_query), dim=1)
  42. speech = torch.cat((input_query, speech), dim=1)
  43. speech_lengths += 3
  44. # Encoder
  45. encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
  46. if isinstance(encoder_out, tuple):
  47. encoder_out = encoder_out[0]
  48. # c. Passed the encoder result and the beam search
  49. ctc_logits = self.ctc.log_softmax(encoder_out)
  50. return ctc_logits, encoder_out_lens
  51. def export_dummy_inputs(self):
  52. speech = torch.randn(2, 30, 560)
  53. speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
  54. language = torch.tensor([0, 0], dtype=torch.int32)
  55. textnorm = torch.tensor([15, 15], dtype=torch.int32)
  56. return (speech, speech_lengths, language, textnorm)
  57. def export_input_names(self):
  58. return ["speech", "speech_lengths", "language", "textnorm"]
  59. def export_output_names(self):
  60. return ["ctc_logits", "encoder_out_lens"]
  61. def export_dynamic_axes(self):
  62. return {
  63. "speech": {0: "batch_size", 1: "feats_length"},
  64. "speech_lengths": {
  65. 0: "batch_size",
  66. },
  67. "logits": {0: "batch_size", 1: "logits_length"},
  68. }
  69. def export_name(
  70. self,
  71. ):
  72. return "model.onnx"