Add run_mut
This commit is contained in:
parent
a80e4536a2
commit
bd968a44f5
76
src/lib.rs
76
src/lib.rs
@ -278,18 +278,49 @@ impl Session {
|
||||
Ok(OrtString { raw })
|
||||
}
|
||||
|
||||
// pub fn run(
|
||||
// &self,
|
||||
// options: &RunOptions,
|
||||
// input_names: InputNames,
|
||||
// inputs: Inputs,
|
||||
// output_names: OutputNames,
|
||||
// ) -> Result<Vec<Value>>
|
||||
// where InputNames: Iterator<Item=&str>,
|
||||
// Inputs: Iterator<Item=&Value>,
|
||||
// OutputNames: Iterator<Item=&str>,
|
||||
// {
|
||||
// }
|
||||
pub fn run_mut(
|
||||
&self,
|
||||
options: &RunOptions,
|
||||
input_names: &[&str],
|
||||
inputs: &[&Value],
|
||||
output_names: &[&str],
|
||||
outputs: &[&mut Value],
|
||||
) -> Result<()> {
|
||||
assert_eq!(input_names.len(), inputs.len());
|
||||
assert_eq!(output_names.len(), outputs.len());
|
||||
|
||||
let input_names = input_names
|
||||
.iter()
|
||||
.map(|n| CString::new(*n))
|
||||
.collect::<std::result::Result<Vec<_>, ffi::NulError>>()?;
|
||||
let input_names_ptrs = input_names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
|
||||
let output_names = output_names
|
||||
.iter()
|
||||
.map(|n| CString::new(*n))
|
||||
.collect::<std::result::Result<Vec<_>, ffi::NulError>>()?;
|
||||
let output_names_ptrs = output_names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
|
||||
let inputs = inputs
|
||||
.iter()
|
||||
.map(|v| v.raw as *const sys::Value)
|
||||
.collect::<Vec<_>>();
|
||||
let mut outputs = outputs.iter().map(|v| v.raw).collect::<Vec<_>>();
|
||||
let output_size = output_names.len() as u64;
|
||||
unsafe {
|
||||
checked_call!(
|
||||
Run,
|
||||
self.raw,
|
||||
options.raw,
|
||||
input_names_ptrs.as_ptr(),
|
||||
inputs.as_ptr(),
|
||||
inputs.len() as u64,
|
||||
output_names_ptrs.as_ptr(),
|
||||
output_size,
|
||||
outputs.as_mut_ptr()
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_raw(
|
||||
&self,
|
||||
@ -426,6 +457,7 @@ mod tests {
|
||||
|
||||
let input_tensor = Tensor::new(mem_info, input_shape, input_data)?;
|
||||
|
||||
// immutable version
|
||||
let output = session.run_raw(&ro, &input_names, &[input_tensor.value()], &output_names)?;
|
||||
|
||||
let output_value = output.into_iter().next().unwrap();
|
||||
@ -436,6 +468,26 @@ mod tests {
|
||||
&[1. + 2. * 2., 3. + 2. * 4., 5. + 2. * 6.]
|
||||
);
|
||||
|
||||
// mutable version
|
||||
let mem_info = MemoryInfo::cpu_memory_info(AllocatorType::ArenaAllocator, MemType::Cpu)?;
|
||||
let output_data: Vec<f32> = vec![0.0; 3];
|
||||
let output_shape = vec![3, 1];
|
||||
|
||||
let mut output_tensor = Tensor::new(mem_info, output_shape, output_data)?;
|
||||
|
||||
session.run_mut(
|
||||
&ro,
|
||||
&input_names,
|
||||
&[input_tensor.value()],
|
||||
&output_names,
|
||||
&[output_tensor.value_mut()],
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
&output_tensor[..],
|
||||
&[1. + 2. * 2., 3. + 2. * 4., 5. + 2. * 6.]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user