happyme531 commited on
Commit
2b134bc
·
verified ·
1 Parent(s): 01bc982

尝试解决fp16溢出问题

Browse files
Files changed (4) hide show
  1. README.md +11 -15
  2. convert_rknn.py +152 -24
  3. sense-voice-encoder.rknn +2 -2
  4. sensevoice_rknn.py +81 -75
README.md CHANGED
@@ -28,12 +28,10 @@ SenseVoice是具有音频理解能力的音频基础模型, 包括语音识别
28
  2. 安装依赖
29
 
30
  ```bash
31
- pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml numpy<2
32
  ```
33
 
34
- 你还需要手动安装rknn-toolkit2-lite2.
35
-
36
- 3. 运行
37
 
38
  ```bash
39
  python ./sensevoice_rknn.py --audio_file output.wav
@@ -47,7 +45,7 @@ ffmpeg -i input.mp3 -f wav -acodec pcm_s16le -ac 1 -ar 16000 output.wav
47
 
48
  ## RKNN模型转换
49
 
50
- 你需要提前安装rknn-toolkit2 v2.1.0或更高版本.
51
 
52
  1. 下载或转换onnx模型
53
 
@@ -58,13 +56,13 @@ ffmpeg -i input.mp3 -f wav -acodec pcm_s16le -ac 1 -ar 16000 output.wav
58
 
59
  2. 转换为rknn模型
60
  ```bash
61
- python convert_rknn.py
62
  ```
63
 
64
  ## 已知问题
65
 
66
- - RKNN2使用fp16推理时可能会出现溢出,导致结果为inf,可以尝试修改输入数据的缩放比例来解决.
67
- 在`sensevoice_rknn.py`中将`SPEECH_SCALE`设置为更小的值.
68
 
69
  ## 参考
70
  - [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
@@ -88,11 +86,9 @@ Currently, SenseVoice-small supports multilingual speech recognition, emotion re
88
  2. Install dependencies
89
 
90
  ```bash
91
- pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml numpy<2
92
  ```
93
 
94
- You also need to manually install rknn-toolkit2-lite2.
95
-
96
  3. Run
97
 
98
  ```bash
@@ -107,7 +103,7 @@ ffmpeg -i input.mp3 -f wav -acodec pcm_s16le -ac 1 -ar 16000 output.wav
107
 
108
  ## RKNN Model Conversion
109
 
110
- You need to install rknn-toolkit2 v2.1.0 or higher in advance.
111
 
112
  1. Download or convert the ONNX model
113
 
@@ -118,13 +114,13 @@ The model file should be named 'sense-voice-encoder.onnx' and placed in the same
118
 
119
  2. Convert to RKNN model
120
  ```bash
121
- python convert_rknn.py
122
  ```
123
 
124
  ## Known Issues
125
 
126
- - When using fp16 inference with RKNN2, overflow may occur, resulting in inf values. You can try modifying the scaling ratio of the input data to resolve this.
127
- Set `SPEECH_SCALE` to a smaller value in `sensevoice_rknn.py`.
128
 
129
  ## References
