|
8 | 8 | #include "resource_handles_cxx_api.hpp" |
9 | 9 | #include <cuda.h> |
10 | 10 | #include <cstdint> |
| 11 | +#include <cstring> |
11 | 12 | #include <mutex> |
12 | 13 | #include <unordered_map> |
13 | 14 | #include <vector> |
@@ -704,58 +705,97 @@ DevicePtrHandle deviceptr_create_with_owner(CUdeviceptr ptr, PyObject* owner) { |
704 | 705 | // The first cuMemFreeAsync incorrectly unmaps the memory even when the pointer |
705 | 706 | // was imported multiple times. We work around this by caching imported pointers |
706 | 707 | // and returning the same handle for duplicate imports. |
| 708 | +// |
| 709 | +// The cache key is the export_data bytes (CUmemPoolPtrExportData), not the |
| 710 | +// returned pointer, because we must check the cache BEFORE calling |
| 711 | +// cuMemPoolImportPointer (which fails with CUDA_ERROR_ALREADY_MAPPED if |
| 712 | +// the pointer is already imported). |
707 | 713 |
|
708 | 714 | // TODO: When driver fix is available, add version check here to bypass cache. |
709 | 715 | static bool use_ipc_ptr_cache() { |
710 | 716 | return true; |
711 | 717 | } |
712 | 718 |
|
| 719 | +// Wrapper for CUmemPoolPtrExportData to use as map key |
| 720 | +struct ExportDataKey { |
| 721 | + CUmemPoolPtrExportData data; |
| 722 | + |
| 723 | + bool operator==(const ExportDataKey& other) const { |
| 724 | + return std::memcmp(&data, &other.data, sizeof(data)) == 0; |
| 725 | + } |
| 726 | +}; |
| 727 | + |
| 728 | +struct ExportDataKeyHash { |
| 729 | + std::size_t operator()(const ExportDataKey& key) const { |
| 730 | + // Simple hash of the bytes |
| 731 | + std::size_t h = 0; |
| 732 | + const auto* bytes = reinterpret_cast<const unsigned char*>(&key.data); |
| 733 | + for (std::size_t i = 0; i < sizeof(key.data); ++i) { |
| 734 | + h = h * 31 + bytes[i]; |
| 735 | + } |
| 736 | + return h; |
| 737 | + } |
| 738 | +}; |
| 739 | + |
713 | 740 | static std::mutex ipc_ptr_cache_mutex; |
714 | | -static std::unordered_map<CUdeviceptr, std::weak_ptr<DevicePtrBox>> ipc_ptr_cache; |
| 741 | +static std::unordered_map<ExportDataKey, std::weak_ptr<DevicePtrBox>, ExportDataKeyHash> ipc_ptr_cache; |
715 | 742 |
|
716 | 743 | DevicePtrHandle deviceptr_import_ipc(MemoryPoolHandle h_pool, const void* export_data, StreamHandle h_stream) { |
717 | 744 | if (!ensure_driver_loaded()) { |
718 | 745 | err = CUDA_ERROR_NOT_INITIALIZED; |
719 | 746 | return {}; |
720 | 747 | } |
721 | | - GILReleaseGuard gil; |
722 | | - CUdeviceptr ptr; |
| 748 | + |
723 | 749 | auto data = const_cast<CUmemPoolPtrExportData*>( |
724 | 750 | reinterpret_cast<const CUmemPoolPtrExportData*>(export_data)); |
725 | | - if (CUDA_SUCCESS != (err = p_cuMemPoolImportPointer(&ptr, *h_pool, data))) { |
726 | | - return {}; |
727 | | - } |
728 | 751 |
|
729 | 752 | if (use_ipc_ptr_cache()) { |
| 753 | + // Check cache BEFORE calling cuMemPoolImportPointer |
| 754 | + ExportDataKey key; |
| 755 | + std::memcpy(&key.data, data, sizeof(key.data)); |
| 756 | + |
730 | 757 | std::lock_guard<std::mutex> lock(ipc_ptr_cache_mutex); |
731 | 758 |
|
732 | | - // Check for existing handle |
733 | | - auto it = ipc_ptr_cache.find(ptr); |
| 759 | + auto it = ipc_ptr_cache.find(key); |
734 | 760 | if (it != ipc_ptr_cache.end()) { |
735 | 761 | if (auto box = it->second.lock()) { |
| 762 | + // Cache hit - return existing handle |
736 | 763 | return DevicePtrHandle(box, &box->resource); |
737 | 764 | } |
738 | 765 | ipc_ptr_cache.erase(it); // Expired entry |
739 | 766 | } |
740 | 767 |
|
| 768 | + // Cache miss - import the pointer |
| 769 | + GILReleaseGuard gil; |
| 770 | + CUdeviceptr ptr; |
| 771 | + if (CUDA_SUCCESS != (err = p_cuMemPoolImportPointer(&ptr, *h_pool, data))) { |
| 772 | + return {}; |
| 773 | + } |
| 774 | + |
741 | 775 | // Create new handle with cache-clearing deleter |
742 | 776 | auto box = std::shared_ptr<DevicePtrBox>( |
743 | 777 | new DevicePtrBox{ptr, h_stream}, |
744 | | - [h_pool, ptr](DevicePtrBox* b) { |
| 778 | + [h_pool, key](DevicePtrBox* b) { |
745 | 779 | { |
746 | 780 | std::lock_guard<std::mutex> lock(ipc_ptr_cache_mutex); |
747 | | - ipc_ptr_cache.erase(ptr); |
| 781 | + ipc_ptr_cache.erase(key); |
748 | 782 | } |
749 | 783 | GILReleaseGuard gil; |
750 | 784 | p_cuMemFreeAsync(b->resource, native(b->h_stream)); |
751 | 785 | delete b; |
752 | 786 | } |
753 | 787 | ); |
754 | | - ipc_ptr_cache[ptr] = box; |
| 788 | + ipc_ptr_cache[key] = box; |
755 | 789 | return DevicePtrHandle(box, &box->resource); |
756 | 790 |
|
757 | 791 | } else { |
758 | 792 | // No caching - simple handle creation |
| 793 | + GILReleaseGuard gil; |
| 794 | + CUdeviceptr ptr; |
| 795 | + if (CUDA_SUCCESS != (err = p_cuMemPoolImportPointer(&ptr, *h_pool, data))) { |
| 796 | + return {}; |
| 797 | + } |
| 798 | + |
759 | 799 | auto box = std::shared_ptr<DevicePtrBox>( |
760 | 800 | new DevicePtrBox{ptr, h_stream}, |
761 | 801 | [h_pool](DevicePtrBox* b) { |
|
0 commit comments