Add run_mut

This commit is contained in:
Christopher Chalmers 2020-05-19 10:28:52 +01:00
parent a80e4536a2
commit bd968a44f5

View File

@ -278,18 +278,49 @@ impl Session {
Ok(OrtString { raw })
}
// pub fn run(
// &self,
// options: &RunOptions,
// input_names: InputNames,
// inputs: Inputs,
// output_names: OutputNames,
// ) -> Result<Vec<Value>>
// where InputNames: Iterator<Item=&str>,
// Inputs: Iterator<Item=&Value>,
// OutputNames: Iterator<Item=&str>,
// {
// }
pub fn run_mut(
&self,
options: &RunOptions,
input_names: &[&str],
inputs: &[&Value],
output_names: &[&str],
outputs: &[&mut Value],
) -> Result<()> {
assert_eq!(input_names.len(), inputs.len());
assert_eq!(output_names.len(), outputs.len());
let input_names = input_names
.iter()
.map(|n| CString::new(*n))
.collect::<std::result::Result<Vec<_>, ffi::NulError>>()?;
let input_names_ptrs = input_names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
let output_names = output_names
.iter()
.map(|n| CString::new(*n))
.collect::<std::result::Result<Vec<_>, ffi::NulError>>()?;
let output_names_ptrs = output_names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
let inputs = inputs
.iter()
.map(|v| v.raw as *const sys::Value)
.collect::<Vec<_>>();
let mut outputs = outputs.iter().map(|v| v.raw).collect::<Vec<_>>();
let output_size = output_names.len() as u64;
unsafe {
checked_call!(
Run,
self.raw,
options.raw,
input_names_ptrs.as_ptr(),
inputs.as_ptr(),
inputs.len() as u64,
output_names_ptrs.as_ptr(),
output_size,
outputs.as_mut_ptr()
)?;
Ok(())
}
}
pub fn run_raw(
&self,
@ -426,6 +457,7 @@ mod tests {
let input_tensor = Tensor::new(mem_info, input_shape, input_data)?;
// immutable version
let output = session.run_raw(&ro, &input_names, &[input_tensor.value()], &output_names)?;
let output_value = output.into_iter().next().unwrap();
@ -436,6 +468,26 @@ mod tests {
&[1. + 2. * 2., 3. + 2. * 4., 5. + 2. * 6.]
);
// mutable version
let mem_info = MemoryInfo::cpu_memory_info(AllocatorType::ArenaAllocator, MemType::Cpu)?;
let output_data: Vec<f32> = vec![0.0; 3];
let output_shape = vec![3, 1];
let mut output_tensor = Tensor::new(mem_info, output_shape, output_data)?;
session.run_mut(
&ro,
&input_names,
&[input_tensor.value()],
&output_names,
&[output_tensor.value_mut()],
)?;
assert_eq!(
&output_tensor[..],
&[1. + 2. * 2., 3. + 2. * 4., 5. + 2. * 6.]
);
Ok(())
}
}