Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions kernels/optimized/blas/CPUBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,86 @@ extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void
#endif // ET_BUILD_FOR_APPLE
#endif // ET_BUILD_WITH_BLAS

#if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE)
#if defined(__linux__) && defined(ET_USE_THREADPOOL)
#include <executorch/extension/threadpool/threadpool_guard.h>
#endif // defined(__linux__) && defined(ET_USE_THREADPOOL)

// Some host BLAS backends (notably MKL) parallelize gemm internally with their
// own OpenMP thread team. ExecuTorch already parallelizes operators across its
// own threadpool, and kernels such as the optimized SDPA call gemm from inside
// a pthreadpool worker thread. Letting the BLAS spin up a nested OpenMP team
// from that worker crashes on Linux x86 hosts (SEGV in __kmp_create_worker /
// KMP_UBER_GTID).
//
// When a NoThreadPoolGuard is active on this thread, force the BLAS
// single-threaded for the duration of the gemm call and restore afterwards.
// ExecuTorch enables the guard on its threadpool workers (the nested case);
// top-level callers may also enable it to force single-threaded execution.
//
// The whole mechanism is gated on defined(__linux__) &&
// defined(ET_USE_THREADPOOL):
// - ET_USE_THREADPOOL: without ExecuTorch's threadpool there are no worker
// threads to nest from, so there is nothing to constrain -- and we avoid any
// dependency on the threadpool extension (NoThreadPoolGuard) in BLAS builds
// that do not link it.
// - __linux__: the nested-OpenMP crash is specific to the Linux x86 host
// iomp5/MKL BLAS path, and the "undefined weak symbol resolves to null"
// behavior the steering relies on is ELF-specific. On macOS/Mach-O a plain
// weak undefined symbol fails to link, and Windows is MSVC (no weak
// symbols), so on both the guard must compile to a no-op.
//
// The OpenMP nthreads-var ICV written by omp_set_num_threads / read by
// omp_get_max_threads is per-thread (one copy per data environment per the
// OpenMP spec), so each threadpool worker captures and restores its own value;
// no cross-thread synchronization is needed. The symbols are declared weak so
// this steers the OpenMP runtime the BLAS has already loaded (e.g. MKL's iomp5)
// WITHOUT compiling this translation unit with -fopenmp (which could link a
// second, conflicting OpenMP runtime); when no OpenMP-threaded BLAS is linked
// they stay null and the guard is a no-op (e.g. OSS Eigen).
#if defined(__linux__) && defined(ET_USE_THREADPOOL)
extern "C" __attribute__((weak)) int omp_get_max_threads(void);
extern "C" __attribute__((weak)) void omp_set_num_threads(int);
#endif // defined(__linux__) && defined(ET_USE_THREADPOOL)

namespace {
class ScopedSingleThreadBlas {
public:
ScopedSingleThreadBlas() {
#if defined(__linux__) && defined(ET_USE_THREADPOOL)
// Only constrain the BLAS when a NoThreadPoolGuard is active on this
// thread; otherwise leave gemm free to use the threaded BLAS.
if (!::executorch::extension::threadpool::NoThreadPoolGuard::is_enabled()) {
return;
}
if (omp_get_max_threads != nullptr && omp_set_num_threads != nullptr) {
prev_num_threads_ = omp_get_max_threads();
if (prev_num_threads_ > 1) {
omp_set_num_threads(1);
restore_ = true;
}
}
#endif // defined(__linux__) && defined(ET_USE_THREADPOOL)
}
~ScopedSingleThreadBlas() {
#if defined(__linux__) && defined(ET_USE_THREADPOOL)
if (restore_) {
omp_set_num_threads(prev_num_threads_);
}
#endif // defined(__linux__) && defined(ET_USE_THREADPOOL)
}
ScopedSingleThreadBlas(const ScopedSingleThreadBlas&) = delete;
ScopedSingleThreadBlas& operator=(const ScopedSingleThreadBlas&) = delete;
ScopedSingleThreadBlas(ScopedSingleThreadBlas&&) = delete;
ScopedSingleThreadBlas& operator=(ScopedSingleThreadBlas&&) = delete;

private:
[[maybe_unused]] int prev_num_threads_ = 1;
[[maybe_unused]] bool restore_ = false;
};
} // namespace
#endif // defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE)

namespace executorch {
namespace cpublas {

Expand Down Expand Up @@ -88,6 +168,9 @@ void gemm(
#ifdef ET_BUILD_FOR_APPLE
cblas_dgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
#else
// See note above: avoid a nested OpenMP team when called from inside an
// ExecuTorch threadpool worker.
ScopedSingleThreadBlas single_thread_blas;
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
double alpha_ = alpha, beta_ = beta;
char transa_ = to_blas(transa), transb_ = to_blas(transb);
Expand Down Expand Up @@ -128,6 +211,9 @@ void gemm(
#ifdef ET_BUILD_FOR_APPLE
cblas_sgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
#else
// Avoid a nested OpenMP team if the BLAS (e.g. MKL) is multithreaded and we
// are already running inside an ExecuTorch threadpool worker. See note above.
ScopedSingleThreadBlas single_thread_blas;
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
float alpha_ = alpha, beta_ = beta;
char transa_ = to_blas(transa), transb_ = to_blas(transb);
Expand Down Expand Up @@ -211,6 +297,9 @@ void gemm(
complex<double> *c, int64_t ldc) {
normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
#if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE)
// See note above: avoid a nested OpenMP team when called from inside an
// ExecuTorch threadpool worker.
ScopedSingleThreadBlas single_thread_blas;
complex<double> alpha_ = alpha, beta_ = beta;
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
char transa_ = to_blas(transa), transb_ = to_blas(transb);
Expand Down Expand Up @@ -247,6 +336,9 @@ void gemm(
complex<float> *c, int64_t ldc) {
normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
#if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE)
// See note above: avoid a nested OpenMP team when called from inside an
// ExecuTorch threadpool worker.
ScopedSingleThreadBlas single_thread_blas;
complex<float> alpha_ = alpha, beta_ = beta;
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
char transa_ = to_blas(transa), transb_ = to_blas(transb);
Expand Down
Loading