onnxruntime-rs/build.rs

73 lines
2.1 KiB
Rust
Raw Normal View History

2020-05-13 22:47:26 +04:00
use bindgen::callbacks::{EnumVariantValue, ParseCallbacks};
use heck::CamelCase;
2020-05-09 20:03:04 +04:00
use std::env;
use std::path::PathBuf;
fn main() {
2020-05-10 22:38:23 +04:00
println!("cargo:rerun-if-env-changed=ONNXRUNTIME_LIB_DIR");
println!("cargo:rerun-if-env-changed=ONNXRUNTIME_INCLUDE_DIR");
match std::env::var("ONNXRUNTIME_LIB_DIR") {
Ok(path) => println!("cargo:rustc-link-search={}", path),
Err(_) => (),
};
let mut clang_args = String::new();
match std::env::var("ONNXRUNTIME_INCLUDE_DIR") {
Ok(path) => {
clang_args = format!("-I{}", path);
2020-05-13 00:28:27 +04:00
}
2020-05-10 22:38:23 +04:00
Err(_) => (),
};
2020-05-10 15:08:30 +04:00
println!("cargo:rustc-link-lib=onnxruntime");
2020-05-09 20:03:04 +04:00
let bindings = bindgen::Builder::default()
2020-05-10 15:08:30 +04:00
.header("cbits/ort.h")
2020-05-10 22:38:23 +04:00
.clang_arg(clang_args)
2020-05-09 20:03:04 +04:00
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
2020-05-10 15:20:23 +04:00
.whitelist_function("OrtGetApiBase")
.whitelist_var("ORT_.*")
.whitelist_recursively(true)
2020-05-13 22:47:26 +04:00
.blacklist_type("__int64_t")
.blacklist_type("__uint32_t")
.rustified_non_exhaustive_enum("*")
.parse_callbacks(Box::new(CustomEnums))
2020-05-10 15:08:30 +04:00
.layout_tests(false)
2020-05-09 20:03:04 +04:00
.generate()
.expect("Unable to generate bindings");
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}
2020-05-13 22:47:26 +04:00
#[derive(Debug)]
struct CustomEnums;
impl ParseCallbacks for CustomEnums {
fn enum_variant_name(
&self,
enum_name: Option<&str>,
variant_name: &str,
_variant_value: EnumVariantValue,
) -> Option<String> {
let mut variant_name = variant_name.to_camel_case();
if let Some(enum_name) = enum_name {
let enum_name = enum_name.replace("enum ", "").to_camel_case();
if variant_name.starts_with(&enum_name) {
variant_name = variant_name.replace(&enum_name, "");
}
}
if variant_name.starts_with("Ort") {
variant_name = variant_name.replace("Ort", "");
}
Some(variant_name)
}
}