Fix handling of status pointer

This commit is contained in:
Basile Henry 2020-05-14 11:14:15 +02:00
parent 4a41e3059d
commit 7e9c5ae322

View File

@ -1,4 +1,3 @@
use std::convert::TryFrom;
use std::ffi; use std::ffi;
use std::os::raw::c_char; use std::os::raw::c_char;
use std::ptr; use std::ptr;
@ -28,9 +27,9 @@ macro_rules! call {
macro_rules! checked_call { macro_rules! checked_call {
($name:ident, $($arg:expr),*) => {{ ($name:ident, $($arg:expr),*) => {{
let status = call!($name, $($arg),*); let status = call!($name, $($arg),*);
match OrtError::try_from(Status::from(status)) { match Status::new(status) {
Ok(err) => Err(Error::OrtError(err)), Some(status) => Err(Error::OrtError(status)),
Err(()) => Ok(()), None => Ok(()),
} }
}} }}
} }
@ -52,7 +51,6 @@ macro_rules! ort_type {
ort_type!(Session, ReleaseSession); ort_type!(Session, ReleaseSession);
ort_type!(SessionOptions, ReleaseSessionOptions); ort_type!(SessionOptions, ReleaseSessionOptions);
ort_type!(Env, ReleaseEnv); ort_type!(Env, ReleaseEnv);
ort_type!(Status, ReleaseStatus);
ort_type!(MemoryInfo, ReleaseMemoryInfo); ort_type!(MemoryInfo, ReleaseMemoryInfo);
ort_type!(Value, ReleaseValue); ort_type!(Value, ReleaseValue);
ort_type!(RunOptions, ReleaseRunOptions); ort_type!(RunOptions, ReleaseRunOptions);
@ -64,14 +62,14 @@ ort_type!(SequenceTypeInfo, ReleaseSequenceTypeInfo);
ort_type!(ModelMetadata, ReleaseModelMetadata); ort_type!(ModelMetadata, ReleaseModelMetadata);
ort_type!(ThreadingOptions, ReleaseThreadingOptions); ort_type!(ThreadingOptions, ReleaseThreadingOptions);
pub struct OrtError { pub struct Status {
pub error_code: ErrorCode, pub error_code: ErrorCode,
pub error_msg: String, pub error_msg: String,
} }
pub enum Error { pub enum Error {
NulStringError(ffi::NulError), NulStringError(ffi::NulError),
OrtError(OrtError), OrtError(Status),
} }
type Result<T> = std::result::Result<T, Error>; type Result<T> = std::result::Result<T, Error>;
@ -82,27 +80,21 @@ fn to_c_string(s: &str) -> Result<*const c_char> {
.as_ptr()) .as_ptr())
} }
impl From<*mut sys::Status> for Status { impl Status {
fn from(raw: *mut sys::Status) -> Status { fn new(raw: *mut sys::Status) -> Option<Status> {
Status { raw }
}
}
impl TryFrom<Status> for OrtError {
type Error = ();
fn try_from(status: Status) -> std::result::Result<OrtError, ()> {
unsafe { unsafe {
if status.raw.is_null() { if raw.is_null() {
return Err(()); return None;
} }
let error_code = call!(GetErrorCode, status.raw); let error_code = call!(GetErrorCode, raw);
let error_msg = ffi::CStr::from_ptr(call!(GetErrorMessage, status.raw)) let error_msg = ffi::CStr::from_ptr(call!(GetErrorMessage, raw))
.to_string_lossy() .to_string_lossy()
.into_owned(); .into_owned();
Ok(OrtError { call!(ReleaseStatus, raw);
Some(Status {
error_code, error_code,
error_msg, error_msg,
}) })