Skip to content

Accelerate fast_matmul with packing and tiling#767

Open
ThreeMonth03 wants to merge 1 commit into
solvcon:masterfrom
ThreeMonth03:accel_matmul
Open

Accelerate fast_matmul with packing and tiling#767
ThreeMonth03 wants to merge 1 commit into
solvcon:masterfrom
ThreeMonth03:accel_matmul

Conversation

@ThreeMonth03
Copy link
Copy Markdown
Collaborator

@ThreeMonth03 ThreeMonth03 commented May 10, 2026

This pull request is to accelerate SimpleArray::matmul(), see issue #715.
In order to compare with baseline SimpleArray::matmul(), I implement new api named SimpleArray::fast_matmul(), which is accelerated by packing and tiling technique.

The following chart and sheet records the measurement results. np corresponds to numpy::matmul(), naive_sa corresponds to original SimpleArray::matmul(), and fast_np_{tile_size_x}_{tile_size_y}_{tile_size_z} corresponds to SimpleArray::fast_matmul() with different tile size.

The observation is interesting, we notice that when SimpleArray is big, SimpleArray::fast_matmul() with small tile is faster. On the contrary, when SimpleArray is small, SimpleArray::fast_matmul() with big tile is faster.

image image

2D x 2D shape: (4, 4) x (4, 4) dtype:float32

func per call (ms) cmp to np
np 1.818E-01 1.000
naive_sa 2.404E-03 0.013
fast_sa_16_16_16 2.296E-03 0.013
fast_sa_32_32_32 1.875E-03 0.010
fast_sa_64_64_64 1.821E-03 0.010

2D x 2D shape: (16, 16) x (16, 16) dtype:float32

func per call (ms) cmp to np
np 2.170E-01 1.000
naive_sa 4.079E-03 0.019
fast_sa_16_16_16 4.004E-03 0.018
fast_sa_32_32_32 3.783E-03 0.017
fast_sa_64_64_64 3.692E-03 0.017

2D x 2D shape: (64, 64) x (64, 64) dtype:float32

func per call (ms) cmp to np
np 2.241E-02 1.000
naive_sa 2.438E-01 10.878
fast_sa_16_16_16 8.967E-02 4.001
fast_sa_32_32_32 1.058E-01 4.720
fast_sa_64_64_64 1.391E-01 6.208

2D x 2D shape: (256, 256) x (256, 256) dtype:float32

func per call (ms) cmp to np
np 4.621E-02 1.000
naive_sa 2.444E+01 528.759
fast_sa_16_16_16 5.206E+00 112.646
fast_sa_32_32_32 6.301E+00 136.342
fast_sa_64_64_64 8.446E+00 182.766

2D x 2D shape: (1024, 1024) x (1024, 1024) dtype:float32

func per call (ms) cmp to np
np 3.203E+00 1.000
naive_sa 2.065E+03 644.640
fast_sa_16_16_16 3.308E+02 103.265
fast_sa_32_32_32 4.138E+02 129.167
fast_sa_64_64_64 5.546E+02 173.138

2D x 2D shape: (4, 4) x (4, 4) dtype:float64

func per call (ms) cmp to np
np 4.696E-03 1.000
naive_sa 1.654E-03 0.352
fast_sa_16_16_16 2.396E-03 0.510
fast_sa_32_32_32 1.837E-03 0.391
fast_sa_64_64_64 1.763E-03 0.375

2D x 2D shape: (16, 16) x (16, 16) dtype:float64

func per call (ms) cmp to np
np 1.889E-01 1.000
naive_sa 4.654E-03 0.025
fast_sa_16_16_16 4.771E-03 0.025
fast_sa_32_32_32 4.017E-03 0.021
fast_sa_64_64_64 3.825E-03 0.020

2D x 2D shape: (64, 64) x (64, 64) dtype:float64

func per call (ms) cmp to np
np 6.288E-03 1.000
naive_sa 2.476E-01 39.384
fast_sa_16_16_16 9.311E-02 14.809
fast_sa_32_32_32 1.195E-01 19.000
fast_sa_64_64_64 1.641E-01 26.106

2D x 2D shape: (256, 256) x (256, 256) dtype:float64

func per call (ms) cmp to np
np 1.621E-01 1.000
naive_sa 2.486E+01 153.326
fast_sa_16_16_16 5.521E+00 34.054
fast_sa_32_32_32 7.124E+00 43.943
fast_sa_64_64_64 1.032E+01 63.680

2D x 2D shape: (1024, 1024) x (1024, 1024) dtype:float64

func per call (ms) cmp to np
np 1.165E+01 1.000
naive_sa 2.142E+03 183.825
fast_sa_16_16_16 4.227E+02 36.281
fast_sa_32_32_32 4.961E+02 42.577
fast_sa_64_64_64 6.815E+02 58.492

