diff --git a/kernels/optimized/blas/CPUBlas.cpp b/kernels/optimized/blas/CPUBlas.cpp index 51a4f1ca26b..047e43830a5 100644 --- a/kernels/optimized/blas/CPUBlas.cpp +++ b/kernels/optimized/blas/CPUBlas.cpp @@ -23,6 +23,86 @@ extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void #endif // ET_BUILD_FOR_APPLE #endif // ET_BUILD_WITH_BLAS +#if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE) +#if defined(__linux__) && defined(ET_USE_THREADPOOL) +#include +#endif // defined(__linux__) && defined(ET_USE_THREADPOOL) + +// Some host BLAS backends (notably MKL) parallelize gemm internally with their +// own OpenMP thread team. ExecuTorch already parallelizes operators across its +// own threadpool, and kernels such as the optimized SDPA call gemm from inside +// a pthreadpool worker thread. Letting the BLAS spin up a nested OpenMP team +// from that worker crashes on Linux x86 hosts (SEGV in __kmp_create_worker / +// KMP_UBER_GTID). +// +// When a NoThreadPoolGuard is active on this thread, force the BLAS +// single-threaded for the duration of the gemm call and restore afterwards. +// ExecuTorch enables the guard on its threadpool workers (the nested case); +// top-level callers may also enable it to force single-threaded execution. +// +// The whole mechanism is gated on defined(__linux__) && +// defined(ET_USE_THREADPOOL): +// - ET_USE_THREADPOOL: without ExecuTorch's threadpool there are no worker +// threads to nest from, so there is nothing to constrain -- and we avoid any +// dependency on the threadpool extension (NoThreadPoolGuard) in BLAS builds +// that do not link it. +// - __linux__: the nested-OpenMP crash is specific to the Linux x86 host +// iomp5/MKL BLAS path, and the "undefined weak symbol resolves to null" +// behavior the steering relies on is ELF-specific. On macOS/Mach-O a plain +// weak undefined symbol fails to link, and Windows is MSVC (no weak +// symbols), so on both the guard must compile to a no-op. +// +// The OpenMP nthreads-var ICV written by omp_set_num_threads / read by +// omp_get_max_threads is per-thread (one copy per data environment per the +// OpenMP spec), so each threadpool worker captures and restores its own value; +// no cross-thread synchronization is needed. The symbols are declared weak so +// this steers the OpenMP runtime the BLAS has already loaded (e.g. MKL's iomp5) +// WITHOUT compiling this translation unit with -fopenmp (which could link a +// second, conflicting OpenMP runtime); when no OpenMP-threaded BLAS is linked +// they stay null and the guard is a no-op (e.g. OSS Eigen). +#if defined(__linux__) && defined(ET_USE_THREADPOOL) +extern "C" __attribute__((weak)) int omp_get_max_threads(void); +extern "C" __attribute__((weak)) void omp_set_num_threads(int); +#endif // defined(__linux__) && defined(ET_USE_THREADPOOL) + +namespace { +class ScopedSingleThreadBlas { + public: + ScopedSingleThreadBlas() { +#if defined(__linux__) && defined(ET_USE_THREADPOOL) + // Only constrain the BLAS when a NoThreadPoolGuard is active on this + // thread; otherwise leave gemm free to use the threaded BLAS. + if (!::executorch::extension::threadpool::NoThreadPoolGuard::is_enabled()) { + return; + } + if (omp_get_max_threads != nullptr && omp_set_num_threads != nullptr) { + prev_num_threads_ = omp_get_max_threads(); + if (prev_num_threads_ > 1) { + omp_set_num_threads(1); + restore_ = true; + } + } +#endif // defined(__linux__) && defined(ET_USE_THREADPOOL) + } + ~ScopedSingleThreadBlas() { +#if defined(__linux__) && defined(ET_USE_THREADPOOL) + if (restore_) { + omp_set_num_threads(prev_num_threads_); + } +#endif // defined(__linux__) && defined(ET_USE_THREADPOOL) + } + ScopedSingleThreadBlas(const ScopedSingleThreadBlas&) = delete; + ScopedSingleThreadBlas& operator=(const ScopedSingleThreadBlas&) = delete; + ScopedSingleThreadBlas(ScopedSingleThreadBlas&&) = delete; + ScopedSingleThreadBlas& operator=(ScopedSingleThreadBlas&&) = delete; + + private: + [[maybe_unused]] int prev_num_threads_ = 1; + [[maybe_unused]] bool restore_ = false; +}; +} // namespace +#endif // defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE) + namespace executorch { namespace cpublas { @@ -88,6 +168,9 @@ void gemm( #ifdef ET_BUILD_FOR_APPLE cblas_dgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); #else + // See note above: avoid a nested OpenMP team when called from inside an + // ExecuTorch threadpool worker. + ScopedSingleThreadBlas single_thread_blas; int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; double alpha_ = alpha, beta_ = beta; char transa_ = to_blas(transa), transb_ = to_blas(transb); @@ -128,6 +211,9 @@ void gemm( #ifdef ET_BUILD_FOR_APPLE cblas_sgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); #else + // Avoid a nested OpenMP team if the BLAS (e.g. MKL) is multithreaded and we + // are already running inside an ExecuTorch threadpool worker. See note above. + ScopedSingleThreadBlas single_thread_blas; int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; float alpha_ = alpha, beta_ = beta; char transa_ = to_blas(transa), transb_ = to_blas(transb); @@ -211,6 +297,9 @@ void gemm( complex *c, int64_t ldc) { normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE) + // See note above: avoid a nested OpenMP team when called from inside an + // ExecuTorch threadpool worker. + ScopedSingleThreadBlas single_thread_blas; complex alpha_ = alpha, beta_ = beta; int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; char transa_ = to_blas(transa), transb_ = to_blas(transb); @@ -247,6 +336,9 @@ void gemm( complex *c, int64_t ldc) { normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE) + // See note above: avoid a nested OpenMP team when called from inside an + // ExecuTorch threadpool worker. + ScopedSingleThreadBlas single_thread_blas; complex alpha_ = alpha, beta_ = beta; int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; char transa_ = to_blas(transa), transb_ = to_blas(transb);