model.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. from typing import Iterable, Optional
  2. import types
  3. import time
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import Tensor
  8. from torch import nn
  9. from torch.cuda.amp import autocast
  10. from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
  11. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  12. from funasr.train_utils.device_funcs import force_gatherable
  13. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  14. from funasr.utils.datadir_writer import DatadirWriter
  15. from funasr.models.ctc.ctc import CTC
  16. from funasr.register import tables
  17. from funasr.models.paraformer.search import Hypothesis
  18. class SinusoidalPositionEncoder(torch.nn.Module):
  19. """ """
  20. def __int__(self, d_model=80, dropout_rate=0.1):
  21. pass
  22. def encode(
  23. self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
  24. ):
  25. batch_size = positions.size(0)
  26. positions = positions.type(dtype)
  27. device = positions.device
  28. log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
  29. depth / 2 - 1
  30. )
  31. inv_timescales = torch.exp(
  32. torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
  33. )
  34. inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
  35. scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
  36. inv_timescales, [1, 1, -1]
  37. )
  38. encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
  39. return encoding.type(dtype)
  40. def forward(self, x):
  41. batch_size, timesteps, input_dim = x.size()
  42. positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
  43. position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
  44. return x + position_encoding
  45. class PositionwiseFeedForward(torch.nn.Module):
  46. """Positionwise feed forward layer.
  47. Args:
  48. idim (int): Input dimenstion.
  49. hidden_units (int): The number of hidden units.
  50. dropout_rate (float): Dropout rate.
  51. """
  52. def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
  53. """Construct an PositionwiseFeedForward object."""
  54. super(PositionwiseFeedForward, self).__init__()
  55. self.w_1 = torch.nn.Linear(idim, hidden_units)
  56. self.w_2 = torch.nn.Linear(hidden_units, idim)
  57. self.dropout = torch.nn.Dropout(dropout_rate)
  58. self.activation = activation
  59. def forward(self, x):
  60. """Forward function."""
  61. return self.w_2(self.dropout(self.activation(self.w_1(x))))
  62. class MultiHeadedAttentionSANM(nn.Module):
  63. """Multi-Head Attention layer.
  64. Args:
  65. n_head (int): The number of heads.
  66. n_feat (int): The number of features.
  67. dropout_rate (float): Dropout rate.
  68. """
  69. def __init__(
  70. self,
  71. n_head,
  72. in_feat,
  73. n_feat,
  74. dropout_rate,
  75. kernel_size,
  76. sanm_shfit=0,
  77. lora_list=None,
  78. lora_rank=8,
  79. lora_alpha=16,
  80. lora_dropout=0.1,
  81. ):
  82. """Construct an MultiHeadedAttention object."""
  83. super().__init__()
  84. assert n_feat % n_head == 0
  85. # We assume d_v always equals d_k
  86. self.d_k = n_feat // n_head
  87. self.h = n_head
  88. # self.linear_q = nn.Linear(n_feat, n_feat)
  89. # self.linear_k = nn.Linear(n_feat, n_feat)
  90. # self.linear_v = nn.Linear(n_feat, n_feat)
  91. self.linear_out = nn.Linear(n_feat, n_feat)
  92. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  93. self.attn = None
  94. self.dropout = nn.Dropout(p=dropout_rate)
  95. self.fsmn_block = nn.Conv1d(
  96. n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
  97. )
  98. # padding
  99. left_padding = (kernel_size - 1) // 2
  100. if sanm_shfit > 0:
  101. left_padding = left_padding + sanm_shfit
  102. right_padding = kernel_size - 1 - left_padding
  103. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  104. def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
  105. b, t, d = inputs.size()
  106. if mask is not None:
  107. mask = torch.reshape(mask, (b, -1, 1))
  108. if mask_shfit_chunk is not None:
  109. mask = mask * mask_shfit_chunk
  110. inputs = inputs * mask
  111. x = inputs.transpose(1, 2)
  112. x = self.pad_fn(x)
  113. x = self.fsmn_block(x)
  114. x = x.transpose(1, 2)
  115. x += inputs
  116. x = self.dropout(x)
  117. if mask is not None:
  118. x = x * mask
  119. return x
  120. def forward_qkv(self, x):
  121. """Transform query, key and value.
  122. Args:
  123. query (torch.Tensor): Query tensor (#batch, time1, size).
  124. key (torch.Tensor): Key tensor (#batch, time2, size).
  125. value (torch.Tensor): Value tensor (#batch, time2, size).
  126. Returns:
  127. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  128. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  129. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  130. """
  131. b, t, d = x.size()
  132. q_k_v = self.linear_q_k_v(x)
  133. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  134. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
  135. 1, 2
  136. ) # (batch, head, time1, d_k)
  137. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
  138. 1, 2
  139. ) # (batch, head, time2, d_k)
  140. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
  141. 1, 2
  142. ) # (batch, head, time2, d_k)
  143. return q_h, k_h, v_h, v
  144. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  145. """Compute attention context vector.
  146. Args:
  147. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  148. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  149. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  150. Returns:
  151. torch.Tensor: Transformed value (#batch, time1, d_model)
  152. weighted by the attention score (#batch, time1, time2).
  153. """
  154. n_batch = value.size(0)
  155. if mask is not None:
  156. if mask_att_chunk_encoder is not None:
  157. mask = mask * mask_att_chunk_encoder
  158. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  159. min_value = -float(
  160. "inf"
  161. ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
  162. scores = scores.masked_fill(mask, min_value)
  163. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  164. mask, 0.0
  165. ) # (batch, head, time1, time2)
  166. else:
  167. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  168. p_attn = self.dropout(self.attn)
  169. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  170. x = (
  171. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  172. ) # (batch, time1, d_model)
  173. return self.linear_out(x) # (batch, time1, d_model)
  174. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  175. """Compute scaled dot product attention.
  176. Args:
  177. query (torch.Tensor): Query tensor (#batch, time1, size).
  178. key (torch.Tensor): Key tensor (#batch, time2, size).
  179. value (torch.Tensor): Value tensor (#batch, time2, size).
  180. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  181. (#batch, time1, time2).
  182. Returns:
  183. torch.Tensor: Output tensor (#batch, time1, d_model).
  184. """
  185. q_h, k_h, v_h, v = self.forward_qkv(x)
  186. fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
  187. q_h = q_h * self.d_k ** (-0.5)
  188. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  189. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  190. return att_outs + fsmn_memory
  191. def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
  192. """Compute scaled dot product attention.
  193. Args:
  194. query (torch.Tensor): Query tensor (#batch, time1, size).
  195. key (torch.Tensor): Key tensor (#batch, time2, size).
  196. value (torch.Tensor): Value tensor (#batch, time2, size).
  197. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  198. (#batch, time1, time2).
  199. Returns:
  200. torch.Tensor: Output tensor (#batch, time1, d_model).
  201. """
  202. q_h, k_h, v_h, v = self.forward_qkv(x)
  203. if chunk_size is not None and look_back > 0 or look_back == -1:
  204. if cache is not None:
  205. k_h_stride = k_h[:, :, : -(chunk_size[2]), :]
  206. v_h_stride = v_h[:, :, : -(chunk_size[2]), :]
  207. k_h = torch.cat((cache["k"], k_h), dim=2)
  208. v_h = torch.cat((cache["v"], v_h), dim=2)
  209. cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
  210. cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
  211. if look_back != -1:
  212. cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :]
  213. cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :]
  214. else:
  215. cache_tmp = {
  216. "k": k_h[:, :, : -(chunk_size[2]), :],
  217. "v": v_h[:, :, : -(chunk_size[2]), :],
  218. }
  219. cache = cache_tmp
  220. fsmn_memory = self.forward_fsmn(v, None)
  221. q_h = q_h * self.d_k ** (-0.5)
  222. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  223. att_outs = self.forward_attention(v_h, scores, None)
  224. return att_outs + fsmn_memory, cache
  225. class LayerNorm(nn.LayerNorm):
  226. def __init__(self, *args, **kwargs):
  227. super().__init__(*args, **kwargs)
  228. def forward(self, input):
  229. output = F.layer_norm(
  230. input.float(),
  231. self.normalized_shape,
  232. self.weight.float() if self.weight is not None else None,
  233. self.bias.float() if self.bias is not None else None,
  234. self.eps,
  235. )
  236. return output.type_as(input)
  237. def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
  238. if maxlen is None:
  239. maxlen = lengths.max()
  240. row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
  241. matrix = torch.unsqueeze(lengths, dim=-1)
  242. mask = row_vector < matrix
  243. mask = mask.detach()
  244. return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
  245. class EncoderLayerSANM(nn.Module):
  246. def __init__(
  247. self,
  248. in_size,
  249. size,
  250. self_attn,
  251. feed_forward,
  252. dropout_rate,
  253. normalize_before=True,
  254. concat_after=False,
  255. stochastic_depth_rate=0.0,
  256. ):
  257. """Construct an EncoderLayer object."""
  258. super(EncoderLayerSANM, self).__init__()
  259. self.self_attn = self_attn
  260. self.feed_forward = feed_forward
  261. self.norm1 = LayerNorm(in_size)
  262. self.norm2 = LayerNorm(size)
  263. self.dropout = nn.Dropout(dropout_rate)
  264. self.in_size = in_size
  265. self.size = size
  266. self.normalize_before = normalize_before
  267. self.concat_after = concat_after
  268. if self.concat_after:
  269. self.concat_linear = nn.Linear(size + size, size)
  270. self.stochastic_depth_rate = stochastic_depth_rate
  271. self.dropout_rate = dropout_rate
  272. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  273. """Compute encoded features.
  274. Args:
  275. x_input (torch.Tensor): Input tensor (#batch, time, size).
  276. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  277. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  278. Returns:
  279. torch.Tensor: Output tensor (#batch, time, size).
  280. torch.Tensor: Mask tensor (#batch, time).
  281. """
  282. skip_layer = False
  283. # with stochastic depth, residual connection `x + f(x)` becomes
  284. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  285. stoch_layer_coeff = 1.0
  286. if self.training and self.stochastic_depth_rate > 0:
  287. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  288. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  289. if skip_layer:
  290. if cache is not None:
  291. x = torch.cat([cache, x], dim=1)
  292. return x, mask
  293. residual = x
  294. if self.normalize_before:
  295. x = self.norm1(x)
  296. if self.concat_after:
  297. x_concat = torch.cat(
  298. (
  299. x,
  300. self.self_attn(
  301. x,
  302. mask,
  303. mask_shfit_chunk=mask_shfit_chunk,
  304. mask_att_chunk_encoder=mask_att_chunk_encoder,
  305. ),
  306. ),
  307. dim=-1,
  308. )
  309. if self.in_size == self.size:
  310. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  311. else:
  312. x = stoch_layer_coeff * self.concat_linear(x_concat)
  313. else:
  314. if self.in_size == self.size:
  315. x = residual + stoch_layer_coeff * self.dropout(
  316. self.self_attn(
  317. x,
  318. mask,
  319. mask_shfit_chunk=mask_shfit_chunk,
  320. mask_att_chunk_encoder=mask_att_chunk_encoder,
  321. )
  322. )
  323. else:
  324. x = stoch_layer_coeff * self.dropout(
  325. self.self_attn(
  326. x,
  327. mask,
  328. mask_shfit_chunk=mask_shfit_chunk,
  329. mask_att_chunk_encoder=mask_att_chunk_encoder,
  330. )
  331. )
  332. if not self.normalize_before:
  333. x = self.norm1(x)
  334. residual = x
  335. if self.normalize_before:
  336. x = self.norm2(x)
  337. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  338. if not self.normalize_before:
  339. x = self.norm2(x)
  340. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  341. def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
  342. """Compute encoded features.
  343. Args:
  344. x_input (torch.Tensor): Input tensor (#batch, time, size).
  345. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  346. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  347. Returns:
  348. torch.Tensor: Output tensor (#batch, time, size).
  349. torch.Tensor: Mask tensor (#batch, time).
  350. """
  351. residual = x
  352. if self.normalize_before:
  353. x = self.norm1(x)
  354. if self.in_size == self.size:
  355. attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
  356. x = residual + attn
  357. else:
  358. x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
  359. if not self.normalize_before:
  360. x = self.norm1(x)
  361. residual = x
  362. if self.normalize_before:
  363. x = self.norm2(x)
  364. x = residual + self.feed_forward(x)
  365. if not self.normalize_before:
  366. x = self.norm2(x)
  367. return x, cache
  368. @tables.register("encoder_classes", "SenseVoiceEncoderSmall")
  369. class SenseVoiceEncoderSmall(nn.Module):
  370. """
  371. Author: Speech Lab of DAMO Academy, Alibaba Group
  372. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  373. https://arxiv.org/abs/2006.01713
  374. """
  375. def __init__(
  376. self,
  377. input_size: int,
  378. output_size: int = 256,
  379. attention_heads: int = 4,
  380. linear_units: int = 2048,
  381. num_blocks: int = 6,
  382. tp_blocks: int = 0,
  383. dropout_rate: float = 0.1,
  384. positional_dropout_rate: float = 0.1,
  385. attention_dropout_rate: float = 0.0,
  386. stochastic_depth_rate: float = 0.0,
  387. input_layer: Optional[str] = "conv2d",
  388. pos_enc_class=SinusoidalPositionEncoder,
  389. normalize_before: bool = True,
  390. concat_after: bool = False,
  391. positionwise_layer_type: str = "linear",
  392. positionwise_conv_kernel_size: int = 1,
  393. padding_idx: int = -1,
  394. kernel_size: int = 11,
  395. sanm_shfit: int = 0,
  396. selfattention_layer_type: str = "sanm",
  397. **kwargs,
  398. ):
  399. super().__init__()
  400. self._output_size = output_size
  401. self.embed = SinusoidalPositionEncoder()
  402. self.normalize_before = normalize_before
  403. positionwise_layer = PositionwiseFeedForward
  404. positionwise_layer_args = (
  405. output_size,
  406. linear_units,
  407. dropout_rate,
  408. )
  409. encoder_selfattn_layer = MultiHeadedAttentionSANM
  410. encoder_selfattn_layer_args0 = (
  411. attention_heads,
  412. input_size,
  413. output_size,
  414. attention_dropout_rate,
  415. kernel_size,
  416. sanm_shfit,
  417. )
  418. encoder_selfattn_layer_args = (
  419. attention_heads,
  420. output_size,
  421. output_size,
  422. attention_dropout_rate,
  423. kernel_size,
  424. sanm_shfit,
  425. )
  426. self.encoders0 = nn.ModuleList(
  427. [
  428. EncoderLayerSANM(
  429. input_size,
  430. output_size,
  431. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  432. positionwise_layer(*positionwise_layer_args),
  433. dropout_rate,
  434. )
  435. for i in range(1)
  436. ]
  437. )
  438. self.encoders = nn.ModuleList(
  439. [
  440. EncoderLayerSANM(
  441. output_size,
  442. output_size,
  443. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  444. positionwise_layer(*positionwise_layer_args),
  445. dropout_rate,
  446. )
  447. for i in range(num_blocks - 1)
  448. ]
  449. )
  450. self.tp_encoders = nn.ModuleList(
  451. [
  452. EncoderLayerSANM(
  453. output_size,
  454. output_size,
  455. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  456. positionwise_layer(*positionwise_layer_args),
  457. dropout_rate,
  458. )
  459. for i in range(tp_blocks)
  460. ]
  461. )
  462. self.after_norm = LayerNorm(output_size)
  463. self.tp_norm = LayerNorm(output_size)
  464. def output_size(self) -> int:
  465. return self._output_size
  466. def forward(
  467. self,
  468. xs_pad: torch.Tensor,
  469. ilens: torch.Tensor,
  470. ):
  471. """Embed positions in tensor."""
  472. masks = sequence_mask(ilens, device=ilens.device)[:, None, :]
  473. xs_pad *= self.output_size() ** 0.5
  474. xs_pad = self.embed(xs_pad)
  475. # forward encoder1
  476. for layer_idx, encoder_layer in enumerate(self.encoders0):
  477. encoder_outs = encoder_layer(xs_pad, masks)
  478. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  479. for layer_idx, encoder_layer in enumerate(self.encoders):
  480. encoder_outs = encoder_layer(xs_pad, masks)
  481. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  482. xs_pad = self.after_norm(xs_pad)
  483. # forward encoder2
  484. olens = masks.squeeze(1).sum(1).int()
  485. for layer_idx, encoder_layer in enumerate(self.tp_encoders):
  486. encoder_outs = encoder_layer(xs_pad, masks)
  487. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  488. xs_pad = self.tp_norm(xs_pad)
  489. return xs_pad, olens
  490. @tables.register("model_classes", "SenseVoiceSmall")
  491. class SenseVoiceSmall(nn.Module):
  492. """CTC-attention hybrid Encoder-Decoder model"""
  493. def __init__(
  494. self,
  495. specaug: str = None,
  496. specaug_conf: dict = None,
  497. normalize: str = None,
  498. normalize_conf: dict = None,
  499. encoder: str = None,
  500. encoder_conf: dict = None,
  501. ctc_conf: dict = None,
  502. input_size: int = 80,
  503. vocab_size: int = -1,
  504. ignore_id: int = -1,
  505. blank_id: int = 0,
  506. sos: int = 1,
  507. eos: int = 2,
  508. length_normalized_loss: bool = False,
  509. **kwargs,
  510. ):
  511. super().__init__()
  512. if specaug is not None:
  513. specaug_class = tables.specaug_classes.get(specaug)
  514. specaug = specaug_class(**specaug_conf)
  515. if normalize is not None:
  516. normalize_class = tables.normalize_classes.get(normalize)
  517. normalize = normalize_class(**normalize_conf)
  518. encoder_class = tables.encoder_classes.get(encoder)
  519. encoder = encoder_class(input_size=input_size, **encoder_conf)
  520. encoder_output_size = encoder.output_size()
  521. if ctc_conf is None:
  522. ctc_conf = {}
  523. ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
  524. self.blank_id = blank_id
  525. self.sos = sos if sos is not None else vocab_size - 1
  526. self.eos = eos if eos is not None else vocab_size - 1
  527. self.vocab_size = vocab_size
  528. self.ignore_id = ignore_id
  529. self.specaug = specaug
  530. self.normalize = normalize
  531. self.encoder = encoder
  532. self.error_calculator = None
  533. self.ctc = ctc
  534. self.length_normalized_loss = length_normalized_loss
  535. self.encoder_output_size = encoder_output_size
  536. self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
  537. self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
  538. self.textnorm_dict = {"withitn": 14, "woitn": 15}
  539. self.textnorm_int_dict = {25016: 14, 25017: 15}
  540. self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size)
  541. self.criterion_att = LabelSmoothingLoss(
  542. size=self.vocab_size,
  543. padding_idx=self.ignore_id,
  544. smoothing=kwargs.get("lsm_weight", 0.0),
  545. normalize_length=self.length_normalized_loss,
  546. )
  547. @staticmethod
  548. def from_pretrained(model:str=None, **kwargs):
  549. from funasr import AutoModel
  550. model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
  551. return model, kwargs
  552. def forward(
  553. self,
  554. speech: torch.Tensor,
  555. speech_lengths: torch.Tensor,
  556. text: torch.Tensor,
  557. text_lengths: torch.Tensor,
  558. **kwargs,
  559. ):
  560. """Encoder + Decoder + Calc loss
  561. Args:
  562. speech: (Batch, Length, ...)
  563. speech_lengths: (Batch, )
  564. text: (Batch, Length)
  565. text_lengths: (Batch,)
  566. """
  567. # import pdb;
  568. # pdb.set_trace()
  569. if len(text_lengths.size()) > 1:
  570. text_lengths = text_lengths[:, 0]
  571. if len(speech_lengths.size()) > 1:
  572. speech_lengths = speech_lengths[:, 0]
  573. batch_size = speech.shape[0]
  574. # 1. Encoder
  575. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)
  576. loss_ctc, cer_ctc = None, None
  577. loss_rich, acc_rich = None, None
  578. stats = dict()
  579. loss_ctc, cer_ctc = self._calc_ctc_loss(
  580. encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4
  581. )
  582. loss_rich, acc_rich = self._calc_rich_ce_loss(
  583. encoder_out[:, :4, :], text[:, :4]
  584. )
  585. loss = loss_ctc
  586. # Collect total loss stats
  587. stats["loss"] = torch.clone(loss.detach()) if loss_ctc is not None else None
  588. stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None
  589. stats["acc_rich"] = acc_rich
  590. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  591. if self.length_normalized_loss:
  592. batch_size = int((text_lengths + 1).sum())
  593. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  594. return loss, stats, weight
  595. def encode(
  596. self,
  597. speech: torch.Tensor,
  598. speech_lengths: torch.Tensor,
  599. text: torch.Tensor,
  600. **kwargs,
  601. ):
  602. """Frontend + Encoder. Note that this method is used by asr_inference.py
  603. Args:
  604. speech: (Batch, Length, ...)
  605. speech_lengths: (Batch, )
  606. ind: int
  607. """
  608. # Data augmentation
  609. if self.specaug is not None and self.training:
  610. speech, speech_lengths = self.specaug(speech, speech_lengths)
  611. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  612. if self.normalize is not None:
  613. speech, speech_lengths = self.normalize(speech, speech_lengths)
  614. lids = torch.LongTensor([[self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0 ] for lid in text[:, 0]]).to(speech.device)
  615. language_query = self.embed(lids)
  616. styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device)
  617. style_query = self.embed(styles)
  618. speech = torch.cat((style_query, speech), dim=1)
  619. speech_lengths += 1
  620. event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
  621. input_query = torch.cat((language_query, event_emo_query), dim=1)
  622. speech = torch.cat((input_query, speech), dim=1)
  623. speech_lengths += 3
  624. encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
  625. return encoder_out, encoder_out_lens
  626. def _calc_ctc_loss(
  627. self,
  628. encoder_out: torch.Tensor,
  629. encoder_out_lens: torch.Tensor,
  630. ys_pad: torch.Tensor,
  631. ys_pad_lens: torch.Tensor,
  632. ):
  633. # Calc CTC loss
  634. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  635. # Calc CER using CTC
  636. cer_ctc = None
  637. if not self.training and self.error_calculator is not None:
  638. ys_hat = self.ctc.argmax(encoder_out).data
  639. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  640. return loss_ctc, cer_ctc
  641. def _calc_rich_ce_loss(
  642. self,
  643. encoder_out: torch.Tensor,
  644. ys_pad: torch.Tensor,
  645. ):
  646. decoder_out = self.ctc.ctc_lo(encoder_out)
  647. # 2. Compute attention loss
  648. loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous())
  649. acc_rich = th_accuracy(
  650. decoder_out.view(-1, self.vocab_size),
  651. ys_pad.contiguous(),
  652. ignore_label=self.ignore_id,
  653. )
  654. return loss_rich, acc_rich
  655. def inference(
  656. self,
  657. data_in,
  658. data_lengths=None,
  659. key: list = ["wav_file_tmp_name"],
  660. tokenizer=None,
  661. frontend=None,
  662. **kwargs,
  663. ):
  664. meta_data = {}
  665. if (
  666. isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
  667. ): # fbank
  668. speech, speech_lengths = data_in, data_lengths
  669. if len(speech.shape) < 3:
  670. speech = speech[None, :, :]
  671. if speech_lengths is None:
  672. speech_lengths = speech.shape[1]
  673. else:
  674. # extract fbank feats
  675. time1 = time.perf_counter()
  676. audio_sample_list = load_audio_text_image_video(
  677. data_in,
  678. fs=frontend.fs,
  679. audio_fs=kwargs.get("fs", 16000),
  680. data_type=kwargs.get("data_type", "sound"),
  681. tokenizer=tokenizer,
  682. )
  683. time2 = time.perf_counter()
  684. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  685. speech, speech_lengths = extract_fbank(
  686. audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
  687. )
  688. time3 = time.perf_counter()
  689. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  690. meta_data["batch_data_time"] = (
  691. speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  692. )
  693. speech = speech.to(device=kwargs["device"])
  694. speech_lengths = speech_lengths.to(device=kwargs["device"])
  695. language = kwargs.get("language", "auto")
  696. language_query = self.embed(
  697. torch.LongTensor(
  698. [[self.lid_dict[language] if language in self.lid_dict else 0]]
  699. ).to(speech.device)
  700. ).repeat(speech.size(0), 1, 1)
  701. use_itn = kwargs.get("use_itn", False)
  702. textnorm = kwargs.get("text_norm", None)
  703. if textnorm is None:
  704. textnorm = "withitn" if use_itn else "woitn"
  705. textnorm_query = self.embed(
  706. torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
  707. ).repeat(speech.size(0), 1, 1)
  708. speech = torch.cat((textnorm_query, speech), dim=1)
  709. speech_lengths += 1
  710. event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
  711. speech.size(0), 1, 1
  712. )
  713. input_query = torch.cat((language_query, event_emo_query), dim=1)
  714. speech = torch.cat((input_query, speech), dim=1)
  715. speech_lengths += 3
  716. # Encoder
  717. encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
  718. if isinstance(encoder_out, tuple):
  719. encoder_out = encoder_out[0]
  720. # c. Passed the encoder result and the beam search
  721. ctc_logits = self.ctc.log_softmax(encoder_out)
  722. results = []
  723. b, n, d = encoder_out.size()
  724. if isinstance(key[0], (list, tuple)):
  725. key = key[0]
  726. if len(key) < b:
  727. key = key * b
  728. for i in range(b):
  729. x = ctc_logits[i, : encoder_out_lens[i].item(), :]
  730. yseq = x.argmax(dim=-1)
  731. yseq = torch.unique_consecutive(yseq, dim=-1)
  732. ibest_writer = None
  733. if kwargs.get("output_dir") is not None:
  734. if not hasattr(self, "writer"):
  735. self.writer = DatadirWriter(kwargs.get("output_dir"))
  736. ibest_writer = self.writer[f"1best_recog"]
  737. mask = yseq != self.blank_id
  738. token_int = yseq[mask].tolist()
  739. # Change integer-ids to tokens
  740. text = tokenizer.decode(token_int)
  741. result_i = {"key": key[i], "text": text}
  742. results.append(result_i)
  743. if ibest_writer is not None:
  744. ibest_writer["text"][key[i]] = text
  745. return results, meta_data
  746. def export(self, **kwargs):
  747. from .export_meta import export_rebuild_model
  748. if "max_seq_len" not in kwargs:
  749. kwargs["max_seq_len"] = 512
  750. models = export_rebuild_model(model=self, **kwargs)
  751. return models