Initial Draft

This commit is contained in:
Andrey Tkachenko 2020-08-17 22:09:19 +04:00
commit 2c98558d53
9 changed files with 541 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/target
Cargo.lock

17
Cargo.toml Normal file
View File

@ -0,0 +1,17 @@
[package]
name = "vosk"
version = "0.1.0"
authors = ["Andrey Tkachenko <andreytkachenko64@gmail.com>"]
edition = "2018"
build = "build.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
[build-dependencies]
bindgen = "0.54.1"
cc = { version = "1.0.58", features = ["parallel"] }
[dev-dependencies]
audrey = "0.2.0"

256
build.rs Normal file
View File

@ -0,0 +1,256 @@
use std::{env, fs};
use std::path::PathBuf;
use std::process::Command;
fn main() {
println!("cargo:rerun-if-changed=cbits/vosk.h");
println!("cargo:rerun-if-changed=build.rs");
let bindings = bindgen::Builder::default()
.generate_inline_functions(true)
.derive_default(false)
.header("cbits/vosk.h")
.clang_arg("-I./resources/vosk-api/src/")
.clang_arg("-I./resources/kaldi/src/")
.clang_arg("-I./resources/openfst/src/include")
.clang_arg("-std=c++14")
.clang_arg("-x")
.clang_arg("c++")
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
.opaque_type("std::.*")
.opaque_type("kaldi::.*")
.opaque_type("fst::.*")
.opaque_type("KaldiRecognizer")
.opaque_type("Model")
.opaque_type("SpkModel")
.whitelist_type("KaldiRecognizer")
.whitelist_type("Model")
.whitelist_type("SpkModel")
.rustified_non_exhaustive_enum("*")
.no_copy(".*")
.layout_tests(false)
.generate()
.expect("Unable to generate bindings");
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
cc::Build::new()
.warnings(false)
.static_flag(true)
.cpp(true)
.include("resources/openfst/src/include")
.file("resources/openfst/src/lib/compat.cc")
.file("resources/openfst/src/lib/encode.cc")
.file("resources/openfst/src/lib/fst-types.cc")
.file("resources/openfst/src/lib/fst.cc")
.file("resources/openfst/src/lib/mapped-file.cc")
.file("resources/openfst/src/lib/properties.cc")
.file("resources/openfst/src/lib/symbol-table-ops.cc")
.file("resources/openfst/src/lib/symbol-table.cc")
.file("resources/openfst/src/lib/util.cc")
.file("resources/openfst/src/lib/weight.cc")
.compile("libopenfst");
cc::Build::new()
.warnings(false)
.static_flag(true)
.cpp(true)
.include("resources/vosk-api/src")
.include("resources/kaldi/src/")
.include("resources/openfst/src/include")
.file("resources/vosk-api/src/kaldi_recognizer.cc")
.file("resources/vosk-api/src/model.cc")
.file("resources/vosk-api/src/spk_model.cc")
.compile("libvosk");
let out_dir = env::var("OUT_DIR").unwrap();
let contents = fs::read_to_string("resources/kaldi/src/lat/kaldi-lattice.cc").expect("Something went wrong reading the file");
let contents = contents.replace("printer.Print(&os, \"<unknown>\");", "printer.Print(os, \"<unknown>\");");
let kaldi_lattice = format!("{}/kaldi-lattice.cc", out_dir);
fs::write(&kaldi_lattice, contents).expect("Write file!");
Command::new("sh")
.arg("-c")
.arg("resources/kaldi/src/base/get_version.sh")
.status()
.expect("Failed get_version.sh!");
cc::Build::new()
.warnings(false)
.static_flag(true)
.cpp(true)
.define("HAVE_OPENBLAS", "true")
.include("resources/openfst/src/include")
.include("resources/kaldi/src")
// base
.file("resources/kaldi/src/base/io-funcs.cc")
.file("resources/kaldi/src/base/kaldi-error.cc")
.file("resources/kaldi/src/base/kaldi-math.cc")
.file("resources/kaldi/src/base/kaldi-utils.cc")
// .file("resources/kaldi/src/base/timer.cc")
// matrix
.file("resources/kaldi/src/matrix/kaldi-matrix.cc")
.file("resources/kaldi/src/matrix/kaldi-vector.cc")
.file("resources/kaldi/src/matrix/matrix-functions.cc")
.file("resources/kaldi/src/matrix/optimization.cc")
// cuda
.file("resources/kaldi/src/cudamatrix/cu-matrix.cc")
.file("resources/kaldi/src/cudamatrix/cu-allocator.cc")
.file("resources/kaldi/src/cudamatrix/cu-common.cc")
// .file("resources/kaldi/src/cudamatrix/cu-math.cc")
// fstext
.file("resources/kaldi/src/fstext/context-fst.cc")
.file("resources/kaldi/src/fstext/grammar-context-fst.cc")
.file("resources/kaldi/src/fstext/kaldi-fst-io.cc")
.file("resources/kaldi/src/fstext/push-special.cc")
// feat
// .file("resources/kaldi/src/feat/feature-fbank.cc")
// .file("resources/kaldi/src/feat/feature-functions.cc")
.file("resources/kaldi/src/feat/feature-mfcc.cc")
// .file("resources/kaldi/src/feat/feature-plp.cc")
// .file("resources/kaldi/src/feat/feature-spectrogram.cc")
// .file("resources/kaldi/src/feat/feature-window.cc")
// .file("resources/kaldi/src/feat/mel-computations.cc")
// .file("resources/kaldi/src/feat/online-feature.cc")
// .file("resources/kaldi/src/feat/pitch-functions.cc")
// .file("resources/kaldi/src/feat/resample.cc")
// .file("resources/kaldi/src/feat/signal.cc")
// .file("resources/kaldi/src/feat/wave-reader.cc")
// lm
// .file("resources/kaldi/src/lm/arpa-file-parser.cc")
// .file("resources/kaldi/src/lm/arpa-lm-compiler.cc")
.file("resources/kaldi/src/lm/const-arpa-lm.cc")
// .file("resources/kaldi/src/lm/kaldi-rnnlm.cc")
// .file("resources/kaldi/src/lm/mikolov-rnnlm-lib.cc")
// rnnlm
// .file("resources/kaldi/src/rnnlm/rnnlm-compute-state.cc")
// .file("resources/kaldi/src/rnnlm/rnnlm-core-compute.cc")
// .file("resources/kaldi/src/rnnlm/rnnlm-core-training.cc")
// .file("resources/kaldi/src/rnnlm/rnnlm-embedding-training.cc")
// .file("resources/kaldi/src/rnnlm/rnnlm-example-utils.cc")
// .file("resources/kaldi/src/rnnlm/rnnlm-example.cc")
// .file("resources/kaldi/src/rnnlm/rnnlm-lattice-rescoring.cc")
// .file("resources/kaldi/src/rnnlm/rnnlm-training.cc")
.file("resources/kaldi/src/rnnlm/rnnlm-utils.cc")
// .file("resources/kaldi/src/rnnlm/sampler.cc")
// .file("resources/kaldi/src/rnnlm/sampling-lm-estimate.cc")
// .file("resources/kaldi/src/rnnlm/sampling-lm.cc")
// decoder
// .file("resources/kaldi/src/decoder/decodable-matrix.cc")
// .file("resources/kaldi/src/decoder/decoder-wrappers.cc")
// .file("resources/kaldi/src/decoder/faster-decoder.cc")
// .file("resources/kaldi/src/decoder/grammar-fst.cc")
.file("resources/kaldi/src/decoder/lattice-faster-decoder.cc")
// .file("resources/kaldi/src/decoder/lattice-faster-online-decoder.cc")
// .file("resources/kaldi/src/decoder/lattice-incremental-decoder.cc")
// .file("resources/kaldi/src/decoder/lattice-incremental-online-decoder.cc")
// .file("resources/kaldi/src/decoder/lattice-simple-decoder.cc")
// .file("resources/kaldi/src/decoder/simple-decoder.cc")
// .file("resources/kaldi/src/decoder/training-graph-compiler.cc")
// nnet3
.file("resources/kaldi/src/nnet3/am-nnet-simple.cc")
// .file("resources/kaldi/src/nnet3/attention.cc")
// .file("resources/kaldi/src/nnet3/convolution.cc")
.file("resources/kaldi/src/nnet3/decodable-online-looped.cc")
.file("resources/kaldi/src/nnet3/decodable-simple-looped.cc")
// .file("resources/kaldi/src/nnet3/discriminative-supervision.cc")
// .file("resources/kaldi/src/nnet3/discriminative-training.cc")
// .file("resources/kaldi/src/nnet3/natural-gradient-online.cc")
.file("resources/kaldi/src/nnet3/nnet-am-decodable-simple.cc")
// .file("resources/kaldi/src/nnet3/nnet-analyze.cc")
// .file("resources/kaldi/src/nnet3/nnet-attention-component.cc")
// .file("resources/kaldi/src/nnet3/nnet-batch-compute.cc")
// .file("resources/kaldi/src/nnet3/nnet-chain-diagnostics.cc")
// .file("resources/kaldi/src/nnet3/nnet-chain-diagnostics2.cc")
// .file("resources/kaldi/src/nnet3/nnet-chain-example.cc")
// .file("resources/kaldi/src/nnet3/nnet-chain-training.cc")
// .file("resources/kaldi/src/nnet3/nnet-chain-training2.cc")
// .file("resources/kaldi/src/nnet3/nnet-combined-component.cc")
// .file("resources/kaldi/src/nnet3/nnet-common.cc")
// .file("resources/kaldi/src/nnet3/nnet-compile-looped.cc")
// .file("resources/kaldi/src/nnet3/nnet-compile-utils.cc")
// .file("resources/kaldi/src/nnet3/nnet-compile.cc")
// .file("resources/kaldi/src/nnet3/nnet-component-itf.cc")
// .file("resources/kaldi/src/nnet3/nnet-computation-graph.cc")
.file("resources/kaldi/src/nnet3/nnet-computation.cc")
.file("resources/kaldi/src/nnet3/nnet-compute.cc")
// .file("resources/kaldi/src/nnet3/nnet-convolutional-component.cc")
// .file("resources/kaldi/src/nnet3/nnet-descriptor.cc")
// .file("resources/kaldi/src/nnet3/nnet-diagnostics.cc")
// .file("resources/kaldi/src/nnet3/nnet-discriminative-diagnostics.cc")
// .file("resources/kaldi/src/nnet3/nnet-discriminative-example.cc")
// .file("resources/kaldi/src/nnet3/nnet-discriminative-training.cc")
// .file("resources/kaldi/src/nnet3/nnet-example-utils.cc")
// .file("resources/kaldi/src/nnet3/nnet-example.cc")
// .file("resources/kaldi/src/nnet3/nnet-general-component.cc")
// .file("resources/kaldi/src/nnet3/nnet-graph.cc")
.file("resources/kaldi/src/nnet3/nnet-nnet.cc")
// .file("resources/kaldi/src/nnet3/nnet-normalize-component.cc")
// .file("resources/kaldi/src/nnet3/nnet-optimize-utils.cc")
// .file("resources/kaldi/src/nnet3/nnet-optimize.cc")
.file("resources/kaldi/src/nnet3/nnet-parse.cc")
// .file("resources/kaldi/src/nnet3/nnet-simple-component.cc")
// .file("resources/kaldi/src/nnet3/nnet-tdnn-component.cc")
// .file("resources/kaldi/src/nnet3/nnet-training.cc")
.file("resources/kaldi/src/nnet3/nnet-utils.cc")
// lat
// .file("resources/kaldi/src/lat/compose-lattice-pruned.cc")
// .file("resources/kaldi/src/lat/confidence.cc")
// .file("resources/kaldi/src/lat/determinize-lattice-pruned.cc")
.file(&kaldi_lattice) // .file("resources/kaldi/src/lat/kaldi-lattice.cc")
.file("resources/kaldi/src/lat/lattice-functions.cc")
// .file("resources/kaldi/src/lat/minimize-lattice.cc")
// .file("resources/kaldi/src/lat/phone-align-lattice.cc")
// .file("resources/kaldi/src/lat/push-lattice.cc")
.file("resources/kaldi/src/lat/sausages.cc")
.file("resources/kaldi/src/lat/word-align-lattice-lexicon.cc")
.file("resources/kaldi/src/lat/word-align-lattice.cc")
// util
// .file("resources/kaldi/src/util/kaldi-holder.cc")
.file("resources/kaldi/src/util/kaldi-io.cc")
// .file("resources/kaldi/src/util/kaldi-semaphore.cc")
// .file("resources/kaldi/src/util/kaldi-table.cc")
// .file("resources/kaldi/src/util/kaldi-thread.cc")
.file("resources/kaldi/src/util/parse-options.cc")
.file("resources/kaldi/src/util/simple-io-funcs.cc")
// .file("resources/kaldi/src/util/simple-options.cc")
.file("resources/kaldi/src/util/text-utils.cc")
// online2
.file("resources/kaldi/src/online2/online-endpoint.cc")
.file("resources/kaldi/src/online2/online-feature-pipeline.cc")
// .file("resources/kaldi/src/online2/online-gmm-decodable.cc")
// .file("resources/kaldi/src/online2/online-gmm-decoding.cc")
// .file("resources/kaldi/src/online2/online-ivector-feature.cc")
// .file("resources/kaldi/src/online2/online-nnet2-decoding-threaded.cc")
// .file("resources/kaldi/src/online2/online-nnet2-decoding.cc")
// .file("resources/kaldi/src/online2/online-nnet2-feature-pipeline.cc")
.file("resources/kaldi/src/online2/online-nnet3-decoding.cc")
// .file("resources/kaldi/src/online2/online-nnet3-incremental-decoding.cc")
// .file("resources/kaldi/src/online2/online-nnet3-wake-word-faster-decoder.cc")
// .file("resources/kaldi/src/online2/online-speex-wrapper.cc")
.file("resources/kaldi/src/online2/online-timing.cc")
// .file("resources/kaldi/src/online2/onlinebin-util.cc")
.compile("libkaldi");
}