130
  - [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
 
28
  2. 安装依赖
29
 
30
  ```bash
31
+ pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml "numpy<2" rknn-toolkit-lite2
32
  ```
33
 
34
+ 1. 运行
 
 
35
 
36
  ```bash
37
  python ./sensevoice_rknn.py --audio_file output.wav
 
45
 
46
  ## RKNN模型转换
47
 
48
+ 你需要提前安装rknn-toolkit2, 测试可用的版本为2.3.3a25,可从https://console.zbox.filez.com/l/I00fc3 下载(密码为"rknn")
49
 
50
  1. 下载或转换onnx模型
51
 
 
56
 
57
  2. 转换为rknn模型
58
  ```bash
59
+ python convert_rknn.py ./sense-voice-encoder.onnx
60
  ```
61
 
62
  ## 已知问题
63
 
64
+ - ~~RKNN2使用fp16推理时可能会出现溢出,导致结果为inf,可以尝试修改输入数据的缩放比例来解决.
65
+ 在`sensevoice_rknn.py`中将`SPEECH_SCALE`设置为更小的值.~~ (现在应该已经通过模型内部插入缩放算子解决了)
66
 
67
  ## 参考
68
  - [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
 
86
  2. Install dependencies
87
 
88
  ```bash
89
+ pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml "numpy<2" rknn-toolkit-lite2
90
  ```
91
 
 
 
92
  3. Run
93
 
94
  ```bash
 
103
 
104
  ## RKNN Model Conversion
105
 
106
+ You need to install rknn-toolkit2 in advance. The tested working version is 2.3.3a25, which can be downloaded from https://console.zbox.filez.com/l/I00fc3 (password: "rknn").
107
 
108
  1. Download or convert the ONNX model
109
 
 
114
 
115
  2. Convert to RKNN model
116
  ```bash
117
+ python convert_rknn.py ./sense-voice-encoder.onnx
118
  ```
119
 
120
  ## Known Issues
121
 
122
+ - ~~When using fp16 inference with RKNN2, overflow may occur, resulting in inf values. You can try modifying the scaling ratio of the input data to resolve this.
123
+ Set `SPEECH_SCALE` to a smaller value in `sensevoice_rknn.py`.~~ (This issue should now be resolved by inserting scaling operators inside the model.)
124
 
125
  ## References
126
  - [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
convert_rknn.py CHANGED
@@ -2,6 +2,8 @@
2
  # coding: utf-8
3
 
4
  import os
 
 
5
  from rknn.api import RKNN
6
  from math import exp
7
  from sys import exit
@@ -18,28 +20,158 @@ os.chdir(os.path.dirname(os.path.abspath(__file__)))
18
 
19
  speech_length = 171
20
 
21
- def convert_encoder():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  rknn = RKNN(verbose=True)
23
 
24
- ONNX_MODEL=f"sense-voice-encoder.onnx"
25
- RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
26
- DATASET="dataset.txt"
27
- QUANTIZE=False
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  #开局先给���来个大惊喜,rknn做第一步常量折叠的时候就会在这个子图里报错,所以要单独拿出来先跑一遍
30
  #然后把这个子图的输出结果保存下来喂给rknn
31
- onnx.utils.extract_model(ONNX_MODEL, "extract_model.onnx", ['speech_lengths'], ['/make_pad_mask/Cast_2_output_0'])
32
- sess = ort.InferenceSession("extract_model.onnx", providers=['CPUExecutionProvider'])
 
33
  extract_result = sess.run(None, {"speech_lengths": np.array([speech_length], dtype=np.int64)})[0]
 
34
 
35
- # 删掉模型最后的多余transpose, 速度从365ms提升到350ms
 
36
  ret = onnx_edit(model = ONNX_MODEL,
37
- export_path = ONNX_MODEL.replace(".onnx", "_edited.onnx"),
38
  # # 1, len, 25055 -> 1, 25055, 1, len # 这个是坏的, 我真服了,
39
- # outputs_transform = {'encoder_out': 'a,b,c->a,c,1,b'},
40
- outputs_transform = {'encoder_out': 'a,b,c->a,c,b'},
41
  )
42
- ONNX_MODEL = ONNX_MODEL.replace(".onnx", "_edited.onnx")
 
 
 
43
 
44
  # pre-process config
45
  print('--> Config model')
@@ -48,8 +180,9 @@ def convert_encoder():
48
 
49
  # Load ONNX model
50
  print("--> Loading model")
 
51
  ret = rknn.load_onnx(
52
- model=ONNX_MODEL,
53
  inputs=["speech", "/make_pad_mask/Cast_2_output_0"],
54
  input_size_list=[[1, speech_length, 560], [extract_result.shape[0], extract_result.shape[1]]],
55
  input_initial_val=[None, extract_result],
@@ -60,6 +193,7 @@ def convert_encoder():
60
  print('Load model failed!')
61
  exit(ret)
62
  print('done')
 
63
 
64
  # Build model
65
  print('--> Building model')
@@ -76,20 +210,14 @@ def convert_encoder():
76
  print('Export RKNN model failed!')
77
  exit(ret)
78
  print('done')
 
 
79
 
80
- # usage: python convert_rknn.py encoder|all
81
 
82
  if __name__ == "__main__":
83
  parser = argparse.ArgumentParser()
84
- parser.add_argument("model", type=str, help="model to convert", choices=["encoder", "all"], nargs='?')
85
  args = parser.parse_args()
86
- if args.model is None:
87
- args.model = "all"
88
 
89
- if args.model == "encoder":
90
- convert_encoder()
91
- elif args.model == "all":
92
- convert_encoder()
93
- else:
94
- print(f"Unknown model: {args.model}")
95
- exit(1)
 
2
  # coding: utf-8
3
 
4
  import os
5
+ import re
6
+ from typing import Optional, Set
7
  from rknn.api import RKNN
8
  from math import exp
9
  from sys import exit
 
20
 
21
  speech_length = 171
22
 
23
+ def _remove_file(path: str, *, keep: Optional[Set[str]] = None) -> None:
24
+ if not path:
25
+ return
26
+ keep_paths: Set[str] = {os.path.abspath(item) for item in keep} if keep else set()
27
+ normalized = os.path.abspath(path)
28
+ if keep_paths and normalized in keep_paths:
29
+ return
30
+ if not os.path.exists(normalized):
31
+ return
32
+ try:
33
+ os.remove(normalized)
34
+ print(f'cleaned temp model: {normalized}')
35
+ except OSError as err:
36
+ print(f'warning: failed to remove {normalized}: {err}')
37
+
38
+ def _with_suffix(path: str, suffix: str) -> str:
39
+ stem, ext = os.path.splitext(path)
40
+ return f"{stem}{suffix}{ext}"
41
+
42
+ def _sanitize_name(name: str) -> str:
43
+ return re.sub(r'[^0-9A-Za-z_]', '_', name)
44
+
45
+ def _insert_div_node(model: onnx.ModelProto, tensor_name: str, divisor: float = 16.0) -> bool:
46
+ graph = model.graph
47
+
48
+ for node in graph.node:
49
+ if node.op_type == 'Div' and tensor_name in node.output:
50
+ return False
51
+
52
+ producer_index = None
53
+ output_index = None
54
+ for idx, node in enumerate(graph.node):
55
+ for out_idx, output in enumerate(node.output):
56
+ if output == tensor_name:
57
+ producer_index = idx
58
+ output_index = out_idx
59
+ producer_node = node
60
+ break
61
+ if producer_index is not None:
62
+ break
63
+
64
+ if producer_index is None:
65
+ raise RuntimeError(f"Producer node for tensor {tensor_name} not found.")
66
+
67
+ pre_div_output = f"{tensor_name}_pre_div"
68
+ producer_node.output[output_index] = pre_div_output
69
+
70
+ sanitized = _sanitize_name(tensor_name)
71
+ const_output = f"{sanitized}_div_const"
72
+ const_node_name = f"{sanitized}_DivConst"
73
+ div_node_name = f"{sanitized}_Div"
74
+
75
+ const_tensor = onnx.helper.make_tensor(
76
+ name=f"{const_node_name}_value",
77
+ data_type=onnx.TensorProto.FLOAT,
78
+ dims=[],
79
+ vals=[divisor],
80
+ )
81
+
82
+ const_node = onnx.helper.make_node(
83
+ 'Constant',
84
+ inputs=[],
85
+ outputs=[const_output],
86
+ value=const_tensor,
87
+ name=const_node_name,
88
+ )
89
+
90
+ div_node = onnx.helper.make_node(
91
+ 'Div',
92
+ inputs=[pre_div_output, const_output],
93
+ outputs=[tensor_name],
94
+ name=div_node_name,
95
+ )
96
+
97
+ graph.node.insert(producer_index + 1, const_node)
98
+ graph.node.insert(producer_index + 2, div_node)
99
+ return True
100
+
101
+ def _scale_initializer(model: onnx.ModelProto, initializer_name: str, divisor: float = 16.0) -> bool:
102
+ for idx, initializer in enumerate(model.graph.initializer):
103
+ if initializer.name == initializer_name:
104
+ data = onh.to_array(initializer).astype(np.float32, copy=False)
105
+ scaled = data / divisor
106
+ model.graph.initializer[idx].CopyFrom(onh.from_array(scaled, name=initializer_name))
107
+ return True
108
+ return False
109
+
110
+ def convert_encoder(model_path: str):
111
  rknn = RKNN(verbose=True)
112
 
113
+ ONNX_MODEL = os.path.abspath(model_path)
114
+ if not os.path.isfile(ONNX_MODEL):
115
+ print(f'Model file not found: {model_path}')
116
+ exit(1)
117
+ if not ONNX_MODEL.lower().endswith('.onnx'):
118
+ print(f'Model file must be an ONNX file: {model_path}')
119
+ exit(1)
120
+
121
+ RKNN_MODEL = os.path.splitext(ONNX_MODEL)[0] + ".rknn"
122
+ DATASET = "dataset.txt"
123
+ QUANTIZE = False
124
+ original_model = ONNX_MODEL
125
+ preserve_files: Set[str] = {original_model}
126
+
127
 
128
+ print('--> Patching model to avoid overflow issue')
129
+ base_model = onnx.load(ONNX_MODEL)
130
+ modified = False
131
+ for layer_idx in range(48, 49):
132
+ for target in [
133
+ f'/encoders.{layer_idx}/feed_forward/activation/Relu_output_0',
134
+ f'/encoders.{layer_idx}/norm2/Cast_output_0',
135
+ ]:
136
+ modified |= _insert_div_node(base_model, target, divisor=2.0)
137
+ bias_scaled = False
138
+ if modified:
139
+ for layer_idx in range(48, 49):
140
+ bias_scaled |= _scale_initializer(base_model, f'model.encoders.{layer_idx}.feed_forward.w_2.bias', divisor=2.0)
141
+ div_model_path = _with_suffix(ONNX_MODEL, "_div")
142
+ onnx.save(base_model, div_model_path)
143
+ if os.path.exists(div_model_path):
144
+ previous_model = ONNX_MODEL
145
+ ONNX_MODEL = div_model_path
146
+ _remove_file(previous_model, keep=preserve_files)
147
+ if modified:
148
+ if bias_scaled:
149
+ print('done (created div-adjusted model and scaled bias)')
150
+ else:
151
+ print('done (created div-adjusted model; bias initializer not found)')
152
+ else:
153
+ print('done (div nodes already present)')
154
+
155
  #开局先给���来个大惊喜,rknn做第一步常量折叠的时候就会在这个子图里报错,所以要单独拿出来先跑一遍
156
  #然后把这个子图的输出结果保存下来喂给rknn
157
+ extract_model_path = os.path.join(os.getcwd(), "extract_model.onnx")
158
+ onnx.utils.extract_model(ONNX_MODEL, extract_model_path, ['speech_lengths'], ['/make_pad_mask/Cast_2_output_0'])
159
+ sess = ort.InferenceSession(extract_model_path, providers=['CPUExecutionProvider'])
160
  extract_result = sess.run(None, {"speech_lengths": np.array([speech_length], dtype=np.int64)})[0]
161
+ _remove_file(extract_model_path)
162
 
163
+ # 删掉模型最后的多余transpose, 速度从365ms提升到259ms
164
+ edited_model_path = _with_suffix(ONNX_MODEL, "_edited")
165
  ret = onnx_edit(model = ONNX_MODEL,
166
+ export_path = edited_model_path,
167
  # # 1, len, 25055 -> 1, 25055, 1, len # 这个是坏的, 我真服了,
168
+ outputs_transform = {'encoder_out': 'a,b,c->a,c,1,b'},
169
+ # outputs_transform = {'encoder_out': 'a,b,c->a,c,b'},
170
  )
171
+ if os.path.exists(edited_model_path):
172
+ previous_model = ONNX_MODEL
173
+ ONNX_MODEL = edited_model_path
174
+ _remove_file(previous_model, keep=preserve_files)
175
 
176
  # pre-process config
177
  print('--> Config model')
 
180
 
181
  # Load ONNX model
182
  print("--> Loading model")
183
+ current_model_path = ONNX_MODEL
184
  ret = rknn.load_onnx(
185
+ model=current_model_path,
186
  inputs=["speech", "/make_pad_mask/Cast_2_output_0"],
187
  input_size_list=[[1, speech_length, 560], [extract_result.shape[0], extract_result.shape[1]]],
188
  input_initial_val=[None, extract_result],
 
193
  print('Load model failed!')
194
  exit(ret)
195
  print('done')
196
+ _remove_file(current_model_path, keep=preserve_files)
197
 
198
  # Build model
199
  print('--> Building model')
 
210
  print('Export RKNN model failed!')
211
  exit(ret)
212
  print('done')
213
+ # 精度分析(可选)
214
+ # rknn.accuracy_analysis(inputs=["input_content.npy"], target="rk3588", device_id=None)
215
 
216
+ # usage: python convert_rknn.py path/to/model.onnx
217
 
218
  if __name__ == "__main__":
219
  parser = argparse.ArgumentParser()
220
+ parser.add_argument("model_path", type=str, help="path to source ONNX model")
221
  args = parser.parse_args()
 
 
222
 
223
+ convert_encoder(args.model_path)
 
 
 
 
 
 
sense-voice-encoder.rknn CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8db70c1a8d4887e35dff55ab0f5d8da283d32359bd1599ece51eb81f99a6f468
3
- size 485687354
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:198e9c843ed0ffd8500f3b045e47ef5cbc9a4fc36401f6d51c853dacd8129d98
3
+ size 484090175
sensevoice_rknn.py CHANGED
@@ -28,7 +28,7 @@ from rknnlite.api.rknn_lite import RKNNLite
28
 
29
  RKNN_INPUT_LEN = 171
30
 
31
- SPEECH_SCALE = 1/2 # 因为是fp16推理,如果中间结果太大可能会溢出变inf,所以需要缩放一下
32
 
33
  class VadOrtInferRuntimeSession:
34
  def __init__(self, config, root_dir: Path):
@@ -72,7 +72,7 @@ class VadOrtInferRuntimeSession:
72
  )
73
 
74
  def __call__(
75
- self, input_content: List[Union[np.ndarray, np.ndarray]]
76
  ) -> np.ndarray:
77
  if isinstance(input_content, list):
78
  input_dict = {
@@ -97,10 +97,10 @@ class VadOrtInferRuntimeSession:
97
  ):
98
  return [v.name for v in self.session.get_outputs()]
99
 
100
- def get_character_list(self, key: str = "character"):
101
  return self.meta_dict[key].splitlines()
102
 
103
- def have_key(self, key: str = "character") -> bool:
104
  self.meta_dict = self.session.get_modelmeta().custom_metadata_map
105
  if key in self.meta_dict.keys():
106
  return True
@@ -196,10 +196,10 @@ class OrtInferRuntimeSession:
196
  ):
197
  return [v.name for v in self.session.get_outputs()]
198
 
199
- def get_character_list(self, key: str = "character"):
200
  return self.meta_dict[key].splitlines()
201
 
202
- def have_key(self, key: str = "character") -> bool:
203
  self.meta_dict = self.session.get_modelmeta().custom_metadata_map
204
  if key in self.meta_dict.keys():
205
  return True
@@ -250,7 +250,7 @@ class SenseVoiceInferenceSession:
250
  self.sp = spm.SentencePieceProcessor()
251
  self.sp.load(bpe_model_file)
252
 
253
- def __call__(self, speech, language: int, use_itn: bool) -> np.ndarray:
254
  language_query = self.embedding[[[language]]]
255
 
256
  # 14 means with itn, 15 means without itn
@@ -274,6 +274,7 @@ class SenseVoiceInferenceSession:
274
  input_content = np.pad(input_content, ((0, 0), (0, RKNN_INPUT_LEN - input_content.shape[1]), (0, 0)))
275
  print("padded shape:", input_content.shape)
276
  start_time = time.time()
 
277
  encoder_out = self.encoder.inference(inputs=[input_content])[0]
278
  end_time = time.time()
279
  print(f"encoder inference time: {end_time - start_time:.2f} seconds")
@@ -308,14 +309,14 @@ class WavFrontend:
308
 
309
  def __init__(
310
  self,
311
- cmvn_file: str = None,
312
- fs: int = 16000,
313
- window: str = "hamming",
314
- n_mels: int = 80,
315
- frame_length: int = 25,
316
- frame_shift: int = 10,
317
- lfr_m: int = 7,
318
- lfr_n: int = 6,
319
  dither: float = 0,
320
  **kwargs,
321
  ) -> None:
@@ -367,7 +368,7 @@ class WavFrontend:
367
  feat_len = np.array(feat.shape[0]).astype(np.int32)
368
  return feat, feat_len
369
 
370
- def load_audio(self, filename: str) -> Tuple[np.ndarray, int]:
371
  data, sample_rate = sf.read(
372
  filename,
373
  always_2d=True,
@@ -383,7 +384,7 @@ class WavFrontend:
383
  return samples, sample_rate
384
 
385
  @staticmethod
386
- def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
387
  LFR_inputs = []
388
 
389
  T = inputs.shape[0]
@@ -417,7 +418,7 @@ class WavFrontend:
417
  inputs = (inputs + means) * vars
418
  return inputs
419
 
420
- def get_features(self, inputs: Union[str, np.ndarray]) -> Tuple[np.ndarray, int]:
421
  if isinstance(inputs, str):
422
  inputs, _ = self.load_audio(inputs)
423
 
@@ -504,35 +505,35 @@ class VadDetectMode(Enum):
504
  class VADXOptions:
505
  def __init__(
506
  self,
507
- sample_rate: int = 16000,
508
- detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
509
- snr_mode: int = 0,
510
- max_end_silence_time: int = 800,
511
- max_start_silence_time: int = 3000,
512
  do_start_point_detection: bool = True,
513
  do_end_point_detection: bool = True,
514
- window_size_ms: int = 200,
515
- sil_to_speech_time_thres: int = 150,
516
- speech_to_sil_time_thres: int = 150,
517
  speech_2_noise_ratio: float = 1.0,
518
- do_extend: int = 1,
519
- lookback_time_start_point: int = 200,
520
- lookahead_time_end_point: int = 100,
521
- max_single_segment_time: int = 60000,
522
- nn_eval_block_size: int = 8,
523
- dcd_block_size: int = 4,
524
- snr_thres: int = -100.0,
525
- noise_frame_num_used_for_snr: int = 100,
526
- decibel_thres: int = -100.0,
527
  speech_noise_thres: float = 0.6,
528
  fe_prior_thres: float = 1e-4,
529
- silence_pdf_num: int = 1,
530
  sil_pdf_ids: List[int] = [0],
531
  speech_noise_thresh_low: float = -0.1,
532
  speech_noise_thresh_high: float = 0.3,
533
  output_frame_probs: bool = False,
534
- frame_in_ms: int = 10,
535
- frame_length_ms: int = 25,
536
  ):
537
  self.sample_rate = sample_rate
538
  self.detect_mode = detect_mode
@@ -595,10 +596,10 @@ class E2EVadFrameProb(object):
595
  class WindowDetector(object):
596
  def __init__(
597
  self,
598
- window_size_ms: int,
599
- sil_to_speech_time: int,
600
- speech_to_sil_time: int,
601
- frame_size_ms: int,
602
  ):
603
  self.window_size_ms = window_size_ms
604
  self.sil_to_speech_time = sil_to_speech_time
@@ -633,7 +634,7 @@ class WindowDetector(object):
633
  return int(self.win_size_frame)
634
 
635
  def detect_one_frame(
636
- self, frameState: FrameState, frame_count: int
637
  ) -> AudioChangeState:
638
  cur_frame_state = FrameState.kFrameStateSil
639
  if frameState == FrameState.kFrameStateSpeech:
@@ -773,7 +774,7 @@ class E2EVadModel:
773
 
774
  return scores[1:]
775
 
776
- def pop_data_buf_till_frame(self, frame_idx: int) -> None: # need check again
777
  while self.data_buf_start_frame < frame_idx:
778
  if self.data_buf_size >= int(
779
  self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
@@ -787,8 +788,8 @@ class E2EVadModel:
787
 
788
  def pop_data_to_output_buf(
789
  self,
790
- start_frm: int,
791
- frm_cnt: int,
792
  first_frm_is_start_point: bool,
793
  last_frm_is_end_point: bool,
794
  end_point_is_sent_end: bool,
@@ -849,18 +850,18 @@ class E2EVadModel:
849
  if last_frm_is_end_point:
850
  cur_seg.contain_seg_end_point = True
851
 
852
- def on_silence_detected(self, valid_frame: int):
853
  self.lastest_confirmed_silence_frame = valid_frame
854
  if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
855
  self.pop_data_buf_till_frame(valid_frame)
856
  # silence_detected_callback_
857
  # pass
858
 
859
- def on_voice_detected(self, valid_frame: int) -> None:
860
  self.latest_confirmed_speech_frame = valid_frame
861
  self.pop_data_to_output_buf(valid_frame, 1, False, False, False)
862
 
863
- def on_voice_start(self, start_frame: int, fake_result: bool = False) -> None:
864
  if self.vad_opts.do_start_point_detection:
865
  pass
866
  if self.confirmed_start_frame != -1:
@@ -878,7 +879,7 @@ class E2EVadModel:
878
  )
879
 
880
  def on_voice_end(
881
- self, end_frame: int, fake_result: bool, is_last_frame: bool
882
  ) -> None:
883
  for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
884
  self.on_voice_detected(t)
@@ -896,7 +897,7 @@ class E2EVadModel:
896
  self.number_end_time_detected += 1
897
 
898
  def maybe_on_voice_end_last_frame(
899
- self, is_final_frame: bool, cur_frm_idx: int
900
  ) -> None:
901
  if is_final_frame:
902
  self.on_voice_end(cur_frm_idx, False, True)
@@ -913,7 +914,7 @@ class E2EVadModel:
913
  )
914
  return vad_latency
915
 
916
- def get_frame_state(self, t: int) -> FrameState:
917
  frame_state = FrameState.kFrameStateInvalid
918
  cur_decibel = self.decibel[t - self.decibel_offset]
919
  cur_snr = cur_decibel - self.noise_average_decibel
@@ -1010,7 +1011,7 @@ class E2EVadModel:
1010
  waveform: np.ndarray,
1011
  in_cache: list = None,
1012
  is_final: bool = False,
1013
- max_end_sil: int = 800,
1014
  ) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]:
1015
  feats = [feats]
1016
  if in_cache is None:
@@ -1059,7 +1060,7 @@ class E2EVadModel:
1059
  waveform: np.ndarray,
1060
  in_cache: list = None,
1061
  is_final: bool = False,
1062
- max_end_sil: int = 800,
1063
  ):
