Fix handling of status pointer

This commit is contained in:
Basile Henry 2020-05-14 11:14:15 +02:00
parent 4a41e3059d
commit 7e9c5ae322

View File

@ -1,4 +1,3 @@
use std::convert::TryFrom;
use std::ffi;
use std::os::raw::c_char;
use std::ptr;
@ -28,9 +27,9 @@ macro_rules! call {
macro_rules! checked_call {
($name:ident, $($arg:expr),*) => {{
let status = call!($name, $($arg),*);
match OrtError::try_from(Status::from(status)) {
Ok(err) => Err(Error::OrtError(err)),
Err(()) => Ok(()),
match Status::new(status) {
Some(status) => Err(Error::OrtError(status)),
None => Ok(()),
}
}}
}
@ -52,7 +51,6 @@ macro_rules! ort_type {
ort_type!(Session, ReleaseSession);
ort_type!(SessionOptions, ReleaseSessionOptions);
ort_type!(Env, ReleaseEnv);
ort_type!(Status, ReleaseStatus);
ort_type!(MemoryInfo, ReleaseMemoryInfo);
ort_type!(Value, ReleaseValue);
ort_type!(RunOptions, ReleaseRunOptions);
@ -64,14 +62,14 @@ ort_type!(SequenceTypeInfo, ReleaseSequenceTypeInfo);
ort_type!(ModelMetadata, ReleaseModelMetadata);
ort_type!(ThreadingOptions, ReleaseThreadingOptions);
pub struct OrtError {
pub struct Status {
pub error_code: ErrorCode,
pub error_msg: String,
}
pub enum Error {
NulStringError(ffi::NulError),
OrtError(OrtError),
OrtError(Status),
}
type Result<T> = std::result::Result<T, Error>;
@ -82,27 +80,21 @@ fn to_c_string(s: &str) -> Result<*const c_char> {
.as_ptr())
}
impl From<*mut sys::Status> for Status {
fn from(raw: *mut sys::Status) -> Status {
Status { raw }
}
}
impl TryFrom<Status> for OrtError {
type Error = ();
fn try_from(status: Status) -> std::result::Result<OrtError, ()> {
impl Status {
fn new(raw: *mut sys::Status) -> Option<Status> {
unsafe {
if status.raw.is_null() {
return Err(());
if raw.is_null() {
return None;
}
let error_code = call!(GetErrorCode, status.raw);
let error_msg = ffi::CStr::from_ptr(call!(GetErrorMessage, status.raw))
let error_code = call!(GetErrorCode, raw);
let error_msg = ffi::CStr::from_ptr(call!(GetErrorMessage, raw))
.to_string_lossy()
.into_owned();
Ok(OrtError {
call!(ReleaseStatus, raw);
Some(Status {
error_code,
error_msg,
})