3
cbits/vosk.h Normal file
View File

@ -0,0 +1,3 @@
#include "model.h"
#include "kaldi_recognizer.h"
#include "spk_model.h"

66
examples/demo.rs Normal file
View File

@ -0,0 +1,66 @@
use vosk::VoskModel;
use audrey::read::Reader;
use audrey::sample::interpolate::{Converter, Linear, Sinc};
use audrey::sample::signal::{from_iter, Signal};
use std::fs::File;
const SAMPLE_RATE: u32 = 16000;
pub fn main() {
let audio_file_path = std::env::args().nth(1)
.expect("Please specify an audio file to run STT on");
let audio_file = File::open(audio_file_path).unwrap();
let mut reader = Reader::new(audio_file).unwrap();
let desc = reader.description();
assert_eq!(1, desc.channel_count(),
"The channel count is required to be one, at least for now");
let audio_buf :Vec<_> = if desc.sample_rate() == SAMPLE_RATE {
reader.samples().map(|s| s.unwrap()).collect()
} else {
// We need to interpolate to the target sample rate
let interpolator = Linear::new([0i16], [0]);
let conv = Converter::from_hz_to_hz(
from_iter(reader.samples::<i16>().map(|s| [s.unwrap()])),
interpolator,
desc.sample_rate() as f64,
SAMPLE_RATE as f64);
conv.until_exhausted().map(|v| v[0]).collect()
};
let model = VoskModel::new("./models/en-small");
let sess = model.create_session(Default::default());
// audio_buf
// FILE *wavin;
// char buf[3200];
// int nread, final;
// VoskModel *model = vosk_model_new("model");
// VoskRecognizer *recognizer = vosk_recognizer_new(model, 16000.0);
// wavin = fopen("test.wav", "rb");
// fseek(wavin, 44, SEEK_SET);
// while (!feof(wavin)) {
// nread = fread(buf, 1, sizeof(buf), wavin);
// final = vosk_recognizer_accept_waveform(recognizer, buf, nread);
// if (final) {
// printf("%s\n", vosk_recognizer_result(recognizer));
// } else {
// printf("%s\n", vosk_recognizer_partial_result(recognizer));
// }
// }
// printf("%s\n", vosk_recognizer_final_result(recognizer));
// vosk_recognizer_free(recognizer);
// vosk_model_free(model);
// return 0;
}