1064
  feats = [feats]
1065
  states = []
@@ -1116,7 +1117,7 @@ class E2EVadModel:
1116
  return 0
1117
 
1118
  def detect_one_frame(
1119
- self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool
1120
  ) -> None:
1121
  tmp_cur_frm_state = FrameState.kFrameStateInvalid
1122
  if cur_frm_state == FrameState.kFrameStateSpeech:
@@ -1267,7 +1268,7 @@ class E2EVadModel:
1267
 
1268
 
1269
  class FSMNVad(object):
1270
- def __init__(self, config_dir: str):
1271
  config_dir = Path(config_dir)
1272
  self.config = read_yaml(config_dir / "fsmn-config.yaml")
1273
  self.frontend = WavFrontend(
@@ -1332,16 +1333,16 @@ languages = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeec
1332
  formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
1333
  logging.basicConfig(format=formatter, level=logging.INFO)
1334
 
1335
- def main():
1336
  arg_parser = argparse.ArgumentParser(description="Sense Voice")
1337
  arg_parser.add_argument("-a", "--audio_file", required=True, type=str, help="Model")
1338
  download_model_path = os.path.dirname(__file__)
1339
  arg_parser.add_argument(
1340
  "-dp",
1341
- "--download_path",
1342
  default=download_model_path,
1343
  type=str,
1344
- help="dir path of resource downloaded",
1345
  )
1346
  arg_parser.add_argument("-d", "--device", default=-1, type=int, help="Device")
1347
  arg_parser.add_argument(
@@ -1351,35 +1352,36 @@ def main():
1351
  "-l",
1352
  "--language",
1353
  choices=languages.keys(),
1354
- default="auto",
1355
  type=str,
1356
- help="Language",
1357
  )
1358
  arg_parser.add_argument("--use_itn", action="store_true", help="Use ITN")
1359
- args = arg_parser.parse_args()
1360
 
1361
- front = WavFrontend(os.path.join(download_model_path, "am.mvn"))
 
1362
 
1363
  model = SenseVoiceInferenceSession(
1364
- os.path.join(download_model_path, "embedding.npy"),
1365
  os.path.join(
1366
- download_model_path,
1367
  "sense-voice-encoder.rknn",
1368
  ),
1369
- os.path.join(download_model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"),
1370
- args.device,
1371
- args.num_threads,
1372
  )
1373
  waveform, _sample_rate = sf.read(
1374
- args.audio_file,
1375
  dtype="float32",
1376
  always_2d=True
1377
  )
1378
 
1379
- logging.info(f"Audio {args.audio_file} is {len(waveform) / _sample_rate} seconds, {waveform.shape[1]} channel")
1380
- # load vad model
1381
  start = time.time()
1382
- vad = FSMNVad(download_model_path)
1383
  for channel_id, channel_data in enumerate(waveform.T):
1384
  segments = vad.segments_offline(channel_data)
1385
  results = ""
@@ -1387,16 +1389,20 @@ def main():
1387
  audio_feats = front.get_features(channel_data[part[0] * 16 : part[1] * 16])
1388
  asr_result = model(
1389
  audio_feats[None, ...],
1390
- language=languages[args.language],
1391
- use_itn=args.use_itn,
1392
  )
1393
  logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}")
 
 
1394
  vad.vad.all_reset_detection()
1395
  decoding_time = time.time() - start
1396
  logging.info(f"Decoder audio takes {decoding_time} seconds")
1397
  logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.")
 
1398
 
1399
 
1400
  if __name__ == "__main__":
1401
- main()
 
1402
 
 
28
 
29
  RKNN_INPUT_LEN = 171
30
 
31
+ SPEECH_SCALE = 1
32
 
33
  class VadOrtInferRuntimeSession:
34
  def __init__(self, config, root_dir: Path):
 
72
  )
