Skip to content

[FEAT][STUBGEN] Add Rust code generation backend#609

Open
Seven-Streams wants to merge 3 commits into
apache:mainfrom
Seven-Streams:main-dev/2026-06-04/rust_stubgen
Open

[FEAT][STUBGEN] Add Rust code generation backend#609
Seven-Streams wants to merge 3 commits into
apache:mainfrom
Seven-Streams:main-dev/2026-06-04/rust_stubgen

Conversation

@Seven-Streams

Copy link
Copy Markdown

This PR is based on #608.

Summary

This PR adds Rust support to the stub generator.

Rust bindings can now be generated through both CMake and the CLI:

  • In CMake, set STUB_TARGET=rust.
  • In the CLI, use tvm-ffi-stubgen <generated_dir> --target rust ....

This PR also adds documentation, examples, and tests for Rust stub generation.

Key Changes

  • Added a Rust backend in python/tvm_ffi/stubgen/rust_generator.
  • Added unit tests and string-level tests for Rust stub generation in test_stubgen.py.
  • Added documentation in docs/packaging/stubgen.rst.
  • Added Rust stubgen examples in examples/rust_stubgen.
  • Extended the CMake integration to support Rust stub generation.
  • Updated the Rust package to support generated bindings.

Example

Given the following C++ definition:

/*! \brief Data object: a pair of 64-bit integers `a` and `b`. */
class IntPairObj : public ffi::Object {
 public:
  int64_t a;
  int64_t b;

  IntPairObj() = default;
  explicit IntPairObj(int64_t a, int64_t b) : a(a), b(b) {}

  /*! \brief Sum of the two components. */
  int64_t sum() const { return a + b; }

  // All fields are writable, so the generated Rust wrapper gets `DerefMut`.
  static constexpr bool _type_mutable = true;
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("rust_stubgen.IntPair", IntPairObj, ffi::Object);
};

/*! \brief Reference wrapper for `IntPairObj`. */
class IntPair : public ffi::ObjectRef {
 public:
  explicit IntPair(int64_t a, int64_t b) { data_ = ffi::make_object<IntPairObj>(a, b); }

  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IntPair, ffi::ObjectRef, IntPairObj);
};

TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;

  refl::ObjectDef<IntPairObj>()
      .def(refl::init<int64_t, int64_t>())
      .def_rw("a", &IntPairObj::a, "first component")
      .def_rw("b", &IntPairObj::b, "second component")
      .def("sum", &IntPairObj::sum, "a + b");

  // Lets an `AnyView` holding this object convert back into the `IntPair` ref,
  // which is what the generated Rust bindings rely on for object returns.
  refl::TypeAttrDef<IntPairObj>().def(refl::type_attr::kConvert,
                                      &refl::details::FFIConvertFromAnyViewToObjectRef<IntPair>);
}

The Rust stub generator produces Rust wrappers for the reflected objects and methods. For example, the generated bindings include:

// tvm-ffi-stubgen(begin): helpers
fn lookup_type_index(type_key: &'static str) -> i32 {
    static CACHE: std::sync::OnceLock<
        std::sync::Mutex<std::collections::HashMap<&'static str, i32>>,
    > = std::sync::OnceLock::new();
    let cache = CACHE.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()));
    if let Some(v) = cache.lock().unwrap().get(type_key) {
        return *v;
    }
    let arg = unsafe { tvm_ffi::tvm_ffi_sys::TVMFFIByteArray::from_str(type_key) };
    let mut tindex = 0;
    let ret = unsafe { tvm_ffi::tvm_ffi_sys::TVMFFITypeKeyToIndex(&arg, &mut tindex) };
    assert_eq!(ret, 0, "type key `{type_key}` is not registered");
    cache.lock().unwrap().insert(type_key, tindex);
    tindex
}

