From df70bd17b596ee5b9143723af0f599a9e4848b3e Mon Sep 17 00:00:00 2001 From: Luca Barbato Date: Mon, 2 Jan 2017 02:43:49 +0000 Subject: [PATCH] cuda: Wrap the Device Management --- src/cuda/device.rs | 80 ++++++++++++++++++++++++++++++++++++++++++++++ src/cuda/mod.rs | 1 + src/lib.rs | 2 +- 3 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 src/cuda/device.rs create mode 100644 src/cuda/mod.rs diff --git a/src/cuda/device.rs b/src/cuda/device.rs new file mode 100644 index 0000000..41b8b3e --- /dev/null +++ b/src/cuda/device.rs @@ -0,0 +1,80 @@ +use std::os::raw::c_int; +use std::os::raw::c_char; + +use ffi::cuda::*; +use ffi::cuda::cudaError_enum::*; + +struct CuDevice { + device: CUdevice, +} + +macro_rules! wrap { + ($val:ident, $res:ident) => ( + if $res == CUDA_SUCCESS { + Ok($val) + } else { + Err($res) + } + ) +} + +impl CuDevice { + pub fn new(ordinal: c_int) -> Result { + let mut d = CuDevice { device: 0 }; + let res = unsafe { cuDeviceGet(&mut d.device as *mut i32, ordinal) }; + + wrap!(d, res) + } + + pub fn get_attribute(&self, attr: CUdevice_attribute) -> Result { + let mut pi = 0; + let res = unsafe { cuDeviceGetAttribute(&mut pi as *mut i32, attr, self.device) }; + + wrap!(pi, res) + } + + pub fn get_name(&self) -> Result { + let mut name = vec![0; 256]; + let res = unsafe { + cuDeviceGetName(name.as_mut_ptr() as *mut c_char, + name.len() as i32, + self.device) + }; + let val = String::from_utf8(name).unwrap(); + + wrap!(val, res) + } + + pub fn get_total_mem(&self) -> Result { + let mut val = 0; + let res = unsafe { cuDeviceTotalMem_v2(&mut val as *mut usize, self.device) }; + + wrap!(val, res) + } +} + +pub fn get_count() -> Result { + let mut val = 0; + let res = unsafe { cuDeviceGetCount(&mut val as *mut i32) }; + + wrap!(val, res) +} + +#[cfg(test)] +mod tests { + use ffi::cuda::cuInit; + use super::*; + + #[test] + fn device_enum() { + let _ = unsafe { cuInit(0) }; + + for i in 0..get_count().unwrap() { + let dev = CuDevice::new(i).unwrap(); + + println!("{} {}", + dev.get_name().unwrap(), + dev.get_total_mem().unwrap()); + } + } +} diff --git a/src/cuda/mod.rs b/src/cuda/mod.rs new file mode 100644 index 0000000..5458924 --- /dev/null +++ b/src/cuda/mod.rs @@ -0,0 +1 @@ +pub mod device; diff --git a/src/lib.rs b/src/lib.rs index 226b3ed..7856669 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ mod ffi; -// pub mod cuda; +pub mod cuda; #[cfg(test)] mod tests {