73
 
74
  def __call__(
75
+ self, input_content
76
  ) -> np.ndarray:
77
  if isinstance(input_content, list):
78
  input_dict = {
 
97
  ):
98
  return [v.name for v in self.session.get_outputs()]
99
 
100
+ def get_character_list(self, key = "character"):
101
  return self.meta_dict[key].splitlines()
102
 
103
+ def have_key(self, key = "character") -> bool:
104
  self.meta_dict = self.session.get_modelmeta().custom_metadata_map
105
  if key in self.meta_dict.keys():
106
  return True
 
196
  ):
197
  return [v.name for v in self.session.get_outputs()]
198
 
199
+ def get_character_list(self, key = "character"):
200
  return self.meta_dict[key].splitlines()
201
 
202
+ def have_key(self, key = "character") -> bool:
203
  self.meta_dict = self.session.get_modelmeta().custom_metadata_map
204
  if key in self.meta_dict.keys():
205
  return True
 
250
  self.sp = spm.SentencePieceProcessor()
251
  self.sp.load(bpe_model_file)
252
 
253
+ def __call__(self, speech, language, use_itn: bool) -> np.ndarray:
254
  language_query = self.embedding[[[language]]]
255
 
256
  # 14 means with itn, 15 means without itn
 
274
  input_content = np.pad(input_content, ((0, 0), (0, RKNN_INPUT_LEN - input_content.shape[1]), (0, 0)))
