-
Notifications
You must be signed in to change notification settings - Fork 276
Add torch.Tensor fast path for StridedMemoryView via AOTI tensor bridge #1894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
82ad598
Add torch.Tensor fast path for StridedMemoryView via AOTI tensor bridge
leofang f8f8d8c
Clean up tensor bridge: remove unused AOTI decls, lazy dtype, drop em…
leofang af06e9b
Move torch tensor fast path into each from_* classmethod
leofang 44be580
Add stream ordering for torch tensor bridge
leofang 6e6b8a6
Extract reusable sync_torch_stream and apply to CAI path
leofang 85caaaf
Nits: add check_aoti helper, size_t itemsize, 2D sliced test
leofang 9fad471
Revert itemsize to int, memoize int(stream_ptr)
leofang cc4558a
Use except?-1 instead of except* for check_aoti
leofang 5f49e7a
Require PyTorch >= 2.3 for tensor bridge, move imports to module level
leofang b98fe71
Add tensor bridge entry to 1.0.0 release notes
leofang 30ba7d5
Update speedup range in release notes to match benchmarks
leofang 0f57646
Document THPVariable layout change across PyTorch versions
leofang 74798e7
Cache type check in _is_torch_tensor for ~20% speedup
leofang 00b8ec9
Add upper bound to torch version check (cap at 2.11)
leofang 0c31df1
Update module docstring to document both THPVariable layouts
leofang 8c20237
Use except?-1 for sync_torch_stream instead of except*
leofang 8c019b9
Fix linter errors
leofang 6682646
Fix pyobj_to_aten_handle for PyTorch 2.3–2.9 MaybeOwned layout
leofang 0b7245b
Consolidate torch tensor bridge tests into TestViewCPU/TestViewGPU
leofang 626736a
Extract _arr_size helper for torch/numpy size compatibility
leofang d1d3841
Fix ruff formatting in test_utils.py
leofang b9d80e7
Add readonly comment and fix vendored header license to BSD-3-Clause
leofang 7d46123
Merge bfloat16 test into test_torch_tensor_bridge_dtypes parametrization
leofang c7331a9
Fix SPDX linter: use PyTorch copyright in vendored header
leofang 0e75229
Fix Windows build: generate stub import library for AOTI symbols
leofang 37fce1a
Merge branch 'main' into tensor-bridge-749
leofang 7f5dda6
Exclude torch DLLs from delvewheel repair on Windows
leofang f6a3032
Merge branch 'tensor-bridge-749' of https://github.com/leofang/cuda-p…
leofang 2748a52
Fix delvewheel flag: use --exclude instead of --no-dll
leofang d543be1
Merge branch 'main' into tensor-bridge-749
leofang 84ff2ec
Merge branch 'main' into tensor-bridge-749
leofang a66c0d0
fix merge conflict resolution
leofang 833bcf8
[pre-commit.ci] auto code formatting
pre-commit-ci[bot] 615f984
Merge branch 'main' into tensor-bridge-749
leofang b5ec10d
Add strided layout guard to tensor bridge, reject sparse tensors
leofang 1ac154a
Revert strided layout guard (symbols missing in torch 2.3–2.8)
leofang 6b0dffe
Address review comments: dtypes, stale cache, stream_ptr, sync notes
leofang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| ; Stub import library definition for PyTorch's AOTI stable C ABI symbols. | ||
| ; Used on Windows only: 'lib /DEF:aoti_shim.def /OUT:aoti_shim.lib /MACHINE:X64' | ||
| ; generates a minimal import library that satisfies the MSVC linker. | ||
| ; At runtime the symbols resolve from torch_cpu.dll (loaded by 'import torch'). | ||
| ; | ||
| ; IMPORTANT: Keep this export list in sync with the AOTI_SHIM_API declarations | ||
| ; in aoti_shim.h. build_hooks.py turns this file into the stub import library | ||
| ; that MSVC uses to link _tensor_bridge, so any added/removed/renamed AOTI | ||
| ; symbol must be updated in both files. | ||
| LIBRARY torch_cpu.dll | ||
| EXPORTS | ||
| aoti_torch_get_data_ptr | ||
| aoti_torch_get_dim | ||
| aoti_torch_get_sizes | ||
| aoti_torch_get_strides | ||
| aoti_torch_get_dtype | ||
| aoti_torch_dtype_float16 | ||
| aoti_torch_dtype_float32 | ||
| aoti_torch_dtype_float64 | ||
| aoti_torch_dtype_bfloat16 | ||
| aoti_torch_dtype_uint8 | ||
| aoti_torch_dtype_uint16 | ||
| aoti_torch_dtype_uint32 | ||
| aoti_torch_dtype_uint64 | ||
| aoti_torch_dtype_int8 | ||
| aoti_torch_dtype_int16 | ||
| aoti_torch_dtype_int32 | ||
| aoti_torch_dtype_int64 | ||
| aoti_torch_dtype_bool | ||
| aoti_torch_dtype_complex32 | ||
| aoti_torch_dtype_complex64 | ||
| aoti_torch_dtype_complex128 | ||
| aoti_torch_get_device_type | ||
| aoti_torch_get_device_index | ||
| aoti_torch_device_type_cpu | ||
| aoti_torch_device_type_cuda | ||
| aoti_torch_get_current_cuda_stream | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| /* | ||
| * Vendored subset of PyTorch's AOT Inductor (AOTI) stable C ABI. | ||
| * Original: torch/csrc/inductor/aoti_torch/c/shim.h | ||
| * | ||
| * These are declarations only -- no definitions are provided. The actual | ||
| * symbols are exported by libtorch (loaded via torch._C with RTLD_GLOBAL) | ||
| * and resolved at runtime by the dynamic linker. This means PyTorch is | ||
| * NOT required at compile time. | ||
| * | ||
| * From PyTorch: | ||
| * | ||
| * Copyright (c) 2016- Facebook, Inc (Adam Paszke) | ||
| * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | ||
| * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | ||
| * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | ||
| * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | ||
| * Copyright (c) 2011-2013 NYU (Clement Farabet) | ||
| * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | ||
| * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | ||
| * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | ||
| * | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| * See https://github.com/pytorch/pytorch/blob/main/LICENSE | ||
| */ | ||
|
|
||
| #ifndef CUDA_CORE_AOTI_SHIM_H | ||
| #define CUDA_CORE_AOTI_SHIM_H | ||
|
|
||
| #include <stdint.h> | ||
|
|
||
| /* | ||
| * On Windows the AOTI symbols live in torch_cpu.dll. We consume them | ||
| * via __declspec(dllimport) and a stub import library generated from | ||
| * aoti_shim.def at build time. On Linux/macOS the symbols are made | ||
| * visible at runtime through ctypes.CDLL(torch._C, RTLD_GLOBAL). | ||
| */ | ||
| #ifdef _WIN32 | ||
| # define AOTI_SHIM_API __declspec(dllimport) | ||
| #else | ||
| # define AOTI_SHIM_API | ||
| #endif | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| typedef int32_t AOTITorchError; | ||
|
|
||
| /* Opaque tensor handle -- corresponds to at::Tensor on the C++ side. */ | ||
| struct AtenTensorOpaque; | ||
| typedef struct AtenTensorOpaque* AtenTensorHandle; | ||
|
|
||
|
leofang marked this conversation as resolved.
|
||
| /* | ||
| * IMPORTANT: Keep the AOTI_SHIM_API declaration list below in sync with | ||
| * aoti_shim.def. On Windows, build_hooks.py turns that .def file into the | ||
| * stub import library that MSVC needs to link _tensor_bridge without making | ||
| * PyTorch a build-time dependency. If you add, remove, or rename an imported | ||
| * AOTI symbol here, update aoti_shim.def in the same change. | ||
| */ | ||
|
|
||
| /* ---- tensor metadata --------------------------------------------------- */ | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_data_ptr( | ||
| AtenTensorHandle tensor, void** ret_data_ptr); | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_dim( | ||
| AtenTensorHandle tensor, int64_t* ret_dim); | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_sizes( | ||
| AtenTensorHandle tensor, int64_t** ret_sizes); | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_strides( | ||
| AtenTensorHandle tensor, int64_t** ret_strides); | ||
|
|
||
| /* ---- dtype ------------------------------------------------------------- */ | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_dtype( | ||
| AtenTensorHandle tensor, int32_t* ret_dtype); | ||
|
|
||
| AOTI_SHIM_API int32_t aoti_torch_dtype_float16(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_float32(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_float64(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_bfloat16(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_uint8(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_uint16(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_uint32(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_uint64(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_int8(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_int16(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_int32(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_int64(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_bool(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_complex32(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_complex64(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_complex128(void); | ||
|
|
||
| /* ---- device ------------------------------------------------------------ */ | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_device_type( | ||
| AtenTensorHandle tensor, int32_t* ret_device_type); | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_device_index( | ||
| AtenTensorHandle tensor, int32_t* ret_device_index); | ||
|
|
||
| AOTI_SHIM_API int32_t aoti_torch_device_type_cpu(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_device_type_cuda(void); | ||
|
|
||
| /* ---- stream -------------------------------------------------------------- */ | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_current_cuda_stream( | ||
| int32_t device_index, void** ret_stream); | ||
|
|
||
| #ifdef __cplusplus | ||
| } /* extern "C" */ | ||
| #endif | ||
|
|
||
| #endif /* CUDA_CORE_AOTI_SHIM_H */ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.