尝试解决fp16溢出问题
Browse files- README.md +11 -15
- convert_rknn.py +152 -24
- sense-voice-encoder.rknn +2 -2
- sensevoice_rknn.py +81 -75
README.md
CHANGED
|
@@ -28,12 +28,10 @@ SenseVoice是具有音频理解能力的音频基础模型, 包括语音识别
|
|
| 28 |
2. 安装依赖
|
| 29 |
|
| 30 |
```bash
|
| 31 |
-
pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml numpy<2
|
| 32 |
```
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
3. 运行
|
| 37 |
|
| 38 |
```bash
|
| 39 |
python ./sensevoice_rknn.py --audio_file output.wav
|
|
@@ -47,7 +45,7 @@ ffmpeg -i input.mp3 -f wav -acodec pcm_s16le -ac 1 -ar 16000 output.wav
|
|
| 47 |
|
| 48 |
## RKNN模型转换
|
| 49 |
|
| 50 |
-
你需要提前安装rknn-toolkit2
|
| 51 |
|
| 52 |
1. 下载或转换onnx模型
|
| 53 |
|
|
@@ -58,13 +56,13 @@ ffmpeg -i input.mp3 -f wav -acodec pcm_s16le -ac 1 -ar 16000 output.wav
|
|
| 58 |
|
| 59 |
2. 转换为rknn模型
|
| 60 |
```bash
|
| 61 |
-
python convert_rknn.py
|
| 62 |
```
|
| 63 |
|
| 64 |
## 已知问题
|
| 65 |
|
| 66 |
-
- RKNN2使用fp16推理时可能会出现溢出,导致结果为inf,可以尝试修改输入数据的缩放比例来解决.
|
| 67 |
-
在`sensevoice_rknn.py`中将`SPEECH_SCALE
|
| 68 |
|
| 69 |
## 参考
|
| 70 |
- [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
|
|
@@ -88,11 +86,9 @@ Currently, SenseVoice-small supports multilingual speech recognition, emotion re
|
|
| 88 |
2. Install dependencies
|
| 89 |
|
| 90 |
```bash
|
| 91 |
-
pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml numpy<2
|
| 92 |
```
|
| 93 |
|
| 94 |
-
You also need to manually install rknn-toolkit2-lite2.
|
| 95 |
-
|
| 96 |
3. Run
|
| 97 |
|
| 98 |
```bash
|
|
@@ -107,7 +103,7 @@ ffmpeg -i input.mp3 -f wav -acodec pcm_s16le -ac 1 -ar 16000 output.wav
|
|
| 107 |
|
| 108 |
## RKNN Model Conversion
|
| 109 |
|
| 110 |
-
You need to install rknn-toolkit2
|
| 111 |
|
| 112 |
1. Download or convert the ONNX model
|
| 113 |
|
|
@@ -118,13 +114,13 @@ The model file should be named 'sense-voice-encoder.onnx' and placed in the same
|
|
| 118 |
|
| 119 |
2. Convert to RKNN model
|
| 120 |
```bash
|
| 121 |
-
python convert_rknn.py
|
| 122 |
```
|
| 123 |
|
| 124 |
## Known Issues
|
| 125 |
|
| 126 |
-
- When using fp16 inference with RKNN2, overflow may occur, resulting in inf values. You can try modifying the scaling ratio of the input data to resolve this.
|
| 127 |
-
Set `SPEECH_SCALE` to a smaller value in `sensevoice_rknn.py
|
| 128 |
|
| 129 |
## References
|
| 130 |
- [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
|
|
|
|
| 28 |
2. 安装依赖
|
| 29 |
|
| 30 |
```bash
|
| 31 |
+
pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml "numpy<2" rknn-toolkit-lite2
|
| 32 |
```
|
| 33 |
|
| 34 |
+
1. 运行
|
|
|
|
|
|
|
| 35 |
|
| 36 |
```bash
|
| 37 |
python ./sensevoice_rknn.py --audio_file output.wav
|
|
|
|
| 45 |
|
| 46 |
## RKNN模型转换
|
| 47 |
|
| 48 |
+
你需要提前安装rknn-toolkit2, 测试可用的版本为2.3.3a25,可从https://console.zbox.filez.com/l/I00fc3 下载(密码为"rknn")
|
| 49 |
|
| 50 |
1. 下载或转换onnx模型
|
| 51 |
|
|
|
|
| 56 |
|
| 57 |
2. 转换为rknn模型
|
| 58 |
```bash
|
| 59 |
+
python convert_rknn.py ./sense-voice-encoder.onnx
|
| 60 |
```
|
| 61 |
|
| 62 |
## 已知问题
|
| 63 |
|
| 64 |
+
- ~~RKNN2使用fp16推理时可能会出现溢出,导致结果为inf,可以尝试修改输入数据的缩放比例来解决.
|
| 65 |
+
在`sensevoice_rknn.py`中将`SPEECH_SCALE`设置为更小的值.~~ (现在应该已经通过模型内部插入缩放算子解决了)
|
| 66 |
|
| 67 |
## 参考
|
| 68 |
- [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
|
|
|
|
| 86 |
2. Install dependencies
|
| 87 |
|
| 88 |
```bash
|
| 89 |
+
pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml "numpy<2" rknn-toolkit-lite2
|
| 90 |
```
|
| 91 |
|
|
|
|
|
|
|
| 92 |
3. Run
|
| 93 |
|
| 94 |
```bash
|
|
|
|
| 103 |
|
| 104 |
## RKNN Model Conversion
|
| 105 |
|
| 106 |
+
You need to install rknn-toolkit2 in advance. The tested working version is 2.3.3a25, which can be downloaded from https://console.zbox.filez.com/l/I00fc3 (password: "rknn").
|
| 107 |
|
| 108 |
1. Download or convert the ONNX model
|
| 109 |
|
|
|
|
| 114 |
|
| 115 |
2. Convert to RKNN model
|
| 116 |
```bash
|
| 117 |
+
python convert_rknn.py ./sense-voice-encoder.onnx
|
| 118 |
```
|
| 119 |
|
| 120 |
## Known Issues
|
| 121 |
|
| 122 |
+
- ~~When using fp16 inference with RKNN2, overflow may occur, resulting in inf values. You can try modifying the scaling ratio of the input data to resolve this.
|
| 123 |
+
Set `SPEECH_SCALE` to a smaller value in `sensevoice_rknn.py`.~~ (This issue should now be resolved by inserting scaling operators inside the model.)
|
| 124 |
|
| 125 |
## References
|
| 126 |
- [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
|
convert_rknn.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
| 2 |
# coding: utf-8
|
| 3 |
|
| 4 |
import os
|
|
|
|
|
|
|
| 5 |
from rknn.api import RKNN
|
| 6 |
from math import exp
|
| 7 |
from sys import exit
|
|
@@ -18,28 +20,158 @@ os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
|
| 18 |
|
| 19 |
speech_length = 171
|
| 20 |
|
| 21 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
rknn = RKNN(verbose=True)
|
| 23 |
|
| 24 |
-
ONNX_MODEL=
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
#开局先给���来个大惊喜,rknn做第一步常量折叠的时候就会在这个子图里报错,所以要单独拿出来先跑一遍
|
| 30 |
#然后把这个子图的输出结果保存下来喂给rknn
|
| 31 |
-
|
| 32 |
-
|
|
|
|
| 33 |
extract_result = sess.run(None, {"speech_lengths": np.array([speech_length], dtype=np.int64)})[0]
|
|
|
|
| 34 |
|
| 35 |
-
# 删掉模型最后的多余transpose, 速度从365ms提升到
|
|
|
|
| 36 |
ret = onnx_edit(model = ONNX_MODEL,
|
| 37 |
-
export_path =
|
| 38 |
# # 1, len, 25055 -> 1, 25055, 1, len # 这个是坏的, 我真服了,
|
| 39 |
-
|
| 40 |
-
outputs_transform = {'encoder_out': 'a,b,c->a,c,b'},
|
| 41 |
)
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# pre-process config
|
| 45 |
print('--> Config model')
|
|
@@ -48,8 +180,9 @@ def convert_encoder():
|
|
| 48 |
|
| 49 |
# Load ONNX model
|
| 50 |
print("--> Loading model")
|
|
|
|
| 51 |
ret = rknn.load_onnx(
|
| 52 |
-
model=
|
| 53 |
inputs=["speech", "/make_pad_mask/Cast_2_output_0"],
|
| 54 |
input_size_list=[[1, speech_length, 560], [extract_result.shape[0], extract_result.shape[1]]],
|
| 55 |
input_initial_val=[None, extract_result],
|
|
@@ -60,6 +193,7 @@ def convert_encoder():
|
|
| 60 |
print('Load model failed!')
|
| 61 |
exit(ret)
|
| 62 |
print('done')
|
|
|
|
| 63 |
|
| 64 |
# Build model
|
| 65 |
print('--> Building model')
|
|
@@ -76,20 +210,14 @@ def convert_encoder():
|
|
| 76 |
print('Export RKNN model failed!')
|
| 77 |
exit(ret)
|
| 78 |
print('done')
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
# usage: python convert_rknn.py
|
| 81 |
|
| 82 |
if __name__ == "__main__":
|
| 83 |
parser = argparse.ArgumentParser()
|
| 84 |
-
parser.add_argument("
|
| 85 |
args = parser.parse_args()
|
| 86 |
-
if args.model is None:
|
| 87 |
-
args.model = "all"
|
| 88 |
|
| 89 |
-
|
| 90 |
-
convert_encoder()
|
| 91 |
-
elif args.model == "all":
|
| 92 |
-
convert_encoder()
|
| 93 |
-
else:
|
| 94 |
-
print(f"Unknown model: {args.model}")
|
| 95 |
-
exit(1)
|
|
|
|
| 2 |
# coding: utf-8
|
| 3 |
|
| 4 |
import os
|
| 5 |
+
import re
|
| 6 |
+
from typing import Optional, Set
|
| 7 |
from rknn.api import RKNN
|
| 8 |
from math import exp
|
| 9 |
from sys import exit
|
|
|
|
| 20 |
|
| 21 |
speech_length = 171
|
| 22 |
|
| 23 |
+
def _remove_file(path: str, *, keep: Optional[Set[str]] = None) -> None:
|
| 24 |
+
if not path:
|
| 25 |
+
return
|
| 26 |
+
keep_paths: Set[str] = {os.path.abspath(item) for item in keep} if keep else set()
|
| 27 |
+
normalized = os.path.abspath(path)
|
| 28 |
+
if keep_paths and normalized in keep_paths:
|
| 29 |
+
return
|
| 30 |
+
if not os.path.exists(normalized):
|
| 31 |
+
return
|
| 32 |
+
try:
|
| 33 |
+
os.remove(normalized)
|
| 34 |
+
print(f'cleaned temp model: {normalized}')
|
| 35 |
+
except OSError as err:
|
| 36 |
+
print(f'warning: failed to remove {normalized}: {err}')
|
| 37 |
+
|
| 38 |
+
def _with_suffix(path: str, suffix: str) -> str:
|
| 39 |
+
stem, ext = os.path.splitext(path)
|
| 40 |
+
return f"{stem}{suffix}{ext}"
|
| 41 |
+
|
| 42 |
+
def _sanitize_name(name: str) -> str:
|
| 43 |
+
return re.sub(r'[^0-9A-Za-z_]', '_', name)
|
| 44 |
+
|
| 45 |
+
def _insert_div_node(model: onnx.ModelProto, tensor_name: str, divisor: float = 16.0) -> bool:
|
| 46 |
+
graph = model.graph
|
| 47 |
+
|
| 48 |
+
for node in graph.node:
|
| 49 |
+
if node.op_type == 'Div' and tensor_name in node.output:
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
producer_index = None
|
| 53 |
+
output_index = None
|
| 54 |
+
for idx, node in enumerate(graph.node):
|
| 55 |
+
for out_idx, output in enumerate(node.output):
|
| 56 |
+
if output == tensor_name:
|
| 57 |
+
producer_index = idx
|
| 58 |
+
output_index = out_idx
|
| 59 |
+
producer_node = node
|
| 60 |
+
break
|
| 61 |
+
if producer_index is not None:
|
| 62 |
+
break
|
| 63 |
+
|
| 64 |
+
if producer_index is None:
|
| 65 |
+
raise RuntimeError(f"Producer node for tensor {tensor_name} not found.")
|
| 66 |
+
|
| 67 |
+
pre_div_output = f"{tensor_name}_pre_div"
|
| 68 |
+
producer_node.output[output_index] = pre_div_output
|
| 69 |
+
|
| 70 |
+
sanitized = _sanitize_name(tensor_name)
|
| 71 |
+
const_output = f"{sanitized}_div_const"
|
| 72 |
+
const_node_name = f"{sanitized}_DivConst"
|
| 73 |
+
div_node_name = f"{sanitized}_Div"
|
| 74 |
+
|
| 75 |
+
const_tensor = onnx.helper.make_tensor(
|
| 76 |
+
name=f"{const_node_name}_value",
|
| 77 |
+
data_type=onnx.TensorProto.FLOAT,
|
| 78 |
+
dims=[],
|
| 79 |
+
vals=[divisor],
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
const_node = onnx.helper.make_node(
|
| 83 |
+
'Constant',
|
| 84 |
+
inputs=[],
|
| 85 |
+
outputs=[const_output],
|
| 86 |
+
value=const_tensor,
|
| 87 |
+
name=const_node_name,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
div_node = onnx.helper.make_node(
|
| 91 |
+
'Div',
|
| 92 |
+
inputs=[pre_div_output, const_output],
|
| 93 |
+
outputs=[tensor_name],
|
| 94 |
+
name=div_node_name,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
graph.node.insert(producer_index + 1, const_node)
|
| 98 |
+
graph.node.insert(producer_index + 2, div_node)
|
| 99 |
+
return True
|
| 100 |
+
|
| 101 |
+
def _scale_initializer(model: onnx.ModelProto, initializer_name: str, divisor: float = 16.0) -> bool:
|
| 102 |
+
for idx, initializer in enumerate(model.graph.initializer):
|
| 103 |
+
if initializer.name == initializer_name:
|
| 104 |
+
data = onh.to_array(initializer).astype(np.float32, copy=False)
|
| 105 |
+
scaled = data / divisor
|
| 106 |
+
model.graph.initializer[idx].CopyFrom(onh.from_array(scaled, name=initializer_name))
|
| 107 |
+
return True
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
def convert_encoder(model_path: str):
|
| 111 |
rknn = RKNN(verbose=True)
|
| 112 |
|
| 113 |
+
ONNX_MODEL = os.path.abspath(model_path)
|
| 114 |
+
if not os.path.isfile(ONNX_MODEL):
|
| 115 |
+
print(f'Model file not found: {model_path}')
|
| 116 |
+
exit(1)
|
| 117 |
+
if not ONNX_MODEL.lower().endswith('.onnx'):
|
| 118 |
+
print(f'Model file must be an ONNX file: {model_path}')
|
| 119 |
+
exit(1)
|
| 120 |
+
|
| 121 |
+
RKNN_MODEL = os.path.splitext(ONNX_MODEL)[0] + ".rknn"
|
| 122 |
+
DATASET = "dataset.txt"
|
| 123 |
+
QUANTIZE = False
|
| 124 |
+
original_model = ONNX_MODEL
|
| 125 |
+
preserve_files: Set[str] = {original_model}
|
| 126 |
+
|
| 127 |
|
| 128 |
+
print('--> Patching model to avoid overflow issue')
|
| 129 |
+
base_model = onnx.load(ONNX_MODEL)
|
| 130 |
+
modified = False
|
| 131 |
+
for layer_idx in range(48, 49):
|
| 132 |
+
for target in [
|
| 133 |
+
f'/encoders.{layer_idx}/feed_forward/activation/Relu_output_0',
|
| 134 |
+
f'/encoders.{layer_idx}/norm2/Cast_output_0',
|
| 135 |
+
]:
|
| 136 |
+
modified |= _insert_div_node(base_model, target, divisor=2.0)
|
| 137 |
+
bias_scaled = False
|
| 138 |
+
if modified:
|
| 139 |
+
for layer_idx in range(48, 49):
|
| 140 |
+
bias_scaled |= _scale_initializer(base_model, f'model.encoders.{layer_idx}.feed_forward.w_2.bias', divisor=2.0)
|
| 141 |
+
div_model_path = _with_suffix(ONNX_MODEL, "_div")
|
| 142 |
+
onnx.save(base_model, div_model_path)
|
| 143 |
+
if os.path.exists(div_model_path):
|
| 144 |
+
previous_model = ONNX_MODEL
|
| 145 |
+
ONNX_MODEL = div_model_path
|
| 146 |
+
_remove_file(previous_model, keep=preserve_files)
|
| 147 |
+
if modified:
|
| 148 |
+
if bias_scaled:
|
| 149 |
+
print('done (created div-adjusted model and scaled bias)')
|
| 150 |
+
else:
|
| 151 |
+
print('done (created div-adjusted model; bias initializer not found)')
|
| 152 |
+
else:
|
| 153 |
+
print('done (div nodes already present)')
|
| 154 |
+
|
| 155 |
#开局先给���来个大惊喜,rknn做第一步常量折叠的时候就会在这个子图里报错,所以要单独拿出来先跑一遍
|
| 156 |
#然后把这个子图的输出结果保存下来喂给rknn
|
| 157 |
+
extract_model_path = os.path.join(os.getcwd(), "extract_model.onnx")
|
| 158 |
+
onnx.utils.extract_model(ONNX_MODEL, extract_model_path, ['speech_lengths'], ['/make_pad_mask/Cast_2_output_0'])
|
| 159 |
+
sess = ort.InferenceSession(extract_model_path, providers=['CPUExecutionProvider'])
|
| 160 |
extract_result = sess.run(None, {"speech_lengths": np.array([speech_length], dtype=np.int64)})[0]
|
| 161 |
+
_remove_file(extract_model_path)
|
| 162 |
|
| 163 |
+
# 删掉模型最后的多余transpose, 速度从365ms提升到259ms
|
| 164 |
+
edited_model_path = _with_suffix(ONNX_MODEL, "_edited")
|
| 165 |
ret = onnx_edit(model = ONNX_MODEL,
|
| 166 |
+
export_path = edited_model_path,
|
| 167 |
# # 1, len, 25055 -> 1, 25055, 1, len # 这个是坏的, 我真服了,
|
| 168 |
+
outputs_transform = {'encoder_out': 'a,b,c->a,c,1,b'},
|
| 169 |
+
# outputs_transform = {'encoder_out': 'a,b,c->a,c,b'},
|
| 170 |
)
|
| 171 |
+
if os.path.exists(edited_model_path):
|
| 172 |
+
previous_model = ONNX_MODEL
|
| 173 |
+
ONNX_MODEL = edited_model_path
|
| 174 |
+
_remove_file(previous_model, keep=preserve_files)
|
| 175 |
|
| 176 |
# pre-process config
|
| 177 |
print('--> Config model')
|
|
|
|
| 180 |
|
| 181 |
# Load ONNX model
|
| 182 |
print("--> Loading model")
|
| 183 |
+
current_model_path = ONNX_MODEL
|
| 184 |
ret = rknn.load_onnx(
|
| 185 |
+
model=current_model_path,
|
| 186 |
inputs=["speech", "/make_pad_mask/Cast_2_output_0"],
|
| 187 |
input_size_list=[[1, speech_length, 560], [extract_result.shape[0], extract_result.shape[1]]],
|
| 188 |
input_initial_val=[None, extract_result],
|
|
|
|
| 193 |
print('Load model failed!')
|
| 194 |
exit(ret)
|
| 195 |
print('done')
|
| 196 |
+
_remove_file(current_model_path, keep=preserve_files)
|
| 197 |
|
| 198 |
# Build model
|
| 199 |
print('--> Building model')
|
|
|
|
| 210 |
print('Export RKNN model failed!')
|
| 211 |
exit(ret)
|
| 212 |
print('done')
|
| 213 |
+
# 精度分析(可选)
|
| 214 |
+
# rknn.accuracy_analysis(inputs=["input_content.npy"], target="rk3588", device_id=None)
|
| 215 |
|
| 216 |
+
# usage: python convert_rknn.py path/to/model.onnx
|
| 217 |
|
| 218 |
if __name__ == "__main__":
|
| 219 |
parser = argparse.ArgumentParser()
|
| 220 |
+
parser.add_argument("model_path", type=str, help="path to source ONNX model")
|
| 221 |
args = parser.parse_args()
|
|
|
|
|
|
|
| 222 |
|
| 223 |
+
convert_encoder(args.model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sense-voice-encoder.rknn
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:198e9c843ed0ffd8500f3b045e47ef5cbc9a4fc36401f6d51c853dacd8129d98
|
| 3 |
+
size 484090175
|
sensevoice_rknn.py
CHANGED
|
@@ -28,7 +28,7 @@ from rknnlite.api.rknn_lite import RKNNLite
|
|
| 28 |
|
| 29 |
RKNN_INPUT_LEN = 171
|
| 30 |
|
| 31 |
-
SPEECH_SCALE = 1
|
| 32 |
|
| 33 |
class VadOrtInferRuntimeSession:
|
| 34 |
def __init__(self, config, root_dir: Path):
|
|
@@ -72,7 +72,7 @@ class VadOrtInferRuntimeSession:
|
|
| 72 |
)
|
| 73 |
|
| 74 |
def __call__(
|
| 75 |
-
self, input_content
|
| 76 |
) -> np.ndarray:
|
| 77 |
if isinstance(input_content, list):
|
| 78 |
input_dict = {
|
|
@@ -97,10 +97,10 @@ class VadOrtInferRuntimeSession:
|
|
| 97 |
):
|
| 98 |
return [v.name for v in self.session.get_outputs()]
|
| 99 |
|
| 100 |
-
def get_character_list(self, key
|
| 101 |
return self.meta_dict[key].splitlines()
|
| 102 |
|
| 103 |
-
def have_key(self, key
|
| 104 |
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
|
| 105 |
if key in self.meta_dict.keys():
|
| 106 |
return True
|
|
@@ -196,10 +196,10 @@ class OrtInferRuntimeSession:
|
|
| 196 |
):
|
| 197 |
return [v.name for v in self.session.get_outputs()]
|
| 198 |
|
| 199 |
-
def get_character_list(self, key
|
| 200 |
return self.meta_dict[key].splitlines()
|
| 201 |
|
| 202 |
-
def have_key(self, key
|
| 203 |
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
|
| 204 |
if key in self.meta_dict.keys():
|
| 205 |
return True
|
|
@@ -250,7 +250,7 @@ class SenseVoiceInferenceSession:
|
|
| 250 |
self.sp = spm.SentencePieceProcessor()
|
| 251 |
self.sp.load(bpe_model_file)
|
| 252 |
|
| 253 |
-
def __call__(self, speech, language
|
| 254 |
language_query = self.embedding[[[language]]]
|
| 255 |
|
| 256 |
# 14 means with itn, 15 means without itn
|
|
@@ -274,6 +274,7 @@ class SenseVoiceInferenceSession:
|
|
| 274 |
input_content = np.pad(input_content, ((0, 0), (0, RKNN_INPUT_LEN - input_content.shape[1]), (0, 0)))
|
| 275 |
print("padded shape:", input_content.shape)
|
| 276 |
start_time = time.time()
|
|
|
|
| 277 |
encoder_out = self.encoder.inference(inputs=[input_content])[0]
|
| 278 |
end_time = time.time()
|
| 279 |
print(f"encoder inference time: {end_time - start_time:.2f} seconds")
|
|
@@ -308,14 +309,14 @@ class WavFrontend:
|
|
| 308 |
|
| 309 |
def __init__(
|
| 310 |
self,
|
| 311 |
-
cmvn_file
|
| 312 |
-
fs
|
| 313 |
-
window
|
| 314 |
-
n_mels
|
| 315 |
-
frame_length
|
| 316 |
-
frame_shift
|
| 317 |
-
lfr_m
|
| 318 |
-
lfr_n
|
| 319 |
dither: float = 0,
|
| 320 |
**kwargs,
|
| 321 |
) -> None:
|
|
@@ -367,7 +368,7 @@ class WavFrontend:
|
|
| 367 |
feat_len = np.array(feat.shape[0]).astype(np.int32)
|
| 368 |
return feat, feat_len
|
| 369 |
|
| 370 |
-
def load_audio(self, filename
|
| 371 |
data, sample_rate = sf.read(
|
| 372 |
filename,
|
| 373 |
always_2d=True,
|
|
@@ -383,7 +384,7 @@ class WavFrontend:
|
|
| 383 |
return samples, sample_rate
|
| 384 |
|
| 385 |
@staticmethod
|
| 386 |
-
def apply_lfr(inputs: np.ndarray, lfr_m
|
| 387 |
LFR_inputs = []
|
| 388 |
|
| 389 |
T = inputs.shape[0]
|
|
@@ -417,7 +418,7 @@ class WavFrontend:
|
|
| 417 |
inputs = (inputs + means) * vars
|
| 418 |
return inputs
|
| 419 |
|
| 420 |
-
def get_features(self, inputs: Union[str, np.ndarray])
|
| 421 |
if isinstance(inputs, str):
|
| 422 |
inputs, _ = self.load_audio(inputs)
|
| 423 |
|
|
@@ -504,35 +505,35 @@ class VadDetectMode(Enum):
|
|
| 504 |
class VADXOptions:
|
| 505 |
def __init__(
|
| 506 |
self,
|
| 507 |
-
sample_rate
|
| 508 |
-
detect_mode
|
| 509 |
-
snr_mode
|
| 510 |
-
max_end_silence_time
|
| 511 |
-
max_start_silence_time
|
| 512 |
do_start_point_detection: bool = True,
|
| 513 |
do_end_point_detection: bool = True,
|
| 514 |
-
window_size_ms
|
| 515 |
-
sil_to_speech_time_thres
|
| 516 |
-
speech_to_sil_time_thres
|
| 517 |
speech_2_noise_ratio: float = 1.0,
|
| 518 |
-
do_extend
|
| 519 |
-
lookback_time_start_point
|
| 520 |
-
lookahead_time_end_point
|
| 521 |
-
max_single_segment_time
|
| 522 |
-
nn_eval_block_size
|
| 523 |
-
dcd_block_size
|
| 524 |
-
snr_thres
|
| 525 |
-
noise_frame_num_used_for_snr
|
| 526 |
-
decibel_thres
|
| 527 |
speech_noise_thres: float = 0.6,
|
| 528 |
fe_prior_thres: float = 1e-4,
|
| 529 |
-
silence_pdf_num
|
| 530 |
sil_pdf_ids: List[int] = [0],
|
| 531 |
speech_noise_thresh_low: float = -0.1,
|
| 532 |
speech_noise_thresh_high: float = 0.3,
|
| 533 |
output_frame_probs: bool = False,
|
| 534 |
-
frame_in_ms
|
| 535 |
-
frame_length_ms
|
| 536 |
):
|
| 537 |
self.sample_rate = sample_rate
|
| 538 |
self.detect_mode = detect_mode
|
|
@@ -595,10 +596,10 @@ class E2EVadFrameProb(object):
|
|
| 595 |
class WindowDetector(object):
|
| 596 |
def __init__(
|
| 597 |
self,
|
| 598 |
-
window_size_ms
|
| 599 |
-
sil_to_speech_time
|
| 600 |
-
speech_to_sil_time
|
| 601 |
-
frame_size_ms
|
| 602 |
):
|
| 603 |
self.window_size_ms = window_size_ms
|
| 604 |
self.sil_to_speech_time = sil_to_speech_time
|
|
@@ -633,7 +634,7 @@ class WindowDetector(object):
|
|
| 633 |
return int(self.win_size_frame)
|
| 634 |
|
| 635 |
def detect_one_frame(
|
| 636 |
-
self, frameState: FrameState, frame_count
|
| 637 |
) -> AudioChangeState:
|
| 638 |
cur_frame_state = FrameState.kFrameStateSil
|
| 639 |
if frameState == FrameState.kFrameStateSpeech:
|
|
@@ -773,7 +774,7 @@ class E2EVadModel:
|
|
| 773 |
|
| 774 |
return scores[1:]
|
| 775 |
|
| 776 |
-
def pop_data_buf_till_frame(self, frame_idx
|
| 777 |
while self.data_buf_start_frame < frame_idx:
|
| 778 |
if self.data_buf_size >= int(
|
| 779 |
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
|
|
@@ -787,8 +788,8 @@ class E2EVadModel:
|
|
| 787 |
|
| 788 |
def pop_data_to_output_buf(
|
| 789 |
self,
|
| 790 |
-
start_frm
|
| 791 |
-
frm_cnt
|
| 792 |
first_frm_is_start_point: bool,
|
| 793 |
last_frm_is_end_point: bool,
|
| 794 |
end_point_is_sent_end: bool,
|
|
@@ -849,18 +850,18 @@ class E2EVadModel:
|
|
| 849 |
if last_frm_is_end_point:
|
| 850 |
cur_seg.contain_seg_end_point = True
|
| 851 |
|
| 852 |
-
def on_silence_detected(self, valid_frame
|
| 853 |
self.lastest_confirmed_silence_frame = valid_frame
|
| 854 |
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
| 855 |
self.pop_data_buf_till_frame(valid_frame)
|
| 856 |
# silence_detected_callback_
|
| 857 |
# pass
|
| 858 |
|
| 859 |
-
def on_voice_detected(self, valid_frame
|
| 860 |
self.latest_confirmed_speech_frame = valid_frame
|
| 861 |
self.pop_data_to_output_buf(valid_frame, 1, False, False, False)
|
| 862 |
|
| 863 |
-
def on_voice_start(self, start_frame
|
| 864 |
if self.vad_opts.do_start_point_detection:
|
| 865 |
pass
|
| 866 |
if self.confirmed_start_frame != -1:
|
|
@@ -878,7 +879,7 @@ class E2EVadModel:
|
|
| 878 |
)
|
| 879 |
|
| 880 |
def on_voice_end(
|
| 881 |
-
self, end_frame
|
| 882 |
) -> None:
|
| 883 |
for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
|
| 884 |
self.on_voice_detected(t)
|
|
@@ -896,7 +897,7 @@ class E2EVadModel:
|
|
| 896 |
self.number_end_time_detected += 1
|
| 897 |
|
| 898 |
def maybe_on_voice_end_last_frame(
|
| 899 |
-
self, is_final_frame: bool, cur_frm_idx
|
| 900 |
) -> None:
|
| 901 |
if is_final_frame:
|
| 902 |
self.on_voice_end(cur_frm_idx, False, True)
|
|
@@ -913,7 +914,7 @@ class E2EVadModel:
|
|
| 913 |
)
|
| 914 |
return vad_latency
|
| 915 |
|
| 916 |
-
def get_frame_state(self, t
|
| 917 |
frame_state = FrameState.kFrameStateInvalid
|
| 918 |
cur_decibel = self.decibel[t - self.decibel_offset]
|
| 919 |
cur_snr = cur_decibel - self.noise_average_decibel
|
|
@@ -1010,7 +1011,7 @@ class E2EVadModel:
|
|
| 1010 |
waveform: np.ndarray,
|
| 1011 |
in_cache: list = None,
|
| 1012 |
is_final: bool = False,
|
| 1013 |
-
max_end_sil
|
| 1014 |
) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]:
|
| 1015 |
feats = [feats]
|
| 1016 |
if in_cache is None:
|
|
@@ -1059,7 +1060,7 @@ class E2EVadModel:
|
|
| 1059 |
waveform: np.ndarray,
|
| 1060 |
in_cache: list = None,
|
| 1061 |
is_final: bool = False,
|
| 1062 |
-
max_end_sil
|
| 1063 |
):
|
| 1064 |
feats = [feats]
|
| 1065 |
states = []
|
|
@@ -1116,7 +1117,7 @@ class E2EVadModel:
|
|
| 1116 |
return 0
|
| 1117 |
|
| 1118 |
def detect_one_frame(
|
| 1119 |
-
self, cur_frm_state: FrameState, cur_frm_idx
|
| 1120 |
) -> None:
|
| 1121 |
tmp_cur_frm_state = FrameState.kFrameStateInvalid
|
| 1122 |
if cur_frm_state == FrameState.kFrameStateSpeech:
|
|
@@ -1267,7 +1268,7 @@ class E2EVadModel:
|
|
| 1267 |
|
| 1268 |
|
| 1269 |
class FSMNVad(object):
|
| 1270 |
-
def __init__(self, config_dir
|
| 1271 |
config_dir = Path(config_dir)
|
| 1272 |
self.config = read_yaml(config_dir / "fsmn-config.yaml")
|
| 1273 |
self.frontend = WavFrontend(
|
|
@@ -1332,16 +1333,16 @@ languages = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeec
|
|
| 1332 |
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 1333 |
logging.basicConfig(format=formatter, level=logging.INFO)
|
| 1334 |
|
| 1335 |
-
def
|
| 1336 |
arg_parser = argparse.ArgumentParser(description="Sense Voice")
|
| 1337 |
arg_parser.add_argument("-a", "--audio_file", required=True, type=str, help="Model")
|
| 1338 |
download_model_path = os.path.dirname(__file__)
|
| 1339 |
arg_parser.add_argument(
|
| 1340 |
"-dp",
|
| 1341 |
-
"--download_path",
|
| 1342 |
default=download_model_path,
|
| 1343 |
type=str,
|
| 1344 |
-
help="dir path of resource downloaded"
|
| 1345 |
)
|
| 1346 |
arg_parser.add_argument("-d", "--device", default=-1, type=int, help="Device")
|
| 1347 |
arg_parser.add_argument(
|
|
@@ -1351,35 +1352,36 @@ def main():
|
|
| 1351 |
"-l",
|
| 1352 |
"--language",
|
| 1353 |
choices=languages.keys(),
|
| 1354 |
-
default="auto",
|
| 1355 |
type=str,
|
| 1356 |
-
help="Language"
|
| 1357 |
)
|
| 1358 |
arg_parser.add_argument("--use_itn", action="store_true", help="Use ITN")
|
| 1359 |
-
|
| 1360 |
|
| 1361 |
-
|
|
|
|
| 1362 |
|
| 1363 |
model = SenseVoiceInferenceSession(
|
| 1364 |
-
os.path.join(
|
| 1365 |
os.path.join(
|
| 1366 |
-
|
| 1367 |
"sense-voice-encoder.rknn",
|
| 1368 |
),
|
| 1369 |
-
os.path.join(
|
| 1370 |
-
|
| 1371 |
-
|
| 1372 |
)
|
| 1373 |
waveform, _sample_rate = sf.read(
|
| 1374 |
-
|
| 1375 |
dtype="float32",
|
| 1376 |
always_2d=True
|
| 1377 |
)
|
| 1378 |
|
| 1379 |
-
logging.info(f"Audio {
|
| 1380 |
-
# load vad model
|
| 1381 |
start = time.time()
|
| 1382 |
-
vad = FSMNVad(
|
| 1383 |
for channel_id, channel_data in enumerate(waveform.T):
|
| 1384 |
segments = vad.segments_offline(channel_data)
|
| 1385 |
results = ""
|
|
@@ -1387,16 +1389,20 @@ def main():
|
|
| 1387 |
audio_feats = front.get_features(channel_data[part[0] * 16 : part[1] * 16])
|
| 1388 |
asr_result = model(
|
| 1389 |
audio_feats[None, ...],
|
| 1390 |
-
language=languages[
|
| 1391 |
-
use_itn=
|
| 1392 |
)
|
| 1393 |
logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}")
|
|
|
|
|
|
|
| 1394 |
vad.vad.all_reset_detection()
|
| 1395 |
decoding_time = time.time() - start
|
| 1396 |
logging.info(f"Decoder audio takes {decoding_time} seconds")
|
| 1397 |
logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.")
|
|
|
|
| 1398 |
|
| 1399 |
|
| 1400 |
if __name__ == "__main__":
|
| 1401 |
-
|
|
|
|
| 1402 |
|
|
|
|
| 28 |
|
| 29 |
RKNN_INPUT_LEN = 171
|
| 30 |
|
| 31 |
+
SPEECH_SCALE = 1
|
| 32 |
|
| 33 |
class VadOrtInferRuntimeSession:
|
| 34 |
def __init__(self, config, root_dir: Path):
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
def __call__(
|
| 75 |
+
self, input_content
|
| 76 |
) -> np.ndarray:
|
| 77 |
if isinstance(input_content, list):
|
| 78 |
input_dict = {
|
|
|
|
| 97 |
):
|
| 98 |
return [v.name for v in self.session.get_outputs()]
|
| 99 |
|
| 100 |
+
def get_character_list(self, key = "character"):
|
| 101 |
return self.meta_dict[key].splitlines()
|
| 102 |
|
| 103 |
+
def have_key(self, key = "character") -> bool:
|
| 104 |
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
|
| 105 |
if key in self.meta_dict.keys():
|
| 106 |
return True
|
|
|
|
| 196 |
):
|
| 197 |
return [v.name for v in self.session.get_outputs()]
|
| 198 |
|
| 199 |
+
def get_character_list(self, key = "character"):
|
| 200 |
return self.meta_dict[key].splitlines()
|
| 201 |
|
| 202 |
+
def have_key(self, key = "character") -> bool:
|
| 203 |
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
|
| 204 |
if key in self.meta_dict.keys():
|
| 205 |
return True
|
|
|
|
| 250 |
self.sp = spm.SentencePieceProcessor()
|
| 251 |
self.sp.load(bpe_model_file)
|
| 252 |
|
| 253 |
+
def __call__(self, speech, language, use_itn: bool) -> np.ndarray:
|
| 254 |
language_query = self.embedding[[[language]]]
|
| 255 |
|
| 256 |
# 14 means with itn, 15 means without itn
|
|
|
|
| 274 |
input_content = np.pad(input_content, ((0, 0), (0, RKNN_INPUT_LEN - input_content.shape[1]), (0, 0)))
|
| 275 |
print("padded shape:", input_content.shape)
|
| 276 |
start_time = time.time()
|
| 277 |
+
np.save("input_content.npy",input_content)
|
| 278 |
encoder_out = self.encoder.inference(inputs=[input_content])[0]
|
| 279 |
end_time = time.time()
|
| 280 |
print(f"encoder inference time: {end_time - start_time:.2f} seconds")
|
|
|
|
| 309 |
|
| 310 |
def __init__(
|
| 311 |
self,
|
| 312 |
+
cmvn_file = None,
|
| 313 |
+
fs = 16000,
|
| 314 |
+
window = "hamming",
|
| 315 |
+
n_mels = 80,
|
| 316 |
+
frame_length = 25,
|
| 317 |
+
frame_shift = 10,
|
| 318 |
+
lfr_m = 7,
|
| 319 |
+
lfr_n = 6,
|
| 320 |
dither: float = 0,
|
| 321 |
**kwargs,
|
| 322 |
) -> None:
|
|
|
|
| 368 |
feat_len = np.array(feat.shape[0]).astype(np.int32)
|
| 369 |
return feat, feat_len
|
| 370 |
|
| 371 |
+
def load_audio(self, filename) -> Tuple[np.ndarray, int]:
|
| 372 |
data, sample_rate = sf.read(
|
| 373 |
filename,
|
| 374 |
always_2d=True,
|
|
|
|
| 384 |
return samples, sample_rate
|
| 385 |
|
| 386 |
@staticmethod
|
| 387 |
+
def apply_lfr(inputs: np.ndarray, lfr_m, lfr_n) -> np.ndarray:
|
| 388 |
LFR_inputs = []
|
| 389 |
|
| 390 |
T = inputs.shape[0]
|
|
|
|
| 418 |
inputs = (inputs + means) * vars
|
| 419 |
return inputs
|
| 420 |
|
| 421 |
+
def get_features(self, inputs: Union[str, np.ndarray]):
|
| 422 |
if isinstance(inputs, str):
|
| 423 |
inputs, _ = self.load_audio(inputs)
|
| 424 |
|
|
|
|
| 505 |
class VADXOptions:
|
| 506 |
def __init__(
|
| 507 |
self,
|
| 508 |
+
sample_rate = 16000,
|
| 509 |
+
detect_mode = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
|
| 510 |
+
snr_mode = 0,
|
| 511 |
+
max_end_silence_time = 800,
|
| 512 |
+
max_start_silence_time = 3000,
|
| 513 |
do_start_point_detection: bool = True,
|
| 514 |
do_end_point_detection: bool = True,
|
| 515 |
+
window_size_ms = 200,
|
| 516 |
+
sil_to_speech_time_thres = 150,
|
| 517 |
+
speech_to_sil_time_thres = 150,
|
| 518 |
speech_2_noise_ratio: float = 1.0,
|
| 519 |
+
do_extend = 1,
|
| 520 |
+
lookback_time_start_point = 200,
|
| 521 |
+
lookahead_time_end_point = 100,
|
| 522 |
+
max_single_segment_time = 60000,
|
| 523 |
+
nn_eval_block_size = 8,
|
| 524 |
+
dcd_block_size = 4,
|
| 525 |
+
snr_thres = -100.0,
|
| 526 |
+
noise_frame_num_used_for_snr = 100,
|
| 527 |
+
decibel_thres = -100.0,
|
| 528 |
speech_noise_thres: float = 0.6,
|
| 529 |
fe_prior_thres: float = 1e-4,
|
| 530 |
+
silence_pdf_num = 1,
|
| 531 |
sil_pdf_ids: List[int] = [0],
|
| 532 |
speech_noise_thresh_low: float = -0.1,
|
| 533 |
speech_noise_thresh_high: float = 0.3,
|
| 534 |
output_frame_probs: bool = False,
|
| 535 |
+
frame_in_ms = 10,
|
| 536 |
+
frame_length_ms = 25,
|
| 537 |
):
|
| 538 |
self.sample_rate = sample_rate
|
| 539 |
self.detect_mode = detect_mode
|
|
|
|
| 596 |
class WindowDetector(object):
|
| 597 |
def __init__(
|
| 598 |
self,
|
| 599 |
+
window_size_ms,
|
| 600 |
+
sil_to_speech_time,
|
| 601 |
+
speech_to_sil_time,
|
| 602 |
+
frame_size_ms,
|
| 603 |
):
|
| 604 |
self.window_size_ms = window_size_ms
|
| 605 |
self.sil_to_speech_time = sil_to_speech_time
|
|
|
|
| 634 |
return int(self.win_size_frame)
|
| 635 |
|
| 636 |
def detect_one_frame(
|
| 637 |
+
self, frameState: FrameState, frame_count
|
| 638 |
) -> AudioChangeState:
|
| 639 |
cur_frame_state = FrameState.kFrameStateSil
|
| 640 |
if frameState == FrameState.kFrameStateSpeech:
|
|
|
|
| 774 |
|
| 775 |
return scores[1:]
|
| 776 |
|
| 777 |
+
def pop_data_buf_till_frame(self, frame_idx) -> None: # need check again
|
| 778 |
while self.data_buf_start_frame < frame_idx:
|
| 779 |
if self.data_buf_size >= int(
|
| 780 |
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
|
|
|
|
| 788 |
|
| 789 |
def pop_data_to_output_buf(
|
| 790 |
self,
|
| 791 |
+
start_frm,
|
| 792 |
+
frm_cnt,
|
| 793 |
first_frm_is_start_point: bool,
|
| 794 |
last_frm_is_end_point: bool,
|
| 795 |
end_point_is_sent_end: bool,
|
|
|
|
| 850 |
if last_frm_is_end_point:
|
| 851 |
cur_seg.contain_seg_end_point = True
|
| 852 |
|
| 853 |
+
def on_silence_detected(self, valid_frame):
|
| 854 |
self.lastest_confirmed_silence_frame = valid_frame
|
| 855 |
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
| 856 |
self.pop_data_buf_till_frame(valid_frame)
|
| 857 |
# silence_detected_callback_
|
| 858 |
# pass
|
| 859 |
|
| 860 |
+
def on_voice_detected(self, valid_frame) -> None:
|
| 861 |
self.latest_confirmed_speech_frame = valid_frame
|
| 862 |
self.pop_data_to_output_buf(valid_frame, 1, False, False, False)
|
| 863 |
|
| 864 |
+
def on_voice_start(self, start_frame, fake_result: bool = False) -> None:
|
| 865 |
if self.vad_opts.do_start_point_detection:
|
| 866 |
pass
|
| 867 |
if self.confirmed_start_frame != -1:
|
|
|
|
| 879 |
)
|
| 880 |
|
| 881 |
def on_voice_end(
|
| 882 |
+
self, end_frame, fake_result: bool, is_last_frame: bool
|
| 883 |
) -> None:
|
| 884 |
for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
|
| 885 |
self.on_voice_detected(t)
|
|
|
|
| 897 |
self.number_end_time_detected += 1
|
| 898 |
|
| 899 |
def maybe_on_voice_end_last_frame(
|
| 900 |
+
self, is_final_frame: bool, cur_frm_idx
|
| 901 |
) -> None:
|
| 902 |
if is_final_frame:
|
| 903 |
self.on_voice_end(cur_frm_idx, False, True)
|
|
|
|
| 914 |
)
|
| 915 |
return vad_latency
|
| 916 |
|
| 917 |
+
def get_frame_state(self, t) -> FrameState:
|
| 918 |
frame_state = FrameState.kFrameStateInvalid
|
| 919 |
cur_decibel = self.decibel[t - self.decibel_offset]
|
| 920 |
cur_snr = cur_decibel - self.noise_average_decibel
|
|
|
|
| 1011 |
waveform: np.ndarray,
|
| 1012 |
in_cache: list = None,
|
| 1013 |
is_final: bool = False,
|
| 1014 |
+
max_end_sil = 800,
|
| 1015 |
) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]:
|
| 1016 |
feats = [feats]
|
| 1017 |
if in_cache is None:
|
|
|
|
| 1060 |
waveform: np.ndarray,
|
| 1061 |
in_cache: list = None,
|
| 1062 |
is_final: bool = False,
|
| 1063 |
+
max_end_sil = 800,
|
| 1064 |
):
|
| 1065 |
feats = [feats]
|
| 1066 |
states = []
|
|
|
|
| 1117 |
return 0
|
| 1118 |
|
| 1119 |
def detect_one_frame(
|
| 1120 |
+
self, cur_frm_state: FrameState, cur_frm_idx, is_final_frame: bool
|
| 1121 |
) -> None:
|
| 1122 |
tmp_cur_frm_state = FrameState.kFrameStateInvalid
|
| 1123 |
if cur_frm_state == FrameState.kFrameStateSpeech:
|
|
|
|
| 1268 |
|
| 1269 |
|
| 1270 |
class FSMNVad(object):
|
| 1271 |
+
def __init__(self, config_dir):
|
| 1272 |
config_dir = Path(config_dir)
|
| 1273 |
self.config = read_yaml(config_dir / "fsmn-config.yaml")
|
| 1274 |
self.frontend = WavFrontend(
|
|
|
|
| 1333 |
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 1334 |
logging.basicConfig(format=formatter, level=logging.INFO)
|
| 1335 |
|
| 1336 |
+
def parse_args():
|
| 1337 |
arg_parser = argparse.ArgumentParser(description="Sense Voice")
|
| 1338 |
arg_parser.add_argument("-a", "--audio_file", required=True, type=str, help="Model")
|
| 1339 |
download_model_path = os.path.dirname(__file__)
|
| 1340 |
arg_parser.add_argument(
|
| 1341 |
"-dp",
|
| 1342 |
+
"--download_path",
|
| 1343 |
default=download_model_path,
|
| 1344 |
type=str,
|
| 1345 |
+
help="dir path of resource downloaded"
|
| 1346 |
)
|
| 1347 |
arg_parser.add_argument("-d", "--device", default=-1, type=int, help="Device")
|
| 1348 |
arg_parser.add_argument(
|
|
|
|
| 1352 |
"-l",
|
| 1353 |
"--language",
|
| 1354 |
choices=languages.keys(),
|
| 1355 |
+
default="auto",
|
| 1356 |
type=str,
|
| 1357 |
+
help="Language"
|
| 1358 |
)
|
| 1359 |
arg_parser.add_argument("--use_itn", action="store_true", help="Use ITN")
|
| 1360 |
+
return arg_parser.parse_args()
|
| 1361 |
|
| 1362 |
+
def main(audio_file, download_path, device, num_threads, language, use_itn):
|
| 1363 |
+
front = WavFrontend(os.path.join(download_path, "am.mvn"))
|
| 1364 |
|
| 1365 |
model = SenseVoiceInferenceSession(
|
| 1366 |
+
os.path.join(download_path, "embedding.npy"),
|
| 1367 |
os.path.join(
|
| 1368 |
+
download_path,
|
| 1369 |
"sense-voice-encoder.rknn",
|
| 1370 |
),
|
| 1371 |
+
os.path.join(download_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"),
|
| 1372 |
+
device,
|
| 1373 |
+
num_threads,
|
| 1374 |
)
|
| 1375 |
waveform, _sample_rate = sf.read(
|
| 1376 |
+
audio_file,
|
| 1377 |
dtype="float32",
|
| 1378 |
always_2d=True
|
| 1379 |
)
|
| 1380 |
|
| 1381 |
+
logging.info(f"Audio {audio_file} is {len(waveform) / _sample_rate} seconds, {waveform.shape[1]} channel")
|
| 1382 |
+
# load vad model
|
| 1383 |
start = time.time()
|
| 1384 |
+
vad = FSMNVad(download_path)
|
| 1385 |
for channel_id, channel_data in enumerate(waveform.T):
|
| 1386 |
segments = vad.segments_offline(channel_data)
|
| 1387 |
results = ""
|
|
|
|
| 1389 |
audio_feats = front.get_features(channel_data[part[0] * 16 : part[1] * 16])
|
| 1390 |
asr_result = model(
|
| 1391 |
audio_feats[None, ...],
|
| 1392 |
+
language=languages[language],
|
| 1393 |
+
use_itn=use_itn,
|
| 1394 |
)
|
| 1395 |
logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}")
|
| 1396 |
+
results += asr_result
|
| 1397 |
+
logging.info(f"Results: {results}")
|
| 1398 |
vad.vad.all_reset_detection()
|
| 1399 |
decoding_time = time.time() - start
|
| 1400 |
logging.info(f"Decoder audio takes {decoding_time} seconds")
|
| 1401 |
logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.")
|
| 1402 |
+
return results
|
| 1403 |
|
| 1404 |
|
| 1405 |
if __name__ == "__main__":
|
| 1406 |
+
args = parse_args()
|
| 1407 |
+
main(args.audio_file, args.download_path, args.device, args.num_threads, args.language, args.use_itn)
|
| 1408 |
|