275
  print("padded shape:", input_content.shape)
276
  start_time = time.time()
277
+ np.save("input_content.npy",input_content)
278
  encoder_out = self.encoder.inference(inputs=[input_content])[0]
279
  end_time = time.time()
280
  print(f"encoder inference time: {end_time - start_time:.2f} seconds")
 
309
 
310
  def __init__(
311
  self,
312
+ cmvn_file = None,
313
+ fs = 16000,
314
+ window = "hamming",
315
+ n_mels = 80,
316
+ frame_length = 25,
317
+ frame_shift = 10,
318
+ lfr_m = 7,
319
+ lfr_n = 6,
320
  dither: float = 0,
321
  **kwargs,
322
  ) -> None:
 
368
  feat_len = np.array(feat.shape[0]).astype(np.int32)
369
  return feat, feat_len
370
 
371
+ def load_audio(self, filename) -> Tuple[np.ndarray, int]:
372
  data, sample_rate = sf.read(
373
  filename,
374
  always_2d=True,
 
384
  return samples, sample_rate
385
 
386
  @staticmethod
387
+ def apply_lfr(inputs: np.ndarray, lfr_m, lfr_n) -> np.ndarray:
388
  LFR_inputs = []
389
 
390
  T = inputs.shape[0]
 
418
  inputs = (inputs + means) * vars
419
  return inputs
420
 
421
+ def get_features(self, inputs: Union[str, np.ndarray]):
422
  if isinstance(inputs, str):
423
  inputs, _ = self.load_audio(inputs)
424
 
 
505
  class VADXOptions:
506
  def __init__(
507
  self,
508
+ sample_rate = 16000,
509
+ detect_mode = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
510
+ snr_mode = 0,
511
+ max_end_silence_time = 800,
512
+ max_start_silence_time = 3000,
513
  do_start_point_detection: bool = True,
514
  do_end_point_detection: bool = True,
515
+ window_size_ms = 200,
516
+ sil_to_speech_time_thres = 150,
517
+ speech_to_sil_time_thres = 150,
518
  speech_2_noise_ratio: float = 1.0,
519
+ do_extend = 1,
520
+ lookback_time_start_point = 200,
521
+ lookahead_time_end_point = 100,
522
+ max_single_segment_time = 60000,
523
+ nn_eval_block_size = 8,
524
+ dcd_block_size = 4,
525
+ snr_thres = -100.0,
526
+ noise_frame_num_used_for_snr = 100,
527
+ decibel_thres = -100.0,
528
  speech_noise_thres: float = 0.6,
529
  fe_prior_thres: float = 1e-4,
530
+ silence_pdf_num = 1,
531
  sil_pdf_ids: List[int] = [0],
532
  speech_noise_thresh_low: float = -0.1,
533
  speech_noise_thresh_high: float = 0.3,
534
  output_frame_probs: bool = False,
535
+ frame_in_ms = 10,
536
+ frame_length_ms = 25,
537
  ):
538
  self.sample_rate = sample_rate
539
  self.detect_mode = detect_mode
 
596
  class WindowDetector(object):
597
  def __init__(
598
  self,
599
+ window_size_ms,
600
+ sil_to_speech_time,
601
+ speech_to_sil_time,
602
+ frame_size_ms,
603
  ):
604
  self.window_size_ms = window_size_ms
605
  self.sil_to_speech_time = sil_to_speech_time
 
634
  return int(self.win_size_frame)
635
 
636
  def detect_one_frame(
637
+ self, frameState: FrameState, frame_count
638
  ) -> AudioChangeState:
639
  cur_frame_state = FrameState.kFrameStateSil
640
  if frameState == FrameState.kFrameStateSpeech:
 
774
 
775
  return scores[1:]
776
 
777
+ def pop_data_buf_till_frame(self, frame_idx) -> None: # need check again
778
  while self.data_buf_start_frame < frame_idx:
779
  if self.data_buf_size >= int(
780
  self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
 
788
 
789
  def pop_data_to_output_buf(
790
  self,
791
+ start_frm,
792
+ frm_cnt,
793
  first_frm_is_start_point: bool,
794
  last_frm_is_end_point: bool,
795
  end_point_is_sent_end: bool,
 
850
  if last_frm_is_end_point:
851
  cur_seg.contain_seg_end_point = True
852
 
853
+ def on_silence_detected(self, valid_frame):
854
  self.lastest_confirmed_silence_frame = valid_frame
855
  if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
856
  self.pop_data_buf_till_frame(valid_frame)
857
  # silence_detected_callback_
858
  # pass
859
 
860
+ def on_voice_detected(self, valid_frame) -> None:
861
  self.latest_confirmed_speech_frame = valid_frame
862
  self.pop_data_to_output_buf(valid_frame, 1, False, False, False)
863
 
864
+ def on_voice_start(self, start_frame, fake_result: bool = False) -> None:
865
  if self.vad_opts.do_start_point_detection:
866
  pass
867
  if self.confirmed_start_frame != -1:
 
879
  )
880
 
881
  def on_voice_end(
882
+ self, end_frame, fake_result: bool, is_last_frame: bool
883
  ) -> None:
884
  for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
885
  self.on_voice_detected(t)
 
897
  self.number_end_time_detected += 1
898
 
899
  def maybe_on_voice_end_last_frame(
900
+ self, is_final_frame: bool, cur_frm_idx
901
  ) -> None:
902
  if is_final_frame:
903
  self.on_voice_end(cur_frm_idx, False, True)
 
914
  )
915
  return vad_latency
916
 
917
+ def get_frame_state(self, t) -> FrameState:
918
  frame_state = FrameState.kFrameStateInvalid
919
  cur_decibel = self.decibel[t - self.decibel_offset]
920
  cur_snr = cur_decibel - self.noise_average_decibel
 
1011
  waveform: np.ndarray,
1012
  in_cache: list = None,
1013
  is_final: bool = False,
1014
+ max_end_sil = 800,
1015
  ) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]:
