diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index d1cb5733d3..12f9bcb5d3 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; using B1ElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< @@ -77,7 +77,7 @@ using DeviceGemmInstance = Acc0ElementOp, B1ElementOp, CElementOp, - GemmDefault, + GemmSpec, 1, 256, 128, // MPerBlock @@ -166,8 +166,6 @@ int main(int argc, char* argv[]) // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) ck::index_t G0 = 7; ck::index_t G1 = 13; - std::vector c_gs_ms_os_lengths{G0, G1, M, O}; - std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; if(argc == 1) { @@ -204,6 +202,9 @@ int main(int argc, char* argv[]) exit(0); } + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; + const int DefaultStrideA = ck::is_same_v ? K : M; const int DefaultStrideB0 = ck::is_same_v ? N : K; const int DefaultStrideB1 = ck::is_same_v ? O : N; diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index fff78a5266..6157cb7763 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -693,9 +693,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle } // Check if C permute dimension matches GEMM + GEMM shape - const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); - const index_t c_m = arg.c_grid_desc_g_m_n_.GetLength(I1); - const index_t c_gemm1n = arg.c_grid_desc_g_m_n_.GetLength(I2); + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0); + const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index e500ad84f1..6e69f9ddb0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -594,10 +594,17 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); - // selected_mfma.k_per_blk <= B1K1 <= selected_mfma.group_size - constexpr index_t Gemm1KPack = math::max( - math::gcd(MfmaSelector::selected_mfma.group_size, B1K1), - MfmaSelector::selected_mfma.k_per_blk); + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index 1985457300..c8cdf3d7b6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -645,10 +645,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); - // selected_mfma.k_per_blk <= B1K1 <= selected_mfma.group_size - constexpr index_t Gemm1KPack = math::max( - math::gcd(MfmaSelector::selected_mfma.group_size, B1K1), - MfmaSelector::selected_mfma.k_per_blk); + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize,