Accelerate fast_matmul with packing and tiling#767
Conversation
There was a problem hiding this comment.
We probably do not need to test arrays smaller than the tile size. An array smaller than a cache line for sure will run fast.
- Discussion: Rename the function
fast_matmul()totiled_matmul()? - Use a helper class (
modmesh::detail::HelperName) to implement the function body offast_matmul()
| .def("matmul", &wrapped_type::matmul) | ||
| .def("__matmul__", &wrapped_type::matmul) | ||
| .def( | ||
| "fast_matmul", |
There was a problem hiding this comment.
How about calling it tiled_matmul()?
There was a problem hiding this comment.
Good idea, I will change it although we will try to optimize by more technique, like simd or Strassen-Winograd algorithm.
There was a problem hiding this comment.
In that case, I suggest to stay with fast_matmul() because it will not be just tiling. The plan can be added in the comment at the C++ implementation.
This PR already has 300 lines of code and can focus on tiling. Other optimization can use later PRs.
There was a problem hiding this comment.
OK, I keep the original api name fast_matmul().
| A SimpleArrayMixinCalculators<A, T>::fast_matmul(A const & other, | ||
| size_t tile_x, | ||
| size_t tile_y, | ||
| size_t tile_z) const // NOLINT(readability-function-cognitive-complexity) |
There was a problem hiding this comment.
This function is too long. Clang-tidy also thinks so (NOLINT(readability-function-cognitive-complexity)).
Could you please try to use a helper class (modmesh::detail::HelperName) to implement the function body of fast_matmul()?
There was a problem hiding this comment.
Yes, I decide to create a new helper class. I would also wrap the original matrix multiplication operation in helper class.
24cc347 to
452b888
Compare
ThreeMonth03
left a comment
There was a problem hiding this comment.
- Discussion: Rename the function fast_matmul() to tiled_matmul()?
- Use a helper class (modmesh::detail::HelperName) to implement the function body of fast_matmul()
@yungyuc Please take a look. Thanks
| .def("matmul", &wrapped_type::matmul) | ||
| .def("__matmul__", &wrapped_type::matmul) | ||
| .def( | ||
| "fast_matmul", |
There was a problem hiding this comment.
OK, I keep the original api name fast_matmul().
| template <typename A, typename T> | ||
| class SimpleArrayMatmulHelper | ||
| { |
There was a problem hiding this comment.
I implement the matrix multiplication in SimpleArrayMatmulHelper.
There was a problem hiding this comment.
Add unittest for fast_matmul()
There was a problem hiding this comment.
Is it correct that you compare the results between matmul(), and fast_matmul() in all test functions that had matmul()?
7488350 to
b7762a7
Compare
There was a problem hiding this comment.
- Rename
SimpleArrayMatmulHelper::fast_matmul()tomatmul_fast()for consistency in the class. - Discussion: Add appropriate member data to make
SimpleArrayMatmulHelperstatic functions non-static when they fit if it helps maintainability. - For
SimpleArrayMatmulHelper::matmul_mat_mat(), make the result array a member datum so that caller may choose to reuse it. - Confirm: Did you compare the results between
matmul(), andfast_matmul()in all test functions that hadmatmul()?
| using shape_type = typename internal_types::shape_type; | ||
|
|
||
| static A matmul(A const & lhs, A const & rhs); | ||
| static A fast_matmul(A const & lhs, |
There was a problem hiding this comment.
matmul_fast() may be a more consistent name for this class (than fast_matmul(). You have helper functions matmul_*() (e.g. matmul_vec_vec() in this class too. Placing fast before matmul looks inconsistent.
| private: | ||
|
|
||
| static std::string shape_str(A const & arr); | ||
| static void check_dims(A const & lhs, A const & rhs); |
There was a problem hiding this comment.
All the static functions use lhs and rhs. Consider to make them non-static and the arguments as member data m_lhs and m_rhs.
I am guessing it will save you some lines and make the code cleaner. But that is a guess. If it does not, I am OK to keep the implementation using static.
| size_t const k = lhs.shape(1); | ||
| size_t const n = rhs.shape(1); | ||
| shape_type const result_shape{m, n}; | ||
| A result(result_shape); |
There was a problem hiding this comment.
Make the result array a member datum so that caller may choose to reuse it.
There was a problem hiding this comment.
Is it correct that you compare the results between matmul(), and fast_matmul() in all test functions that had matmul()?
|
@yungyuc By the way, I notice that the CI runner devbuild / build_macos-26_Release (pull_request) fails due to the long compile time. Should I remove static function of For example: void wrap_SimpleArray(pybind11::module & mod, char const * pyname)
{
namespace py = pybind11;
using value_type = typename A::value_type;
py::class_<A>(mod, pyname)
.def(
"fast_matmul",
[](A const & lhs,
A const & rhs,
size_t tile_x,
size_t tile_y,
size_t tile_z)
{
return detail::SimpleArrayMatmulHelper<A, value_type>::fast_matmul(
lhs,
rhs,
tile_x,
tile_y,
tile_z);
},
py::arg("rhs"),
py::arg("tile_x") = 16,
py::arg("tile_y") = 16,
py::arg("tile_z") = 16);
} |
Please check if the long compile time is intermittent or really related to your code change. You can create a PR in your fork to trigger a run. |
b7762a7 to
f32d51a
Compare
f32d51a to
253e478
Compare
This pull request is to accelerate
SimpleArray::matmul(), see issue #715.In order to compare with baseline
SimpleArray::matmul(), I implement new api namedSimpleArray::fast_matmul(), which is accelerated by packing and tiling technique.The following chart and sheet records the measurement results.
npcorresponds tonumpy::matmul(),naive_sacorresponds to originalSimpleArray::matmul(), andfast_np_{tile_size_x}_{tile_size_y}_{tile_size_z}corresponds toSimpleArray::fast_matmul()with different tile size.The observation is interesting, we notice that when SimpleArray is big,
SimpleArray::fast_matmul()with small tile is faster. On the contrary, when SimpleArray is small,SimpleArray::fast_matmul()with big tile is faster.2D x 2D shape: (4, 4) x (4, 4) dtype:
float322D x 2D shape: (16, 16) x (16, 16) dtype:
float322D x 2D shape: (64, 64) x (64, 64) dtype:
float322D x 2D shape: (256, 256) x (256, 256) dtype:
float322D x 2D shape: (1024, 1024) x (1024, 1024) dtype:
float322D x 2D shape: (4, 4) x (4, 4) dtype:
float642D x 2D shape: (16, 16) x (16, 16) dtype:
float642D x 2D shape: (64, 64) x (64, 64) dtype:
float642D x 2D shape: (256, 256) x (256, 256) dtype:
float642D x 2D shape: (1024, 1024) x (1024, 1024) dtype:
float642D x 2D shape: (9, 9) x (9, 9) dtype:
float322D x 2D shape: (27, 27) x (27, 27) dtype:
float322D x 2D shape: (81, 81) x (81, 81) dtype:
float322D x 2D shape: (243, 243) x (243, 243) dtype:
float322D x 2D shape: (729, 729) x (729, 729) dtype:
float322D x 2D shape: (9, 9) x (9, 9) dtype:
float642D x 2D shape: (27, 27) x (27, 27) dtype:
float642D x 2D shape: (81, 81) x (81, 81) dtype:
float642D x 2D shape: (243, 243) x (243, 243) dtype:
float642D x 2D shape: (729, 729) x (729, 729) dtype:
float64By the way, to validate the usefulness of packing, I also profile


SimpleArray::fast_matmul()without tiling. The following chart and sheet are data. It is clear that packing could make operation faster.2D x 2D shape: (4, 4) x (4, 4) dtype:
float322D x 2D shape: (16, 16) x (16, 16) dtype:
float322D x 2D shape: (64, 64) x (64, 64) dtype:
float322D x 2D shape: (256, 256) x (256, 256) dtype:
float322D x 2D shape: (1024, 1024) x (1024, 1024) dtype:
float322D x 2D shape: (4, 4) x (4, 4) dtype:
float642D x 2D shape: (16, 16) x (16, 16) dtype:
float642D x 2D shape: (64, 64) x (64, 64) dtype:
float642D x 2D shape: (256, 256) x (256, 256) dtype:
float642D x 2D shape: (1024, 1024) x (1024, 1024) dtype:
float642D x 2D shape: (9, 9) x (9, 9) dtype:
float322D x 2D shape: (27, 27) x (27, 27) dtype:
float322D x 2D shape: (81, 81) x (81, 81) dtype:
float322D x 2D shape: (243, 243) x (243, 243) dtype:
float322D x 2D shape: (729, 729) x (729, 729) dtype:
float322D x 2D shape: (9, 9) x (9, 9) dtype:
float642D x 2D shape: (27, 27) x (27, 27) dtype:
float642D x 2D shape: (81, 81) x (81, 81) dtype:
float642D x 2D shape: (243, 243) x (243, 243) dtype:
float642D x 2D shape: (729, 729) x (729, 729) dtype:
float64