1016
  feats = [feats]
1017
  if in_cache is None:
 
1060
  waveform: np.ndarray,
1061
  in_cache: list = None,
1062
  is_final: bool = False,
1063
+ max_end_sil = 800,
1064
  ):
1065
  feats = [feats]
1066
  states = []
 
1117
  return 0
1118
 
1119
  def detect_one_frame(
1120
+ self, cur_frm_state: FrameState, cur_frm_idx, is_final_frame: bool
1121
  ) -> None:
1122
  tmp_cur_frm_state = FrameState.kFrameStateInvalid
1123
  if cur_frm_state == FrameState.kFrameStateSpeech:
 
1268
 
1269
 
1270
  class FSMNVad(object):
1271
+ def __init__(self, config_dir):
1272
  config_dir = Path(config_dir)
1273
  self.config = read_yaml(config_dir / "fsmn-config.yaml")
1274
  self.frontend = WavFrontend(
 
1333
  formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
1334
  logging.basicConfig(format=formatter, level=logging.INFO)
1335
 
1336
+ def parse_args():
1337
  arg_parser = argparse.ArgumentParser(description="Sense Voice")
1338
  arg_parser.add_argument("-a", "--audio_file", required=True, type=str, help="Model")
1339
  download_model_path = os.path.dirname(__file__)
1340
  arg_parser.add_argument(
1341
  "-dp",
1342
+ "--download_path",
1343
  default=download_model_path,
1344
  type=str,
1345
+ help="dir path of resource downloaded"
1346
  )
1347
  arg_parser.add_argument("-d", "--device", default=-1, type=int, help="Device")
1348
  arg_parser.add_argument(
 
1352
  "-l",
1353
  "--language",
1354
  choices=languages.keys(),
1355
+ default="auto",
1356
  type=str,
1357
+ help="Language"
1358
  )
1359
  arg_parser.add_argument("--use_itn", action="store_true", help="Use ITN")
1360
+ return arg_parser.parse_args()
1361
 
1362
+ def main(audio_file, download_path, device, num_threads, language, use_itn):
1363
+ front = WavFrontend(os.path.join(download_path, "am.mvn"))
1364
 
1365
  model = SenseVoiceInferenceSession(
1366
+ os.path.join(download_path, "embedding.npy"),
1367
  os.path.join(
1368
+ download_path,
1369
  "sense-voice-encoder.rknn",
1370
  ),
1371
+ os.path.join(download_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"),
1372
+ device,
1373
+ num_threads,
1374
  )
1375
  waveform, _sample_rate = sf.read(
1376
+ audio_file,
1377
  dtype="float32",
1378
  always_2d=True
1379
  )
1380
 
1381
+ logging.info(f"Audio {audio_file} is {len(waveform) / _sample_rate} seconds, {waveform.shape[1]} channel")
1382
+ # load vad model
1383
  start = time.time()
1384
+ vad = FSMNVad(download_path)
1385
  for channel_id, channel_data in enumerate(waveform.T):
1386
  segments = vad.segments_offline(channel_data)
1387
  results = ""
 
1389
  audio_feats = front.get_features(channel_data[part[0] * 16 : part[1] * 16])
1390
  asr_result = model(
1391
  audio_feats[None, ...],
1392
+ language=languages[language],
1393
+ use_itn=use_itn,
1394
  )
1395
  logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}")
1396
+ results += asr_result
1397
+ logging.info(f"Results: {results}")
1398
  vad.vad.all_reset_detection()
1399
  decoding_time = time.time() - start
1400
  logging.info(f"Decoder audio takes {decoding_time} seconds")
1401
  logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.")
1402
+ return results
1403
 
1404
 
1405
  if __name__ == "__main__":
1406
+ args = parse_args()
1407
+ main(args.audio_file, args.download_path, args.device, args.num_threads, args.language, args.use_itn)
1408