api.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940
  1. """
  2. # api.py usage
  3. ` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" `
  4. ## 执行参数:
  5. `-s` - `SoVITS模型路径, 可在 config.py 中指定`
  6. `-g` - `GPT模型路径, 可在 config.py 中指定`
  7. 调用请求缺少参考音频时使用
  8. `-dr` - `默认参考音频路径`
  9. `-dt` - `默认参考音频文本`
  10. `-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"`
  11. `-d` - `推理设备, "cuda","cpu"`
  12. `-a` - `绑定地址, 默认"127.0.0.1"`
  13. `-p` - `绑定端口, 默认9880, 可在 config.py 中指定`
  14. `-fp` - `覆盖 config.py 使用全精度`
  15. `-hp` - `覆盖 config.py 使用半精度`
  16. `-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
  17. ·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
  18. ·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"`
  19. ·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入`
  20. `-hb` - `cnhubert路径`
  21. `-b` - `bert路径`
  22. ## 调用:
  23. ### 推理
  24. endpoint: `/`
  25. 使用执行参数指定的参考音频:
  26. GET:
  27. `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh`
  28. POST:
  29. ```json
  30. {
  31. "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
  32. "text_language": "zh"
  33. }
  34. ```
  35. 使用执行参数指定的参考音频并设定分割符号:
  36. GET:
  37. `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&cut_punc=,。`
  38. POST:
  39. ```json
  40. {
  41. "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
  42. "text_language": "zh",
  43. "cut_punc": ",。",
  44. }
  45. ```
  46. 手动指定当次推理所使用的参考音频:
  47. GET:
  48. `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh`
  49. POST:
  50. ```json
  51. {
  52. "refer_wav_path": "123.wav",
  53. "prompt_text": "一二三。",
  54. "prompt_language": "zh",
  55. "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
  56. "text_language": "zh"
  57. }
  58. ```
  59. RESP:
  60. 成功: 直接返回 wav 音频流, http code 200
  61. 失败: 返回包含错误信息的 json, http code 400
  62. 手动指定当次推理所使用的参考音频,并提供参数:
  63. GET:
  64. `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"`
  65. POST:
  66. ```json
  67. {
  68. "refer_wav_path": "123.wav",
  69. "prompt_text": "一二三。",
  70. "prompt_language": "zh",
  71. "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
  72. "text_language": "zh",
  73. "top_k": 20,
  74. "top_p": 0.6,
  75. "temperature": 0.6,
  76. "speed": 1,
  77. "inp_refs": ["456.wav","789.wav"]
  78. }
  79. ```
  80. RESP:
  81. 成功: 直接返回 wav 音频流, http code 200
  82. 失败: 返回包含错误信息的 json, http code 400
  83. ### 更换默认参考音频
  84. endpoint: `/change_refer`
  85. key与推理端一样
  86. GET:
  87. `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh`
  88. POST:
  89. ```json
  90. {
  91. "refer_wav_path": "123.wav",
  92. "prompt_text": "一二三。",
  93. "prompt_language": "zh"
  94. }
  95. ```
  96. RESP:
  97. 成功: json, http code 200
  98. 失败: json, 400
  99. ### 命令控制
  100. endpoint: `/control`
  101. command:
  102. "restart": 重新运行
  103. "exit": 结束运行
  104. GET:
  105. `http://127.0.0.1:9880/control?command=restart`
  106. POST:
  107. ```json
  108. {
  109. "command": "restart"
  110. }
  111. ```
  112. RESP: 无
  113. """
  114. import argparse
  115. import os,re
  116. import sys
  117. now_dir = os.getcwd()
  118. sys.path.append(now_dir)
  119. sys.path.append("%s/GPT_SoVITS" % (now_dir))
  120. import signal
  121. import LangSegment
  122. from time import time as ttime
  123. import torch
  124. import librosa
  125. import soundfile as sf
  126. from fastapi import FastAPI, Request, Query, HTTPException
  127. from fastapi.responses import StreamingResponse, JSONResponse
  128. import uvicorn
  129. from transformers import AutoModelForMaskedLM, AutoTokenizer
  130. import numpy as np
  131. from feature_extractor import cnhubert
  132. from io import BytesIO
  133. from module.models import SynthesizerTrn
  134. from AR.models.t2s_lightning_module import Text2SemanticLightningModule
  135. from text import cleaned_text_to_sequence
  136. from text.cleaner import clean_text
  137. from module.mel_processing import spectrogram_torch
  138. from tools.my_utils import load_audio
  139. import config as global_config
  140. import logging
  141. import subprocess
  142. class DefaultRefer:
  143. def __init__(self, path, text, language):
  144. self.path = args.default_refer_path
  145. self.text = args.default_refer_text
  146. self.language = args.default_refer_language
  147. def is_ready(self) -> bool:
  148. return is_full(self.path, self.text, self.language)
  149. def is_empty(*items): # 任意一项不为空返回False
  150. for item in items:
  151. if item is not None and item != "":
  152. return False
  153. return True
  154. def is_full(*items): # 任意一项为空返回False
  155. for item in items:
  156. if item is None or item == "":
  157. return False
  158. return True
  159. class Speaker:
  160. def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
  161. self.name = name
  162. self.sovits = sovits
  163. self.gpt = gpt
  164. self.phones = phones
  165. self.bert = bert
  166. self.prompt = prompt
  167. speaker_list = {}
  168. class Sovits:
  169. def __init__(self, vq_model, hps):
  170. self.vq_model = vq_model
  171. self.hps = hps
  172. def get_sovits_weights(sovits_path):
  173. dict_s2 = torch.load(sovits_path, map_location="cpu")
  174. hps = dict_s2["config"]
  175. hps = DictToAttrRecursive(hps)
  176. hps.model.semantic_frame_rate = "25hz"
  177. if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
  178. hps.model.version = "v1"
  179. else:
  180. hps.model.version = "v2"
  181. logger.info(f"模型版本: {hps.model.version}")
  182. model_params_dict = vars(hps.model)
  183. vq_model = SynthesizerTrn(
  184. hps.data.filter_length // 2 + 1,
  185. hps.train.segment_size // hps.data.hop_length,
  186. n_speakers=hps.data.n_speakers,
  187. **model_params_dict
  188. )
  189. if ("pretrained" not in sovits_path):
  190. del vq_model.enc_q
  191. if is_half == True:
  192. vq_model = vq_model.half().to(device)
  193. else:
  194. vq_model = vq_model.to(device)
  195. vq_model.eval()
  196. vq_model.load_state_dict(dict_s2["weight"], strict=False)
  197. sovits = Sovits(vq_model, hps)
  198. return sovits
  199. class Gpt:
  200. def __init__(self, max_sec, t2s_model):
  201. self.max_sec = max_sec
  202. self.t2s_model = t2s_model
  203. global hz
  204. hz = 50
  205. def get_gpt_weights(gpt_path):
  206. dict_s1 = torch.load(gpt_path, map_location="cpu")
  207. config = dict_s1["config"]
  208. max_sec = config["data"]["max_sec"]
  209. t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
  210. t2s_model.load_state_dict(dict_s1["weight"])
  211. if is_half == True:
  212. t2s_model = t2s_model.half()
  213. t2s_model = t2s_model.to(device)
  214. t2s_model.eval()
  215. total = sum([param.nelement() for param in t2s_model.parameters()])
  216. logger.info("Number of parameter: %.2fM" % (total / 1e6))
  217. gpt = Gpt(max_sec, t2s_model)
  218. return gpt
  219. def change_gpt_sovits_weights(gpt_path,sovits_path):
  220. try:
  221. gpt = get_gpt_weights(gpt_path)
  222. sovits = get_sovits_weights(sovits_path)
  223. except Exception as e:
  224. return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
  225. speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
  226. return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
  227. def get_bert_feature(text, word2ph):
  228. with torch.no_grad():
  229. inputs = tokenizer(text, return_tensors="pt")
  230. for i in inputs:
  231. inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
  232. res = bert_model(**inputs, output_hidden_states=True)
  233. res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
  234. assert len(word2ph) == len(text)
  235. phone_level_feature = []
  236. for i in range(len(word2ph)):
  237. repeat_feature = res[i].repeat(word2ph[i], 1)
  238. phone_level_feature.append(repeat_feature)
  239. phone_level_feature = torch.cat(phone_level_feature, dim=0)
  240. # if(is_half==True):phone_level_feature=phone_level_feature.half()
  241. return phone_level_feature.T
  242. def clean_text_inf(text, language, version):
  243. phones, word2ph, norm_text = clean_text(text, language, version)
  244. phones = cleaned_text_to_sequence(phones, version)
  245. return phones, word2ph, norm_text
  246. def get_bert_inf(phones, word2ph, norm_text, language):
  247. language=language.replace("all_","")
  248. if language == "zh":
  249. bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
  250. else:
  251. bert = torch.zeros(
  252. (1024, len(phones)),
  253. dtype=torch.float16 if is_half == True else torch.float32,
  254. ).to(device)
  255. return bert
  256. from text import chinese
  257. def get_phones_and_bert(text,language,version,final=False):
  258. if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
  259. language = language.replace("all_","")
  260. if language == "en":
  261. LangSegment.setfilters(["en"])
  262. formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
  263. else:
  264. # 因无法区别中日韩文汉字,以用户输入为准
  265. formattext = text
  266. while " " in formattext:
  267. formattext = formattext.replace(" ", " ")
  268. if language == "zh":
  269. if re.search(r'[A-Za-z]', formattext):
  270. formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
  271. formattext = chinese.mix_text_normalize(formattext)
  272. return get_phones_and_bert(formattext,"zh",version)
  273. else:
  274. phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
  275. bert = get_bert_feature(norm_text, word2ph).to(device)
  276. elif language == "yue" and re.search(r'[A-Za-z]', formattext):
  277. formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
  278. formattext = chinese.mix_text_normalize(formattext)
  279. return get_phones_and_bert(formattext,"yue",version)
  280. else:
  281. phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
  282. bert = torch.zeros(
  283. (1024, len(phones)),
  284. dtype=torch.float16 if is_half == True else torch.float32,
  285. ).to(device)
  286. elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
  287. textlist=[]
  288. langlist=[]
  289. LangSegment.setfilters(["zh","ja","en","ko"])
  290. if language == "auto":
  291. for tmp in LangSegment.getTexts(text):
  292. langlist.append(tmp["lang"])
  293. textlist.append(tmp["text"])
  294. elif language == "auto_yue":
  295. for tmp in LangSegment.getTexts(text):
  296. if tmp["lang"] == "zh":
  297. tmp["lang"] = "yue"
  298. langlist.append(tmp["lang"])
  299. textlist.append(tmp["text"])
  300. else:
  301. for tmp in LangSegment.getTexts(text):
  302. if tmp["lang"] == "en":
  303. langlist.append(tmp["lang"])
  304. else:
  305. # 因无法区别中日韩文汉字,以用户输入为准
  306. langlist.append(language)
  307. textlist.append(tmp["text"])
  308. phones_list = []
  309. bert_list = []
  310. norm_text_list = []
  311. for i in range(len(textlist)):
  312. lang = langlist[i]
  313. phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
  314. bert = get_bert_inf(phones, word2ph, norm_text, lang)
  315. phones_list.append(phones)
  316. norm_text_list.append(norm_text)
  317. bert_list.append(bert)
  318. bert = torch.cat(bert_list, dim=1)
  319. phones = sum(phones_list, [])
  320. norm_text = ''.join(norm_text_list)
  321. if not final and len(phones) < 6:
  322. return get_phones_and_bert("." + text,language,version,final=True)
  323. return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
  324. class DictToAttrRecursive(dict):
  325. def __init__(self, input_dict):
  326. super().__init__(input_dict)
  327. for key, value in input_dict.items():
  328. if isinstance(value, dict):
  329. value = DictToAttrRecursive(value)
  330. self[key] = value
  331. setattr(self, key, value)
  332. def __getattr__(self, item):
  333. try:
  334. return self[item]
  335. except KeyError:
  336. raise AttributeError(f"Attribute {item} not found")
  337. def __setattr__(self, key, value):
  338. if isinstance(value, dict):
  339. value = DictToAttrRecursive(value)
  340. super(DictToAttrRecursive, self).__setitem__(key, value)
  341. super().__setattr__(key, value)
  342. def __delattr__(self, item):
  343. try:
  344. del self[item]
  345. except KeyError:
  346. raise AttributeError(f"Attribute {item} not found")
  347. def get_spepc(hps, filename):
  348. audio,_ = librosa.load(filename, int(hps.data.sampling_rate))
  349. audio = torch.FloatTensor(audio)
  350. maxx=audio.abs().max()
  351. if(maxx>1):
  352. audio/=min(2,maxx)
  353. audio_norm = audio
  354. audio_norm = audio_norm.unsqueeze(0)
  355. spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
  356. hps.data.win_length, center=False)
  357. return spec
  358. def pack_audio(audio_bytes, data, rate):
  359. if media_type == "ogg":
  360. audio_bytes = pack_ogg(audio_bytes, data, rate)
  361. elif media_type == "aac":
  362. audio_bytes = pack_aac(audio_bytes, data, rate)
  363. else:
  364. # wav无法流式, 先暂存raw
  365. audio_bytes = pack_raw(audio_bytes, data, rate)
  366. return audio_bytes
  367. def pack_ogg(audio_bytes, data, rate):
  368. # Author: AkagawaTsurunaki
  369. # Issue:
  370. # Stack overflow probabilistically occurs
  371. # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called
  372. # using the Python library `soundfile`
  373. # Note:
  374. # This is an issue related to `libsndfile`, not this project itself.
  375. # It happens when you generate a large audio tensor (about 499804 frames in my PC)
  376. # and try to convert it to an ogg file.
  377. # Related:
  378. # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199
  379. # https://github.com/libsndfile/libsndfile/issues/1023
  380. # https://github.com/bastibe/python-soundfile/issues/396
  381. # Suggestion:
  382. # Or split the whole audio data into smaller audio segment to avoid stack overflow?
  383. def handle_pack_ogg():
  384. with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
  385. audio_file.write(data)
  386. import threading
  387. # See: https://docs.python.org/3/library/threading.html
  388. # The stack size of this thread is at least 32768
  389. # If stack overflow error still occurs, just modify the `stack_size`.
  390. # stack_size = n * 4096, where n should be a positive integer.
  391. # Here we chose n = 4096.
  392. stack_size = 4096 * 4096
  393. try:
  394. threading.stack_size(stack_size)
  395. pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
  396. pack_ogg_thread.start()
  397. pack_ogg_thread.join()
  398. except RuntimeError as e:
  399. # If changing the thread stack size is unsupported, a RuntimeError is raised.
  400. print("RuntimeError: {}".format(e))
  401. print("Changing the thread stack size is unsupported.")
  402. except ValueError as e:
  403. # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified.
  404. print("ValueError: {}".format(e))
  405. print("The specified stack size is invalid.")
  406. return audio_bytes
  407. def pack_raw(audio_bytes, data, rate):
  408. audio_bytes.write(data.tobytes())
  409. return audio_bytes
  410. def pack_wav(audio_bytes, rate):
  411. if is_int32:
  412. data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int32)
  413. wav_bytes = BytesIO()
  414. sf.write(wav_bytes, data, rate, format='WAV', subtype='PCM_32')
  415. else:
  416. data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
  417. wav_bytes = BytesIO()
  418. sf.write(wav_bytes, data, rate, format='WAV')
  419. return wav_bytes
  420. def pack_aac(audio_bytes, data, rate):
  421. if is_int32:
  422. pcm = 's32le'
  423. bit_rate = '256k'
  424. else:
  425. pcm = 's16le'
  426. bit_rate = '128k'
  427. process = subprocess.Popen([
  428. 'ffmpeg',
  429. '-f', pcm, # 输入16位有符号小端整数PCM
  430. '-ar', str(rate), # 设置采样率
  431. '-ac', '1', # 单声道
  432. '-i', 'pipe:0', # 从管道读取输入
  433. '-c:a', 'aac', # 音频编码器为AAC
  434. '-b:a', bit_rate, # 比特率
  435. '-vn', # 不包含视频
  436. '-f', 'adts', # 输出AAC数据流格式
  437. 'pipe:1' # 将输出写入管道
  438. ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  439. out, _ = process.communicate(input=data.tobytes())
  440. audio_bytes.write(out)
  441. return audio_bytes
  442. def read_clean_buffer(audio_bytes):
  443. audio_chunk = audio_bytes.getvalue()
  444. audio_bytes.truncate(0)
  445. audio_bytes.seek(0)
  446. return audio_bytes, audio_chunk
  447. def cut_text(text, punc):
  448. punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}]
  449. if len(punc_list) > 0:
  450. punds = r"[" + "".join(punc_list) + r"]"
  451. text = text.strip("\n")
  452. items = re.split(f"({punds})", text)
  453. mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
  454. # 在句子不存在符号或句尾无符号的时候保证文本完整
  455. if len(items)%2 == 1:
  456. mergeitems.append(items[-1])
  457. text = "\n".join(mergeitems)
  458. while "\n\n" in text:
  459. text = text.replace("\n\n", "\n")
  460. return text
  461. def only_punc(text):
  462. return not any(t.isalnum() or t.isalpha() for t in text)
  463. splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
  464. def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
  465. infer_sovits = speaker_list[spk].sovits
  466. vq_model = infer_sovits.vq_model
  467. hps = infer_sovits.hps
  468. infer_gpt = speaker_list[spk].gpt
  469. t2s_model = infer_gpt.t2s_model
  470. max_sec = infer_gpt.max_sec
  471. t0 = ttime()
  472. prompt_text = prompt_text.strip("\n")
  473. if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
  474. prompt_language, text = prompt_language, text.strip("\n")
  475. dtype = torch.float16 if is_half == True else torch.float32
  476. zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
  477. with torch.no_grad():
  478. wav16k, sr = librosa.load(ref_wav_path, sr=16000)
  479. wav16k = torch.from_numpy(wav16k)
  480. zero_wav_torch = torch.from_numpy(zero_wav)
  481. if (is_half == True):
  482. wav16k = wav16k.half().to(device)
  483. zero_wav_torch = zero_wav_torch.half().to(device)
  484. else:
  485. wav16k = wav16k.to(device)
  486. zero_wav_torch = zero_wav_torch.to(device)
  487. wav16k = torch.cat([wav16k, zero_wav_torch])
  488. ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
  489. codes = vq_model.extract_latent(ssl_content)
  490. prompt_semantic = codes[0, 0]
  491. prompt = prompt_semantic.unsqueeze(0).to(device)
  492. refers=[]
  493. if(inp_refs):
  494. for path in inp_refs:
  495. try:
  496. refer = get_spepc(hps, path).to(dtype).to(device)
  497. refers.append(refer)
  498. except Exception as e:
  499. logger.error(e)
  500. if(len(refers)==0):
  501. refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
  502. t1 = ttime()
  503. version = vq_model.version
  504. os.environ['version'] = version
  505. prompt_language = dict_language[prompt_language.lower()]
  506. text_language = dict_language[text_language.lower()]
  507. phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
  508. texts = text.split("\n")
  509. audio_bytes = BytesIO()
  510. for text in texts:
  511. # 简单防止纯符号引发参考音频泄露
  512. if only_punc(text):
  513. continue
  514. audio_opt = []
  515. if (text[-1] not in splits): text += "。" if text_language != "en" else "."
  516. phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
  517. bert = torch.cat([bert1, bert2], 1)
  518. all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
  519. bert = bert.to(device).unsqueeze(0)
  520. all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
  521. t2 = ttime()
  522. with torch.no_grad():
  523. pred_semantic, idx = t2s_model.model.infer_panel(
  524. all_phoneme_ids,
  525. all_phoneme_len,
  526. prompt,
  527. bert,
  528. # prompt_phone_len=ph_offset,
  529. top_k = top_k,
  530. top_p = top_p,
  531. temperature = temperature,
  532. early_stop_num=hz * max_sec)
  533. pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
  534. t3 = ttime()
  535. audio = \
  536. vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
  537. refers,speed=speed).detach().cpu().numpy()[
  538. 0, 0] ###试试重建不带上prompt部分
  539. max_audio=np.abs(audio).max()
  540. if max_audio>1:
  541. audio/=max_audio
  542. audio_opt.append(audio)
  543. audio_opt.append(zero_wav)
  544. t4 = ttime()
  545. if is_int32:
  546. audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate)
  547. else:
  548. audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
  549. # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
  550. if stream_mode == "normal":
  551. audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
  552. yield audio_chunk
  553. if not stream_mode == "normal":
  554. if media_type == "wav":
  555. audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate)
  556. yield audio_bytes.getvalue()
  557. def handle_control(command):
  558. if command == "restart":
  559. os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
  560. elif command == "exit":
  561. os.kill(os.getpid(), signal.SIGTERM)
  562. exit(0)
  563. def handle_change(path, text, language):
  564. if is_empty(path, text, language):
  565. return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400)
  566. if path != "" or path is not None:
  567. default_refer.path = path
  568. if text != "" or text is not None:
  569. default_refer.text = text
  570. if language != "" or language is not None:
  571. default_refer.language = language
  572. logger.info(f"当前默认参考音频路径: {default_refer.path}")
  573. logger.info(f"当前默认参考音频文本: {default_refer.text}")
  574. logger.info(f"当前默认参考音频语种: {default_refer.language}")
  575. logger.info(f"is_ready: {default_refer.is_ready()}")
  576. return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
  577. def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs):
  578. if (
  579. refer_wav_path == "" or refer_wav_path is None
  580. or prompt_text == "" or prompt_text is None
  581. or prompt_language == "" or prompt_language is None
  582. ):
  583. refer_wav_path, prompt_text, prompt_language = (
  584. default_refer.path,
  585. default_refer.text,
  586. default_refer.language,
  587. )
  588. if not default_refer.is_ready():
  589. return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
  590. if cut_punc == None:
  591. text = cut_text(text,default_cut_punc)
  592. else:
  593. text = cut_text(text,cut_punc)
  594. return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type)
  595. # --------------------------------
  596. # 初始化部分
  597. # --------------------------------
  598. dict_language = {
  599. "中文": "all_zh",
  600. "粤语": "all_yue",
  601. "英文": "en",
  602. "日文": "all_ja",
  603. "韩文": "all_ko",
  604. "中英混合": "zh",
  605. "粤英混合": "yue",
  606. "日英混合": "ja",
  607. "韩英混合": "ko",
  608. "多语种混合": "auto", #多语种启动切分识别语种
  609. "多语种混合(粤语)": "auto_yue",
  610. "all_zh": "all_zh",
  611. "all_yue": "all_yue",
  612. "en": "en",
  613. "all_ja": "all_ja",
  614. "all_ko": "all_ko",
  615. "zh": "zh",
  616. "yue": "yue",
  617. "ja": "ja",
  618. "ko": "ko",
  619. "auto": "auto",
  620. "auto_yue": "auto_yue",
  621. }
  622. # logger
  623. logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
  624. logger = logging.getLogger('uvicorn')
  625. # 获取配置
  626. g_config = global_config.Config()
  627. # 获取参数
  628. parser = argparse.ArgumentParser(description="GPT-SoVITS api")
  629. parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
  630. parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
  631. parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径")
  632. parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
  633. parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
  634. parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
  635. parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
  636. parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
  637. parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
  638. parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
  639. # bool值的用法为 `python ./api.py -fp ...`
  640. # 此时 full_precision==True, half_precision==False
  641. parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
  642. parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac")
  643. parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32")
  644. parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…")
  645. # 切割常用分句符为 `python ./api.py -cp ".?!。?!"`
  646. parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
  647. parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
  648. args = parser.parse_args()
  649. sovits_path = args.sovits_path
  650. gpt_path = args.gpt_path
  651. device = args.device
  652. port = args.port
  653. host = args.bind_addr
  654. cnhubert_base_path = args.hubert_path
  655. bert_path = args.bert_path
  656. default_cut_punc = args.cut_punc
  657. # 应用参数配置
  658. default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
  659. # 模型路径检查
  660. if sovits_path == "":
  661. sovits_path = g_config.pretrained_sovits_path
  662. logger.warn(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
  663. if gpt_path == "":
  664. gpt_path = g_config.pretrained_gpt_path
  665. logger.warn(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
  666. # 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
  667. if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
  668. default_refer.path, default_refer.text, default_refer.language = "", "", ""
  669. logger.info("未指定默认参考音频")
  670. else:
  671. logger.info(f"默认参考音频路径: {default_refer.path}")
  672. logger.info(f"默认参考音频文本: {default_refer.text}")
  673. logger.info(f"默认参考音频语种: {default_refer.language}")
  674. # 获取半精度
  675. is_half = g_config.is_half
  676. if args.full_precision:
  677. is_half = False
  678. if args.half_precision:
  679. is_half = True
  680. if args.full_precision and args.half_precision:
  681. is_half = g_config.is_half # 炒饭fallback
  682. logger.info(f"半精: {is_half}")
  683. # 流式返回模式
  684. if args.stream_mode.lower() in ["normal","n"]:
  685. stream_mode = "normal"
  686. logger.info("流式返回已开启")
  687. else:
  688. stream_mode = "close"
  689. # 音频编码格式
  690. if args.media_type.lower() in ["aac","ogg"]:
  691. media_type = args.media_type.lower()
  692. elif stream_mode == "close":
  693. media_type = "wav"
  694. else:
  695. media_type = "ogg"
  696. logger.info(f"编码格式: {media_type}")
  697. # 音频数据类型
  698. if args.sub_type.lower() == 'int32':
  699. is_int32 = True
  700. logger.info(f"数据类型: int32")
  701. else:
  702. is_int32 = False
  703. logger.info(f"数据类型: int16")
  704. # 初始化模型
  705. cnhubert.cnhubert_base_path = cnhubert_base_path
  706. tokenizer = AutoTokenizer.from_pretrained(bert_path)
  707. bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
  708. ssl_model = cnhubert.get_model()
  709. if is_half:
  710. bert_model = bert_model.half().to(device)
  711. ssl_model = ssl_model.half().to(device)
  712. else:
  713. bert_model = bert_model.to(device)
  714. ssl_model = ssl_model.to(device)
  715. change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path)
  716. # --------------------------------
  717. # 接口部分
  718. # --------------------------------
  719. app = FastAPI()
  720. @app.post("/set_model")
  721. async def set_model(request: Request):
  722. json_post_raw = await request.json()
  723. return change_gpt_sovits_weights(
  724. gpt_path = json_post_raw.get("gpt_model_path"),
  725. sovits_path = json_post_raw.get("sovits_model_path")
  726. )
  727. @app.get("/set_model")
  728. async def set_model(
  729. gpt_model_path: str = None,
  730. sovits_model_path: str = None,
  731. ):
  732. return change_gpt_sovits_weights(gpt_path = gpt_model_path, sovits_path = sovits_model_path)
  733. @app.post("/control")
  734. async def control(request: Request):
  735. json_post_raw = await request.json()
  736. return handle_control(json_post_raw.get("command"))
  737. @app.get("/control")
  738. async def control(command: str = None):
  739. return handle_control(command)
  740. @app.post("/change_refer")
  741. async def change_refer(request: Request):
  742. json_post_raw = await request.json()
  743. return handle_change(
  744. json_post_raw.get("refer_wav_path"),
  745. json_post_raw.get("prompt_text"),
  746. json_post_raw.get("prompt_language")
  747. )
  748. @app.get("/change_refer")
  749. async def change_refer(
  750. refer_wav_path: str = None,
  751. prompt_text: str = None,
  752. prompt_language: str = None
  753. ):
  754. return handle_change(refer_wav_path, prompt_text, prompt_language)
  755. @app.post("/")
  756. async def tts_endpoint(request: Request):
  757. json_post_raw = await request.json()
  758. return handle(
  759. json_post_raw.get("refer_wav_path"),
  760. json_post_raw.get("prompt_text"),
  761. json_post_raw.get("prompt_language"),
  762. json_post_raw.get("text"),
  763. json_post_raw.get("text_language"),
  764. json_post_raw.get("cut_punc"),
  765. json_post_raw.get("top_k", 15),
  766. json_post_raw.get("top_p", 1.0),
  767. json_post_raw.get("temperature", 1.0),
  768. json_post_raw.get("speed", 1.0),
  769. json_post_raw.get("inp_refs", [])
  770. )
  771. @app.get("/")
  772. async def tts_endpoint(
  773. refer_wav_path: str = None,
  774. prompt_text: str = None,
  775. prompt_language: str = None,
  776. text: str = None,
  777. text_language: str = None,
  778. cut_punc: str = None,
  779. top_k: int = 15,
  780. top_p: float = 1.0,
  781. temperature: float = 1.0,
  782. speed: float = 1.0,
  783. inp_refs: list = Query(default=[])
  784. ):
  785. return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs)
  786. if __name__ == "__main__":
  787. uvicorn.run(app, host=host, port=port, workers=1)