2D x 2D shape: (9, 9) x (9, 9) dtype:float32

func per call (ms) cmp to np
np 3.254E-03 1.000
naive_sa 1.842E-03 0.566
fast_sa_16_16_16 2.791E-03 0.858
fast_sa_32_32_32 2.450E-03 0.753
fast_sa_64_64_64 2.437E-03 0.749

2D x 2D shape: (27, 27) x (27, 27) dtype:float32

func per call (ms) cmp to np
np 1.400E-01 1.000
naive_sa 1.388E-02 0.099
fast_sa_16_16_16 1.286E-02 0.092
fast_sa_32_32_32 1.163E-02 0.083
fast_sa_64_64_64 1.157E-02 0.083

2D x 2D shape: (81, 81) x (81, 81) dtype:float32

func per call (ms) cmp to np
np 7.596E-03 1.000
naive_sa 5.659E-01 74.499
fast_sa_16_16_16 1.868E-01 24.599
fast_sa_32_32_32 2.133E-01 28.084
fast_sa_64_64_64 2.624E-01 34.550

2D x 2D shape: (243, 243) x (243, 243) dtype:float32

func per call (ms) cmp to np
np 5.318E-02 1.000
naive_sa 2.011E+01 378.090
fast_sa_16_16_16 4.678E+00 87.972
fast_sa_32_32_32 5.561E+00 104.579
fast_sa_64_64_64 7.337E+00 137.963

2D x 2D shape: (729, 729) x (729, 729) dtype:float32

func per call (ms) cmp to np
np 1.196E+00 1.000
naive_sa 7.124E+02 595.856
fast_sa_16_16_16 1.247E+02 104.295
fast_sa_32_32_32 1.486E+02 124.256
fast_sa_64_64_64 1.983E+02 165.877

2D x 2D shape: (9, 9) x (9, 9) dtype:float64

func per call (ms) cmp to np
np 3.454E-03 1.000
naive_sa 2.008E-03 0.581
fast_sa_16_16_16 2.617E-03 0.758
fast_sa_32_32_32 2.187E-03 0.633
fast_sa_64_64_64 2.187E-03 0.633

2D x 2D shape: (27, 27) x (27, 27) dtype:float64

func per call (ms) cmp to np
np 1.416E-01 1.000
naive_sa 1.424E-02 0.101
fast_sa_16_16_16 1.165E-02 0.082
fast_sa_32_32_32 1.119E-02 0.079
fast_sa_64_64_64 1.110E-02 0.078

2D x 2D shape: (81, 81) x (81, 81) dtype:float64

func per call (ms) cmp to np
np 1.694E-02 1.000
naive_sa 5.645E-01 33.326
fast_sa_16_16_16 2.009E-01 11.864
fast_sa_32_32_32 2.347E-01 13.857
fast_sa_64_64_64 3.144E-01 18.562

2D x 2D shape: (243, 243) x (243, 243) dtype:float64

func per call (ms) cmp to np
np 1.535E-01 1.000
naive_sa 2.021E+01 131.674
fast_sa_16_16_16 4.912E+00 32.000
fast_sa_32_32_32 6.214E+00 40.480
fast_sa_64_64_64 8.631E+00 56.226

2D x 2D shape: (729, 729) x (729, 729) dtype:float64

func per call (ms) cmp to np
np 4.339E+00 1.000
naive_sa 7.178E+02 165.433
fast_sa_16_16_16 1.333E+02 30.710
fast_sa_32_32_32 1.689E+02 38.931
fast_sa_64_64_64 2.377E+02 54.786

By the way, to validate the usefulness of packing, I also profile SimpleArray::fast_matmul() without tiling. The following chart and sheet are data. It is clear that packing could make operation faster.
image
image

2D x 2D shape: (4, 4) x (4, 4) dtype:float32

func per call (ms) cmp to np
np 2.244E-01 1.000
naive_sa 1.691E-03 0.008
fast_sa 1.338E-03 0.006

2D x 2D shape: (16, 16) x (16, 16) dtype:float32

func per call (ms) cmp to np
np 2.208E-01 1.000
naive_sa 4.083E-03 0.018
fast_sa 2.704E-03 0.012

2D x 2D shape: (64, 64) x (64, 64) dtype:float32

func per call (ms) cmp to np
np 2.905E-02 1.000
naive_sa 2.395E-01 8.245
fast_sa 1.325E-01 4.559

2D x 2D shape: (256, 256) x (256, 256) dtype:float32

func per call (ms) cmp to np
np 4.985E-02 1.000
naive_sa 2.443E+01 490.041
fast_sa 1.485E+01 297.920

2D x 2D shape: (1024, 1024) x (1024, 1024) dtype:float32