fn get_type_method(
    type_key: &'static str,
    method_name: &str,
) -> tvm_ffi::Result<tvm_ffi::Function> {
    let type_index = lookup_type_index(type_key);
    unsafe {
        let info = tvm_ffi::tvm_ffi_sys::TVMFFIGetTypeInfo(type_index);
        if info.is_null() {
            return Err(tvm_ffi::Error::new(
                tvm_ffi::TYPE_ERROR,
                &format!("no type info for `{type_key}`"),
                "",
            ));
        }
        let info = &*info;
        for i in 0..info.num_methods {
            let mi = &*info.methods.add(i as usize);
            if mi.name.as_str() == method_name {
                if !<tvm_ffi::Function as tvm_ffi::type_traits::AnyCompatible>::check_any_strict(
                    &mi.method,
                ) {
                    return Err(tvm_ffi::Error::new(
                        tvm_ffi::TYPE_ERROR,
                        &format!("method `{method_name}` on `{type_key}` is not a Function"),
                        "",
                    ));
                }
                return Ok(<tvm_ffi::Function as tvm_ffi::type_traits::AnyCompatible>::copy_from_any_view_after_check(&mi.method));
            }
        }
    }
    Err(tvm_ffi::Error::new(
        tvm_ffi::TYPE_ERROR,
        &format!("method `{method_name}` not found on `{type_key}`"),
        "",
    ))
}
// tvm-ffi-stubgen(end)

// tvm-ffi-stubgen(begin): object/rust_stubgen.IntPair
#[repr(C)]
pub struct IntPairObj {
    base: Object,
    pub a: i64,
    pub b: i64,
}

unsafe impl ObjectCore for IntPairObj {
    const TYPE_KEY: &'static str = "rust_stubgen.IntPair";

    fn type_index() -> i32 {
        lookup_type_index(Self::TYPE_KEY)
    }

    unsafe fn object_header_mut(this: &mut Self) -> &mut TVMFFIObject {
        Object::object_header_mut(&mut this.base)
    }
}

#[repr(C)]
#[derive(DeriveObjectRef, Clone)]
pub struct IntPair {
    data: ObjectArc<IntPairObj>,
}

impl Deref for IntPair {
    type Target = IntPairObj;
    fn deref(&self) -> &IntPairObj {
        &self.data
    }
}

impl DerefMut for IntPair {
    fn deref_mut(&mut self) -> &mut IntPairObj {
        &mut self.data
    }
}

impl IntPair {
    pub fn new(_0: i64, _1: i64) -> Result<Self> {
        let ctor = get_type_method(IntPairObj::TYPE_KEY, "__ffi_init__")?;
        Ok(ctor.call_packed(&[AnyView::from(&_0), AnyView::from(&_1)])?.try_into()?)
    }

    pub fn sum(&mut self) -> Result<i64> {
        let f = get_type_method(IntPairObj::TYPE_KEY, "sum")?;
        Ok(f.call_packed(&[AnyView::from(&*self)])?.try_into()?)
    }
}

The generated helper functions are responsible for runtime type lookup and method dispatch, allowing the generated object wrappers to invoke reflected methods through the TVM FFI runtime.

Testing

  • Added unit tests and string-level tests in test_stubgen.py.
  • End-to-end tests are attached in the PR comments.

