From 4fd738fa1804045e8f554477647d9254e68be186 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 5 Mar 2026 21:28:52 -0500 Subject: [PATCH 1/3] register lowering for conv_fwd_jvp_p and conv_bwd_jvp_p --- .../openequivariance/jax/jvp/conv_prim.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index ca91014..99185b0 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -161,6 +161,12 @@ def conv_fwd_jvp_abstract_eval( conv_fwd_jvp_p.def_impl(conv_fwd_jvp_impl) conv_fwd_jvp_p.def_abstract_eval(conv_fwd_jvp_abstract_eval) +mlir.register_lowering( + conv_fwd_jvp_p, mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False), platform="cuda" +) +mlir.register_lowering( + conv_fwd_jvp_p, mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False), platform="rocm" +) # ============================================================================== @@ -285,6 +291,12 @@ def conv_bwd_jvp_abstract_eval( conv_bwd_jvp_p.def_impl(conv_bwd_jvp_impl) conv_bwd_jvp_p.def_abstract_eval(conv_bwd_jvp_abstract_eval) +mlir.register_lowering( + conv_bwd_jvp_p, mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True), platform="cuda" +) +mlir.register_lowering( + conv_bwd_jvp_p, mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True), platform="rocm" +) # ============================================================================== From bad6b82f8208d6753cefdaccf9bedbefda8b351c Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 5 Mar 2026 22:48:16 -0500 Subject: [PATCH 2/3] prek -a --- .../openequivariance/jax/jvp/conv_prim.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index 99185b0..8d38646 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -162,10 +162,14 @@ def conv_fwd_jvp_abstract_eval( conv_fwd_jvp_p.def_impl(conv_fwd_jvp_impl) conv_fwd_jvp_p.def_abstract_eval(conv_fwd_jvp_abstract_eval) mlir.register_lowering( - conv_fwd_jvp_p, mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False), platform="cuda" + conv_fwd_jvp_p, + mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False), + platform="cuda", ) mlir.register_lowering( - conv_fwd_jvp_p, mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False), platform="rocm" + conv_fwd_jvp_p, + mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False), + platform="rocm", ) @@ -292,10 +296,14 @@ def conv_bwd_jvp_abstract_eval( conv_bwd_jvp_p.def_impl(conv_bwd_jvp_impl) conv_bwd_jvp_p.def_abstract_eval(conv_bwd_jvp_abstract_eval) mlir.register_lowering( - conv_bwd_jvp_p, mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True), platform="cuda" + conv_bwd_jvp_p, + mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True), + platform="cuda", ) mlir.register_lowering( - conv_bwd_jvp_p, mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True), platform="rocm" + conv_bwd_jvp_p, + mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True), + platform="rocm", ) From e7959616cd2e394ddc3eb78bd407c089bd28eb3c Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 5 Mar 2026 22:51:46 -0500 Subject: [PATCH 3/3] add tp too --- .../openequivariance/jax/jvp/tp_prim.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/jax/jvp/tp_prim.py b/openequivariance/openequivariance/jax/jvp/tp_prim.py index c31c3ec..3745f24 100644 --- a/openequivariance/openequivariance/jax/jvp/tp_prim.py +++ b/openequivariance/openequivariance/jax/jvp/tp_prim.py @@ -132,6 +132,16 @@ def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): tp_fwd_jvp_p.def_impl(tp_fwd_jvp_impl) tp_fwd_jvp_p.def_abstract_eval(tp_fwd_jvp_abstract_eval) +mlir.register_lowering( + tp_fwd_jvp_p, + mlir.lower_fun(tp_fwd_jvp_impl, multiple_results=False), + platform="cuda", +) +mlir.register_lowering( + tp_fwd_jvp_p, + mlir.lower_fun(tp_fwd_jvp_impl, multiple_results=False), + platform="rocm", +) # ============================================================================== @@ -225,7 +235,16 @@ def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): tp_bwd_jvp_p.def_impl(tp_bwd_jvp_impl) tp_bwd_jvp_p.def_abstract_eval(tp_bwd_jvp_abstract_eval) - +mlir.register_lowering( + tp_bwd_jvp_p, + mlir.lower_fun(tp_bwd_jvp_impl, multiple_results=True), + platform="cuda", +) +mlir.register_lowering( + tp_bwd_jvp_p, + mlir.lower_fun(tp_bwd_jvp_impl, multiple_results=True), + platform="rocm", +) # ============================================================================== # 9. Transpose Rule for Backward JVP