func per call (ms) cmp to np
np 3.297E+00 1.000
naive_sa 2.182E+03 661.923
fast_sa 1.473E+03 446.771

2D x 2D shape: (4, 4) x (4, 4) dtype:float64

func per call (ms) cmp to np
np 5.475E-03 1.000
naive_sa 2.304E-03 0.421
fast_sa 1.480E-03 0.270

2D x 2D shape: (16, 16) x (16, 16) dtype:float64

func per call (ms) cmp to np
np 3.051E-01 1.000
naive_sa 4.504E-03 0.015
fast_sa 2.817E-03 0.009

2D x 2D shape: (64, 64) x (64, 64) dtype:float64

func per call (ms) cmp to np
np 6.013E-03 1.000
naive_sa 2.360E-01 39.256
fast_sa 1.577E-01 26.221

2D x 2D shape: (256, 256) x (256, 256) dtype:float64

func per call (ms) cmp to np
np 1.607E-01 1.000
naive_sa 2.467E+01 153.484
fast_sa 1.461E+01 90.939

2D x 2D shape: (1024, 1024) x (1024, 1024) dtype:float64

func per call (ms) cmp to np
np 1.195E+01 1.000
naive_sa 2.179E+03 182.362
fast_sa 1.570E+03 131.418

2D x 2D shape: (9, 9) x (9, 9) dtype:float32

func per call (ms) cmp to np
np 5.587E-03 1.000
naive_sa 3.092E-03 0.553
fast_sa 3.008E-03 0.538

2D x 2D shape: (27, 27) x (27, 27) dtype:float32

func per call (ms) cmp to np
np 8.550E-02 1.000
naive_sa 2.400E-02 0.281
fast_sa 2.574E-02 0.301

2D x 2D shape: (81, 81) x (81, 81) dtype:float32

func per call (ms) cmp to np
np 4.090E-02 1.000
naive_sa 8.716E-01 21.310
fast_sa 5.251E-01 12.839

2D x 2D shape: (243, 243) x (243, 243) dtype:float32

func per call (ms) cmp to np
np 1.098E-01 1.000
naive_sa 2.485E+01 226.286
fast_sa 1.503E+01 136.813

2D x 2D shape: (729, 729) x (729, 729) dtype:float32

func per call (ms) cmp to np
np 1.211E+00 1.000
naive_sa 7.165E+02 591.653
fast_sa 4.984E+02 411.572

2D x 2D shape: (9, 9) x (9, 9) dtype:float64

func per call (ms) cmp to np
np 4.358E-03 1.000
naive_sa 2.167E-03 0.497
fast_sa 1.546E-03 0.355

2D x 2D shape: (27, 27) x (27, 27) dtype:float64

func per call (ms) cmp to np
np 1.303E-01 1.000
naive_sa 1.437E-02 0.110
fast_sa 1.032E-02 0.079

2D x 2D shape: (81, 81) x (81, 81) dtype:float64

func per call (ms) cmp to np
np 1.648E-02 1.000
naive_sa 5.687E-01 34.510
fast_sa 3.455E-01 20.968

2D x 2D shape: (243, 243) x (243, 243) dtype:float64

func per call (ms) cmp to np
np 1.824E-01 1.000
naive_sa 2.037E+01 111.698
fast_sa 1.283E+01 70.352

2D x 2D shape: (729, 729) x (729, 729) dtype:float64

func per call (ms) cmp to np
np 4.266E+00 1.000
naive_sa 7.688E+02 180.208
fast_sa 5.700E+02 133.619

@yungyuc yungyuc added performance Profiling, runtime, and memory consumption array Multi-dimensional array implementation labels May 10, 2026
Copy link
Copy Markdown
Member

@yungyuc yungyuc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably do not need to test arrays smaller than the tile size. An array smaller than a cache line for sure will run fast.

  • Discussion: Rename the function fast_matmul() to tiled_matmul()?
  • Use a helper class (modmesh::detail::HelperName) to implement the function body of fast_matmul()