e2e_test.tar.gz

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Rust code generator backend to tvm-ffi-stubgen, enabling the generation of Rust object bindings from C++ reflection metadata. It restructures the tool with a pluggable generator architecture, updates CMake and CLI configurations, adds comprehensive documentation, and provides an end-to-end example. Additionally, the tvm-ffi Rust runtime is updated to support passing container types as arguments. The review feedback highlights several critical improvements: using rpath linker flags in build.rs for runtime library discovery, handling forbidden Rust keywords (like self and super) by appending an underscore since they cannot be raw identifiers, adding a null check to inc_ref_raw_object to prevent undefined behavior, and caching resolved FFI functions with std::sync::OnceLock to eliminate performance bottlenecks during method invocations.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +29 to +44
fn update_runtime_library_env(lib_dir: &str) {
let os_env_var = match env::var("CARGO_CFG_TARGET_OS").as_deref() {
Ok("windows") => "PATH",
Ok("macos") => "DYLD_LIBRARY_PATH",
Ok("linux") => "LD_LIBRARY_PATH",
_ => return,
};
let current_val = env::var(os_env_var).unwrap_or_default();
let separator = if os_env_var == "PATH" { ";" } else { ":" };
let new_val = if current_val.is_empty() {
lib_dir.to_string()
} else {
format!("{current_val}{separator}{lib_dir}")
};
println!("cargo:rustc-env={os_env_var}={new_val}");
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using cargo:rustc-env only sets the environment variable for the compiler (and makes it available via the env! macro) during the compilation of the current crate. It does not set or update the environment variable for the runtime execution of the compiled binary (e.g., when running cargo run --example demo).

To make the dynamic library discoverable at runtime without requiring the user to manually set LD_LIBRARY_PATH/DYLD_LIBRARY_PATH, you should pass rpath linker flags instead.

fn update_runtime_library_env(lib_dir: &str) {
    match env::var("CARGO_CFG_TARGET_OS").as_deref() {
        Ok("linux") => println!("cargo:rustc-link-arg=-Wl,-rpath,{lib_dir}"),
        Ok("macos") => println!("cargo:rustc-link-arg=-Wl,-rpath,{lib_dir}"),
        _ => {}
    }
}

Comment on lines +236 to +240
def _rust_ident(name: str) -> str:
"""Make ``name`` a usable Rust identifier (raw-escape keywords)."""
if name in C.RUST_KEYWORDS and name not in C.RUST_RAW_IDENT_FORBIDDEN:
return f"r#{name}"
return name

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If a C++ field or parameter is named after one of the forbidden Rust keywords (such as self, Self, super, or crate), returning the name as-is will result in invalid Rust code (e.g., pub self: i64), which fails to compile because these keywords cannot be used as raw identifiers. We should handle these forbidden keywords by appending an underscore (e.g., self_) or using another renaming scheme to ensure the generated code is valid.

Suggested change
def _rust_ident(name: str) -> str:
"""Make ``name`` a usable Rust identifier (raw-escape keywords)."""
if name in C.RUST_KEYWORDS and name not in C.RUST_RAW_IDENT_FORBIDDEN:
return f"r#{name}"
return name
def _rust_ident(name: str) -> str:
"""Make ``name`` a usable Rust identifier (raw-escape keywords)."""
if name in C.RUST_RAW_IDENT_FORBIDDEN:
return f"{name}_"
if name in C.RUST_KEYWORDS:
return f"r#{name}"
return name

Comment on lines +121 to +124
#[inline]
pub unsafe fn inc_ref_raw_object(handle: *mut TVMFFIObject) {
unsafe_::inc_ref(handle)
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If handle is null, dereferencing it inside unsafe_::inc_ref will cause undefined behavior (segmentation fault). To make this API more robust and prevent crashes when dealing with potentially null raw pointers (e.g., from nullable object references), we should add a null check before calling inc_ref.

Suggested change
#[inline]
pub unsafe fn inc_ref_raw_object(handle: *mut TVMFFIObject) {
unsafe_::inc_ref(handle)
}
#[inline]
pub unsafe fn inc_ref_raw_object(handle: *mut TVMFFIObject) {
if !handle.is_null() {
unsafe_::inc_ref(handle)
}
}

if method.is_member or params:
_use(self.imports, "tvm_ffi::AnyView")
packed = _packed_args_expr(params, method.is_member)
getter = f' let f = get_type_method({self.obj_struct}::TYPE_KEY, "{ffi_name}")?;'

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling get_type_method on every single method invocation is a major performance bottleneck because it performs a global Mutex lock, a C FFI call (TVMFFIGetTypeInfo), and a string comparison loop over all methods of the type. Since the reflected methods of a registered type do not change at runtime, we can cache the resolved tvm_ffi::Function using std::sync::OnceLock inside the generated method to make subsequent calls extremely fast.

Suggested change
getter = f' let f = get_type_method({self.obj_struct}::TYPE_KEY, "{ffi_name}")?;'
getter = (
f' static F: std::sync::OnceLock<tvm_ffi::Function> = std::sync::OnceLock::new();\n'
f' let f = F.get_or_init(|| get_type_method({self.obj_struct}::TYPE_KEY, "{ffi_name}").unwrap());'
)

@tqchen

tqchen commented Jun 6, 2026

Copy link
Copy Markdown
Member

thanks @Seven-Streams some quick notes:

  • i think we should be able to look up type_index = lookup_type_index(type_key); per type which makes it faster than the global hash map
  • We might be able to update https://github.com/apache/tvm-ffi/blob/main/rust/tvm-ffi-macros/src/object_macros.rs to generate some type index fetch related bolier plate
  • same remark likely applies to ctor since we could have a global static one lock for ctor per type if needed
    • because we are in rust, likely we don't necessarily need the ffi ctor, which is slower, instead, directly construct the object via rust API and allocation would be preferred

Signed-off-by: yuchuan <yuchuan.7streams@gmail.com>
Signed-off-by: yuchuan <yuchuan.7streams@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants