finetune.sh 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  2. # MIT License (https://opensource.org/licenses/MIT)
  3. workspace=`pwd`
  4. # which gpu to train or finetune
  5. export CUDA_VISIBLE_DEVICES="0,1"
  6. gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
  7. # model_name from model_hub, or model_dir in local path
  8. ## option 1, download model automatically
  9. model_name_or_model_dir="iic/SenseVoiceCTC"
  10. ## option 2, download model by git
  11. #local_path_root=${workspace}/modelscope_models
  12. #mkdir -p ${local_path_root}/${model_name_or_model_dir}
  13. #git clone https://www.modelscope.cn/${model_name_or_model_dir}.git ${local_path_root}/${model_name_or_model_dir}
  14. #model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
  15. # data dir, which contains: train.json, val.json
  16. train_data=${workspace}/data/train_example.jsonl
  17. val_data=${workspace}/data/val_example.jsonl
  18. # exp output dir
  19. output_dir="./outputs"
  20. log_file="${output_dir}/log.txt"
  21. deepspeed_config=${workspace}/deepspeed_conf/ds_stage1.json
  22. mkdir -p ${output_dir}
  23. echo "log_file: ${log_file}"
  24. DISTRIBUTED_ARGS="
  25. --nnodes ${WORLD_SIZE:-1} \
  26. --nproc_per_node $gpu_num \
  27. --node_rank ${RANK:-0} \
  28. --master_addr ${MASTER_ADDR:-127.0.0.1} \
  29. --master_port ${MASTER_PORT:-26669}
  30. "
  31. echo $DISTRIBUTED_ARGS
  32. # funasr trainer path
  33. train_tool=`dirname $(which funasr)`/train_ds.py
  34. torchrun $DISTRIBUTED_ARGS \
  35. ${train_tool} \
  36. ++model="${model_name_or_model_dir}" \
  37. ++trust_remote_code=true \
  38. ++train_data_set_list="${train_data}" \
  39. ++valid_data_set_list="${val_data}" \
  40. ++dataset_conf.data_split_num=1 \
  41. ++dataset_conf.batch_sampler="BatchSampler" \
  42. ++dataset_conf.batch_size=6000 \
  43. ++dataset_conf.sort_size=1024 \
  44. ++dataset_conf.batch_type="token" \
  45. ++dataset_conf.num_workers=4 \
  46. ++train_conf.max_epoch=50 \
  47. ++train_conf.log_interval=1 \
  48. ++train_conf.resume=true \
  49. ++train_conf.validate_interval=2000 \
  50. ++train_conf.save_checkpoint_interval=2000 \
  51. ++train_conf.keep_nbest_models=20 \
  52. ++train_conf.avg_nbest_model=10 \
  53. ++train_conf.use_deepspeed=false \
  54. ++train_conf.deepspeed_config=${deepspeed_config} \
  55. ++optim_conf.lr=0.0002 \
  56. ++output_dir="${output_dir}" &> ${log_file}