diff --git a/CMakeLists.txt b/CMakeLists.txt index e8626b2cb9..b27e6ab4fb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -117,7 +117,7 @@ else() add_definitions(-DPROFILER_ONLY) set(GPU_TARGETS "" CACHE STRING "" FORCE) if(GPU_TARGETS) - message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12") endif() if(GPU_ARCH MATCHES "gfx90") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") @@ -127,8 +127,10 @@ else() rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") elseif(GPU_ARCH MATCHES "gfx11") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") + elseif(GPU_ARCH MATCHES "gfx12") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201") else() - message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12") endif() set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) endif() diff --git a/Jenkinsfile b/Jenkinsfile index 855fe8dff9..67e9b2fcb3 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){ def variant = env.STAGE_NAME def retimage + gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) @@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM pipeline { agent none - triggers { - parameterizedCron(CRON_SETTINGS) - } options { parallelsAlwaysFailFast() } diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index fb2b38d688..93fd306e98 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,7 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 8c52e4f7d7..f8afe8d6db 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle - < ALayout, - BLayout, - CLayout, - ADataType, + < ALayout, + BLayout, + CLayout, + ADataType, BDataType, - CDataType, - AccDataType, - CShuffleDataType, - AElementOp, - BElementOp, - CElementOp, - GemmDefault, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, 1, // Prefetch stage 128, // BlockSize 64, // MPerBlock 128, // NPerBlock 64, // KPerBlock - 8, // K1 + 2, // K1 16, // MPerWmma 16, // NPerWmma 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, 1, // C shuffle (M Repeat) Per store 1, // C shuffle (N Repeat) Per store - S<1, 32, 1, 4>, + S<1, 32, 1, 4>, 8>; // clang-format on diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index b04e4e53a8..cb15186c3b 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; case 4: - ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); break; case 5: diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index ab19f819e8..be47665a26 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) set(target 1) endif() -endforeach() \ No newline at end of file +endforeach() diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 2bbf430c4e..f556be887f 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN = 2, 4, 4, - true, + false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, - true, + false, 1, 1, S<1, 64, 1, 2>, diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp index 4c92c5497f..fac19f8b5a 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp @@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial #define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_4 -#define CK_MHA_USE_WAVE_8 +//#define CK_MHA_USE_WAVE_8 using DeviceMHAFactory = std::tuple< #ifdef CK_MHA_USE_WAVE_1 @@ -277,10 +277,10 @@ using DeviceMHAFactory = S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, // CShuffleBlockTransfer MN 1, 1, S<1, 64, 1, 2>, 8, - MaskingSpec>, + MaskingSpec> #endif #ifdef CK_MHA_USE_WAVE_8 - ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp index 8e037272b8..d463cc8716 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp @@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial #define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_4 -#define CK_MHA_USE_WAVE_8 +//#define CK_MHA_USE_WAVE_8 using DeviceMHAFactory = std::tuple< #ifdef CK_MHA_USE_WAVE_1 @@ -277,10 +277,10 @@ using DeviceMHAFactory = S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, // CShuffleBlockTransfer MN 1, 1, S<1, 64, 1, 2>, 8, - MaskingSpec>, + MaskingSpec> #endif #ifdef CK_MHA_USE_WAVE_8 - ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index fd9f5cd89d..c9781637d6 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() @@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 32eea551f5..9528a30b4b 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #define __gfx11__ #endif +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code @@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx11__) +#elif defined(__gfx11__) || defined(__gfx12__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif @@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 -#elif defined(__gfx11__) +#elif defined(__gfx11__) || defined(__gfx12__) #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8_GFX11 @@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define CK_USE_AMD_MFMA_GFX940 #endif -// WMMA instruction -#ifndef __HIP_DEVICE_COMPILE__ // for host code -#define CK_USE_AMD_WMMA -#elif defined(__gfx11__) // for GPU code -#define CK_USE_AMD_WMMA -#endif - // buffer load #define CK_USE_AMD_BUFFER_LOAD 1 diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 116bb3ea02..83af2efe88 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -84,4 +84,9 @@ inline bool is_gfx11_supported() ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; } +inline bool is_gfx12_supported() +{ + return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; +} + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 873539f8b1..3ea19da741 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -13,6 +13,504 @@ namespace ck { +#ifdef __gfx12__ +template +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + */ +struct BlockwiseGemmWMMA +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = 2; + static constexpr index_t B_KRow = 2; + + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + // Default, Block buffer in LDS, thread level offset enabled + __device__ static auto CalculateAThreadOriginDataIndex() + { + if constexpr(AEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_assert(KPack % (A_K1 * A_KRow) == 0, ""); + static_assert(KPack % (B_K1 * B_KRow) == 0, ""); + + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / KPack, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + template + struct AThreadCopySelector; + + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; + + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatA, + FloatA, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + false>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatB, + FloatB, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; +}; +#else template ::type a_thread_copy_; typename BThreadCopySelector::type b_thread_copy_; }; +#endif } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index e2296a55f7..d3f6344c27 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // sync point. if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) { +#ifdef __gfx12__ + asm volatile("\ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("s_barrier" ::); +#endif __builtin_amdgcn_sched_barrier(0); } static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index a157595593..ab3f3856aa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; - static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true; // If true, LDS is used unconditionally static constexpr auto AEnableLds_manu = false; @@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { @@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle } else { - if(!(arg.a_kz_stride_ == 1 && - arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + if(!(arg.a_kz_stride_ == 1)) { - printf("DeviceOp: Vector Access A-k check failure\n"); - return false; + index_t LastK = + AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6); + if(LastK % ABlockTransferSrcScalarPerVector == 0) + { + printf("DeviceOp: Vector Access A-k check failure\n"); + return false; + } } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp index 8fd14afc0c..1b487502f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp @@ -70,8 +70,9 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index f1bc6a2261..f0f89f1d1b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle return false; } - if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" && - ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" && - std::is_same::value) + if(!ck::is_lds_direct_load_supported() && std::is_same::value) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index b84e181306..1edae33be3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl { // check device if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || - ck::is_gfx11_supported())) + ck::is_gfx11_supported() || ck::is_gfx12_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp index bf96324d00..553143e286 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index b1784b3858..eb0fb55f5d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index 93ab8a7e1d..a7cc546f53 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm{}; - static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); - static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); - static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; - static constexpr auto AEnableLds_auto = - (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) && + is_same::value) + ? false + : true; static constexpr auto BEnableLds_auto = - (MWaves == 1 && is_same::value) ? false : true; + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) && + is_same::value) + ? false + : true; // If true, LDS is used unconditionally static constexpr auto AEnableLds_manu = false; @@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 6f74838fba..6bb5d431c9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 86091aeba9..cc26936fef 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -48,8 +48,9 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index 211185dfb0..5738be0fb3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index ce86ec54e5..c3fe54b075 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -90,8 +90,9 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -667,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK // check device if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || - ck::is_gfx103_supported() || ck::is_gfx11_supported())) + ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 5c9d63e2b0..c6b84b613c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -107,7 +107,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx11__)) + defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index ac392cddc4..060a16d1e2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -39,8 +39,9 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ + defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp index 4e14ed3a51..cc88c1a104 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -60,7 +60,7 @@ __global__ void bool input_permute, bool output_permute) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** @@ -165,6 +165,7 @@ __global__ void ignore = O; ignore = G0; ignore = G1; + ignore = alpha; ignore = input_permute; ignore = output_permute; #endif // end of if (defined(__gfx11__)) @@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma static bool IsSupportedArgument(const RawArg& arg) { - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp index 16717ff819..1754e07e6a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if constexpr(B0EnableLds) { // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 - constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( B0BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if constexpr(B1EnableLds) { // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 - constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); - constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); + constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_LRow = I2; +#else constexpr auto B_LRow = I1; +#endif return transform_tensor_descriptor( B1BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 499eb7eb01..21dac6f9e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -50,7 +50,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, @@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 49a6dc3b0f..b3b057c80a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -54,7 +54,7 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -147,7 +147,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // printf("entry kernel launch"); __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; @@ -237,7 +237,7 @@ __global__ void const CDEElementwiseOperation cde_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; GridwiseOp::template Run(p_a_grid, @@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma } else { + constexpr auto A_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma } else { + constexpr auto B_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma // *Caution Here repeat is shuffle repeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; } @@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma const auto M = e_grid_desc_m_n.GetLength(I0); const auto N = e_grid_desc_m_n.GetLength(I1); - const auto MBlock = M / MPerBlock; - const auto NBlock = N / NPerBlock; + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( e_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 8e4117593c..4458b9356d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -45,7 +45,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, @@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma } else { + constexpr auto A_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma } else { + + constexpr auto B_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif + return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; - using DefaultBlock2CTileMap = - remove_cvref_t; - struct SharedMemTrait { // LDS allocation for A and B: be careful of alignment @@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma b_block_space_size_aligned * sizeof(BDataType)); }; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + template __device__ static void Run(const ADataType* __restrict__ p_a_grid, const BDataType* __restrict__ p_b_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp index 6772524e0a..1740749907 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -35,8 +35,9 @@ __global__ void const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) GridwiseTensorRearrangeKernel::Run(in_grid_desc, p_in_global, out_grid_desc, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index bcce930fc7..d7a6a36244 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ElementwiseOperation element_op_; }; -// Specilized for WMMA +// Specilized for WMMA-Navi3 // A single Wave32 is composed by double row // Data exchange allowed between these two rows // This RowLane Dst buf will be filled from two Src buf @@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ElementwiseOperation element_op_{}; }; +// Specilized for WMMA-Navi4 +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + ignore = src_idx; + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v_this_row; + // int type temp value due to intrinsic requirement + int temp = 0; + + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); + + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } + + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + }); + }); + } + ElementwiseOperation element_op_{}; +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 565195f53e..9a9ebf5595 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -11,12 +11,17 @@ namespace ck { enum struct WmmaInstr { + // gfx11 wmma_f32_16x16x16_f16 = 0, wmma_f32_16x16x16_bf16, wmma_f16_16x16x16_f16, wmma_bf16_16x16x16_bf16, wmma_i32_16x16x16_iu8, - wmma_i32_16x16x16_iu4 + wmma_i32_16x16x16_iu4, + // gfx12 + wmma_f32_16x16x16_f16_gfx12, + wmma_f32_16x16x16_bf16_gfx12, + wmma_i32_16x16x16_iu8_gfx12, }; /* @@ -279,6 +284,122 @@ struct wmma_type +struct wmma_type> +{ + // Absolute fixing property + // * Data Pixel + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + // static constexpr index_t acc_data_size = 4; + // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; + // * num_acc_vgprs_per_wave alone M direction + // * num_subgroups alone M direction + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( + a, b, reg_c); + } + } +}; + template static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; +#else return WmmaInstr::wmma_f32_16x16x16_f16; +#endif } template <> static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; +#else return WmmaInstr::wmma_f32_16x16x16_bf16; +#endif } template <> @@ -320,8 +449,13 @@ struct WmmaSelector template <> static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; +#else return WmmaInstr::wmma_i32_16x16x16_iu8; +#endif } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> static constexpr auto GetWmma() @@ -502,6 +636,9 @@ struct WmmaGemm __device__ static auto GetSubGroupId() { + static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups == + wmma_instr.wave_size, + ""); return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; } @@ -516,12 +653,20 @@ struct WmmaGemm __host__ __device__ static auto CalculateAThreadOriginDataIndex() { +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); +#endif } __host__ __device__ static auto CalculateBThreadOriginDataIndex() { +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); +#endif } __device__ static CIndex GetBeginOfThreadBlk() diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 1bb0140f3e..322a0f94bb 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> } }; +// gfx12 +/********************************WAVE32 MODE***********************************************/ + +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif + +// src: fp16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { + // * Inline assembly need to elimate the duplicated data load, compiler won't help you + // delete them. + // amd_assembly_wmma_f32_16x16x16_f16_w32( + // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12; + +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck #endif diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 93a1edefb6..4df14c6211 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -203,7 +203,7 @@ struct vector_type } }; -int static err = 0; +__device__ int static err = 0; template struct vector_type { diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index 4fe5e39504..d6b6eac26c 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -10,12 +10,20 @@ namespace ck { __device__ void block_sync_lds() { #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else // asm volatile("\ // s_waitcnt lgkmcnt(0) \n \ // s_barrier \ // " ::); __builtin_amdgcn_s_waitcnt(0xc07f); __builtin_amdgcn_s_barrier(); +#endif #else __syncthreads(); #endif @@ -23,11 +31,20 @@ __device__ void block_sync_lds() __device__ void block_sync_lds_direct_load() { +#ifdef __gfx12__ + asm volatile("\ + s_wait_vmcnt 0x0 \n \ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt vmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif } __device__ void s_nop() diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 344343d931..83637e18e4 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -17,6 +17,9 @@ #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #define __gfx11__ #endif +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" @@ -155,7 +158,7 @@ #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx11__) // for GPU code +#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 05b8c035c4..1bcc0f802b 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -59,7 +59,7 @@ function(add_instance_library INSTANCE_NAME) endforeach() # Do not build WMMA instances if gfx11 targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -177,7 +177,7 @@ FOREACH(subdir_path ${dir_list}) message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11")) + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12")) message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") set(add_inst 0) endif() @@ -185,11 +185,11 @@ FOREACH(subdir_path ${dir_list}) message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9")) message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") set(add_inst 0) endif() diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index fa0eb6f887..5262ca33a6 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -59,7 +59,7 @@ if(GPU_TARGETS MATCHES "gfx9") endif() -if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) endif() @@ -134,7 +134,7 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) endif() -if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") +if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 49b67992b1..66b4d3d27d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -60,7 +60,7 @@ function(add_test_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -139,7 +139,7 @@ function(add_gtest_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 5ef0730668..aee80cb2cb 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -44,7 +44,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } } - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { // on gfx11 only support for 3d is implemented if constexpr(NDimSpatial{} != 3) diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index 49782bce6e..d9ec94771a 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele]; } +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif for(int ele = 0; ele < 16; ++ele) { @@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8]; } +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif // sync threads, similar to mma_sync // __syncthreads();