Add available_providers method

This commit is contained in:
Andrey Tkachenko 2021-01-22 14:08:14 +04:00
parent 8041b81588
commit 4f6be5efb9
2 changed files with 20 additions and 0 deletions

View File

@ -25,6 +25,8 @@ fn main() -> Result<()> {
let so = SessionOptions::new()?; let so = SessionOptions::new()?;
println!("Available Providers: {:?}", so.available_providers());
for path in &opt.onnx { for path in &opt.onnx {
println!("model {:?}", path); println!("model {:?}", path);
let session = match Session::new(&env, path, &so) { let session = match Session::new(&env, path, &so) {

View File

@ -161,6 +161,24 @@ impl SessionOptions {
Ok(SessionOptions { raw }) Ok(SessionOptions { raw })
} }
pub fn available_providers(&self) -> Vec<String> {
let mut providers_array: *mut *mut i8 = std::ptr::null_mut();
let mut providers_len: i32 = 0;
call!(@unsafe @expect GetAvailableProviders, &mut providers_array, &mut providers_len);
let slice = unsafe { std::slice::from_raw_parts(providers_array, providers_len as usize) };
let mut res = Vec::new();
for i in slice {
res.push(unsafe {CStr::from_ptr(*i)}.to_string_lossy().to_string());
}
call!(@unsafe @expect ReleaseAvailableProviders, providers_array, providers_len);
res
}
pub fn add_cpu(&self, use_arena: bool) { pub fn add_cpu(&self, use_arena: bool) {
let so = self.raw; let so = self.raw;
let status = unsafe { let status = unsafe {