.def("matmul", &wrapped_type::matmul)
.def("__matmul__", &wrapped_type::matmul)
.def(
"fast_matmul",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about calling it tiled_matmul()?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I will change it although we will try to optimize by more technique, like simd or Strassen-Winograd algorithm.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I suggest to stay with fast_matmul() because it will not be just tiling. The plan can be added in the comment at the C++ implementation.

This PR already has 300 lines of code and can focus on tiling. Other optimization can use later PRs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I keep the original api name fast_matmul().

Comment thread cpp/modmesh/buffer/SimpleArray.hpp Outdated
A SimpleArrayMixinCalculators<A, T>::fast_matmul(A const & other,
size_t tile_x,
size_t tile_y,
size_t tile_z) const // NOLINT(readability-function-cognitive-complexity)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is too long. Clang-tidy also thinks so (NOLINT(readability-function-cognitive-complexity)).

Could you please try to use a helper class (modmesh::detail::HelperName) to implement the function body of fast_matmul()?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I decide to create a new helper class. I would also wrap the original matrix multiplication operation in helper class.

@yungyuc yungyuc moved this to In Progress in tensor operations May 10, 2026
@ThreeMonth03 ThreeMonth03 force-pushed the accel_matmul branch 4 times, most recently from 24cc347 to 452b888 Compare May 12, 2026 15:42
Copy link
Copy Markdown
Collaborator Author

@ThreeMonth03 ThreeMonth03 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Discussion: Rename the function fast_matmul() to tiled_matmul()?
  • Use a helper class (modmesh::detail::HelperName) to implement the function body of fast_matmul()

@yungyuc Please take a look. Thanks

.def("matmul", &wrapped_type::matmul)
.def("__matmul__", &wrapped_type::matmul)
.def(
"fast_matmul",
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I keep the original api name fast_matmul().

Comment on lines +134 to +136
template <typename A, typename T>
class SimpleArrayMatmulHelper
{
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implement the matrix multiplication in SimpleArrayMatmulHelper.

Comment thread tests/test_matrix.py
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add unittest for fast_matmul()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that you compare the results between matmul(), and fast_matmul() in all test functions that had matmul()?

@ThreeMonth03 ThreeMonth03 force-pushed the accel_matmul branch 2 times, most recently from 7488350 to b7762a7 Compare May 12, 2026 17:16
Copy link
Copy Markdown
Member

@yungyuc yungyuc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Rename SimpleArrayMatmulHelper::fast_matmul() to matmul_fast() for consistency in the class.
  • Discussion: Add appropriate member data to make SimpleArrayMatmulHelper static functions non-static when they fit if it helps maintainability.
  • For SimpleArrayMatmulHelper::matmul_mat_mat(), make the result array a member datum so that caller may choose to reuse it.
  • Confirm: Did you compare the results between matmul(), and fast_matmul() in all test functions that had matmul()?

Comment thread cpp/modmesh/buffer/SimpleArray.hpp Outdated
using shape_type = typename internal_types::shape_type;

static A matmul(A const & lhs, A const & rhs);
static A fast_matmul(A const & lhs,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmul_fast() may be a more consistent name for this class (than fast_matmul(). You have helper functions matmul_*() (e.g. matmul_vec_vec() in this class too. Placing fast before matmul looks inconsistent.

Comment thread cpp/modmesh/buffer/SimpleArray.hpp Outdated
private:

static std::string shape_str(A const & arr);
static void check_dims(A const & lhs, A const & rhs);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the static functions use lhs and rhs. Consider to make them non-static and the arguments as member data m_lhs and m_rhs.

I am guessing it will save you some lines and make the code cleaner. But that is a guess. If it does not, I am OK to keep the implementation using static.

Comment thread cpp/modmesh/buffer/SimpleArray.hpp Outdated
size_t const k = lhs.shape(1);
size_t const n = rhs.shape(1);
shape_type const result_shape{m, n};
A result(result_shape);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make the result array a member datum so that caller may choose to reuse it.

Comment thread tests/test_matrix.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that you compare the results between matmul(), and fast_matmul() in all test functions that had matmul()?

@ThreeMonth03
Copy link
Copy Markdown
Collaborator Author

@yungyuc By the way, I notice that the CI runner devbuild / build_macos-26_Release (pull_request) fails due to the long compile time. Should I remove static function of class SimpleArrayMatmulHelper in SimpleArray and keep the static function in pybind11?

For example:

void wrap_SimpleArray(pybind11::module & mod, char const * pyname)
{
    namespace py = pybind11;

    using value_type = typename A::value_type;

    py::class_<A>(mod, pyname)
        .def(
            "fast_matmul",
            [](A const & lhs,
               A const & rhs,
               size_t tile_x,
               size_t tile_y,
               size_t tile_z)
            {
                return detail::SimpleArrayMatmulHelper<A, value_type>::fast_matmul(
                    lhs,
                    rhs,
                    tile_x,
                    tile_y,
                    tile_z);
            },
            py::arg("rhs"),
            py::arg("tile_x") = 16,
            py::arg("tile_y") = 16,
            py::arg("tile_z") = 16);
}

@yungyuc
Copy link
Copy Markdown
Member

yungyuc commented May 13, 2026

@yungyuc By the way, I notice that the CI runner devbuild / build_macos-26_Release (pull_request) fails due to the long compile time. Should I remove static function of class SimpleArrayMatmulHelper in SimpleArray and keep the static function in pybind11?

Please check if the long compile time is intermittent or really related to your code change. You can create a PR in your fork to trigger a run.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

array Multi-dimensional array implementation performance Profiling, runtime, and memory consumption

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

2 participants