16
src/lib.rs Normal file
View File

@ -0,0 +1,16 @@
mod ffi {
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
#![allow(dead_code)]
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
mod model;
mod session;
mod speaker;
pub use model::VoskModel;
pub use session::{VoskSession, VoskSessionConfig, VoskSessionConfigBuilder};
pub use speaker::SpeakerModel;

56
src/model.rs Normal file
View File

@ -0,0 +1,56 @@
use std::path::Path;
use std::ffi::{CStr, CString};
use crate::ffi;
use crate::session::{VoskSession, VoskSessionConfig};
pub struct VoskModel {
pub(crate) inner: ffi::Model,
}
impl VoskModel {
pub fn new<P: AsRef<Path>>(root: P) -> Self {
let root = unsafe { CString::from_vec_unchecked(root.as_ref().to_string_lossy().as_bytes().to_vec()) };
Self {
inner: unsafe { ffi::Model::new(root.as_c_str().as_ptr()) },
}
}
#[inline]
pub fn create_session(&self, cfg: VoskSessionConfig) -> VoskSession {
VoskSession::new(&self.inner, cfg)
}
#[inline]
pub fn feed(&self, sess: &mut VoskSession, data: &[i16]) -> bool {
unsafe { ffi::KaldiRecognizer_AcceptWaveform1(&mut sess.inner, data.as_ptr(), data.len() as _) }
}
#[inline]
pub fn get_result(&self, sess: &mut VoskSession) -> String {
let cstr = unsafe { CStr::from_ptr(ffi::KaldiRecognizer_Result(&mut sess.inner)) };
cstr.to_string_lossy().to_string()
}
#[inline]
pub fn get_partial_result(&self, sess: &mut VoskSession) -> String {
let cstr = unsafe { CStr::from_ptr(ffi::KaldiRecognizer_PartialResult(&mut sess.inner)) };
cstr.to_string_lossy().to_string()
}
#[inline]
pub fn get_final_result(&self, sess: &mut VoskSession) -> String {
let cstr = unsafe { CStr::from_ptr(ffi::KaldiRecognizer_FinalResult(&mut sess.inner)) };
cstr.to_string_lossy().to_string()
}
}
impl Drop for VoskModel {
fn drop(&mut self) {
unsafe { self.inner.destruct() }
}
}

102
src/session.rs Normal file
View File

@ -0,0 +1,102 @@
use crate::ffi;
use std::ffi::CString;
use std::path::{PathBuf};
pub struct VoskSessionConfigBuilder {
spk_root: Option<CString>,
grammar: Option<CString>,
freq: f32,
}
impl VoskSessionConfigBuilder {
fn new() -> Self {
VoskSessionConfigBuilder {
spk_root: None,
grammar: None,
freq: 16000.0,
}
}
pub fn spk_root<P: Into<PathBuf>>(&mut self, root: P) -> &mut Self {
self.spk_root = Some(unsafe { CString::from_vec_unchecked(root.into().to_string_lossy().as_bytes().to_vec()) });
self
}
pub fn sampling_freq(&mut self, freq: f32) -> &mut Self {
self.freq = freq;
self
}
pub fn grammar<G: AsRef<str>>(&mut self, grammar: G) -> &mut Self {
self.grammar = Some(CString::new(grammar.as_ref()).unwrap());
self
}
pub fn finish(&mut self) -> VoskSessionConfig {
VoskSessionConfig {
spk_root: core::mem::take(&mut self.spk_root),
grammar: core::mem::take(&mut self.grammar),
freq: self.freq,
}
}
}
pub struct VoskSessionConfig {
spk_root: Option<CString>,
grammar: Option<CString>,
freq: f32,
}
impl Default for VoskSessionConfig {
fn default() -> Self {
Self {
spk_root: None,
grammar: None,
freq: 16000.0
}
}
}
impl VoskSessionConfig {
pub fn builder() -> VoskSessionConfigBuilder {
VoskSessionConfigBuilder::new()
}
#[inline]
pub fn set_spk_root(&mut self, root: CString) {
self.spk_root = Some(root);
}
#[inline]
pub fn set_grammar(&mut self, grammar: CString) {
self.grammar = Some(grammar);
}
#[inline]
pub fn set_freq(&mut self, freq: f32) {
self.freq = freq;
}
}
pub struct VoskSession {
pub(crate) inner: ffi::KaldiRecognizer
}
impl VoskSession {
pub(crate) fn new(model: *const ffi::Model, cfg: VoskSessionConfig) -> Self {
if let Some(_cfg) = &cfg.spk_root {
unimplemented!()
// VoskSession {
// inner: ffi::KaldiRecognizer::new1(model as *mut ffi::Model, cfg.freq)
// }
} else if let Some(grammar) = &cfg.grammar {
VoskSession {
inner: unsafe { ffi::KaldiRecognizer::new2(model as *mut ffi::Model, cfg.freq, grammar.as_c_str().as_ptr()) }
}
} else {
VoskSession {
inner: unsafe { ffi::KaldiRecognizer::new(model as *mut ffi::Model, cfg.freq) }
}
}
}
}

23
src/speaker.rs Normal file
View File

@ -0,0 +1,23 @@
use crate::ffi;
use std::path::Path;
use std::ffi::CString;
pub struct SpeakerModel {
pub(crate) inner: ffi::SpkModel
}
impl SpeakerModel {
pub fn new(root: &Path) -> Self {
let root = unsafe { CString::from_vec_unchecked(root.to_string_lossy().as_bytes().to_vec()) };
Self {
inner: unsafe { ffi::SpkModel::new(root.as_c_str().as_ptr()) }
}
}
}
impl Drop for SpeakerModel {
fn drop(&mut self) {
unsafe { self.inner.destruct() }
}
}