diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 33d0ec6713..90f45170df 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -143,7 +143,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent constexpr ck::index_t ScaleBlockSize = 32; // scaling block size constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 -static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MPerBlock = 32; static constexpr bool MulRoutedWeight = true; // clang-format off @@ -151,14 +151,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic A0Layout, B0Layout, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - ScaleBlockSize, 256, - MPerBlock, 128, KPerBlock, + ScaleBlockSize, 64, + MPerBlock, 32, KPerBlock, 16, 16, 16, 16, - 4, 4, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + 2, 2, + S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 8, 1, 8>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on @@ -170,14 +170,14 @@ int main(int argc, char* argv[]) // per expert: // GEMM shape - constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t sorted_tile_num = 2; constexpr ck::index_t valid_tile_num = sorted_tile_num; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t N = 6144; ck::index_t K = 4096; - ck::index_t experts = 8; + ck::index_t experts = 2; ck::index_t tokens = 832; ck::index_t topk = 2; @@ -319,8 +319,8 @@ int main(int argc, char* argv[]) d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 3: - a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); @@ -337,12 +337,26 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 6: a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 8: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; @@ -404,25 +418,40 @@ int main(int argc, char* argv[]) auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; -#if 0 +#if 1 printf("a0_t_k_k:\n"); + // for(int t = 0; t < tokens; ++t) + // { + // for(int tk = 0; tk < topk; ++tk) + // { + // for(int k = 0; k < K; ++k) + // { + // auto f4x2 = a0_t_k_k(t, tk, k).data; + // if(k % 2 == 0) + // { + // ck::f4_t f4 = (f4x2 >> 4) & 0xf; + // printf("%.2f ", ck::type_convert(f4)); + // } + // else + // { + // ck::f4_t f4 = (f4x2 >> 0) & 0xf; + // printf("%.2f ", ck::type_convert(f4)); + // } + // } + // printf("\n"); + // } + // printf("\n"); + // } + for(int t = 0; t < tokens; ++t) { for(int tk = 0; tk < topk; ++tk) { - for(int k = 0; k < K; ++k) + for(int k = 0; k < K;) { - auto f4x2 = a0_t_k_k(t, tk, k).data; - if(k % 2 == 0) - { - ck::f4_t f4 = (f4x2 >> 4) & 0xf; - printf("%.2f ", ck::type_convert(f4)); - } - else - { - ck::f4_t f4 = (f4x2 >> 0) & 0xf; - printf("%.2f ", ck::type_convert(f4)); - } + printf("0x%08x ", + *(reinterpret_cast(&(a0_t_k_k(t, tk, k).data)))); // 4 bytes + k += 8; } printf("\n"); } @@ -464,23 +493,37 @@ int main(int argc, char* argv[]) } printf("b0_e_n_k:\n"); + // for(int e = 0; e < experts; ++e) + // { + // for(int n = 0; n < N; ++n) + // { + // for(int k = 0; k < K; ++k) + // { + // auto f4x2 = b0_e_n_k(e, k, n).data; + // if(k % 2 == 0) + // { + // ck::f4_t f4 = f4x2 >> 4 & 0xf; + // printf("%.2f ", ck::type_convert(f4)); + // } + // else + // { + // ck::f4_t f4 = f4x2 >> 0 & 0xf; + // printf("%.2f ", ck::type_convert(f4)); + // } + // } + // printf("\n"); + // } + // printf("\n"); + // } for(int e = 0; e < experts; ++e) { for(int n = 0; n < N; ++n) { - for(int k = 0; k < K; ++k) + for(int k = 0; k < K;) { - auto f4x2 = b0_e_n_k(e, k, n).data; - if(k % 2 == 0) - { - ck::f4_t f4 = f4x2 >> 4 & 0xf; - printf("%.2f ", ck::type_convert(f4)); - } - else - { - ck::f4_t f4 = f4x2 >> 0 & 0xf; - printf("%.2f ", ck::type_convert(f4)); - } + printf("0x%08x ", + *(reinterpret_cast(&(b0_e_n_k(e, k, n).data)))); // 4 bytes + k += 8; } printf("\n"); } @@ -509,6 +552,7 @@ int main(int argc, char* argv[]) printf("%.2f ", ck::type_convert(d2_e_n(i, n))); } } + printf("\n"); #endif // do GEMM @@ -625,7 +669,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); -#if 0 +#if 1 printf("e_t_n_device_result:\n"); for(int t = 0; t < tokens; ++t) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp index 86dd9600e4..89b47ccdeb 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp @@ -472,6 +472,25 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + *(reinterpret_cast(&(a_block_bufs(I0)[0].data))), + *(reinterpret_cast(&(a_block_bufs(I0)[16].data))), + *(reinterpret_cast(&(a_block_bufs(I0)[32].data))), + *(reinterpret_cast(&(a_block_bufs(I0)[48].data))), + *(reinterpret_cast(&(a_block_bufs(I0)[64].data))), + *(reinterpret_cast(&(a_block_bufs(I0)[80].data))), + *(reinterpret_cast(&(a_block_bufs(I0)[96].data))), + *(reinterpret_cast(&(a_block_bufs(I0)[112].data))), + *(reinterpret_cast(&(a_block_bufs(I0)[1024 + 0].data))), + *(reinterpret_cast(&(a_block_bufs(I0)[1024 + 112].data)))); + +#endif + static_for<0, KRepeat, 1>{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); @@ -1080,11 +1099,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx(), c_thread_buf.GetVectorTypeReference(Number{})); -#if 0 +#if 1 printf( "blkIdx: %u, blkIdy: %u, tidx: %u, imxdl: %d, inxdl: " - "%d, ikxdl: %d, a_thread_vec=<%.2f, %.2f, %.2f, %.2f>, " - "b_thread_vec=<%.2f, %.2f, %.2f, %.2f>, a_scale=%08x, " + "%d, ikxdl: %d, a_thread_vec=<%08x, %08x, %08x, %08x>, " + "b_thread_vec=<%08x, %08x, %08x, %08x>, a_scale=%08x, " "b_scale=%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n", blockIdx.x, blockIdx.y, @@ -1092,38 +1111,22 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx( - a_thread_vec - .template AsType()[Number<0>{}] - .unpack(Number<0>{})), - type_convert( - a_thread_vec - .template AsType()[Number<0>{}] - .unpack(Number<1>{})), - type_convert( - a_thread_vec - .template AsType()[Number<1>{}] - .unpack(Number<0>{})), - type_convert( - a_thread_vec - .template AsType()[Number<1>{}] - .unpack(Number<1>{})), - type_convert( - b_thread_vec - .template AsType()[Number<0>{}] - .unpack(Number<0>{})), - type_convert( - b_thread_vec - .template AsType()[Number<0>{}] - .unpack(Number<1>{})), - type_convert( - b_thread_vec - .template AsType()[Number<1>{}] - .unpack(Number<0>{})), - type_convert( - b_thread_vec - .template AsType()[Number<1>{}] - .unpack(Number<1>{})), + *(reinterpret_cast(&( + a_thread_vec.template AsType()[Number<0>{}]))), + *(reinterpret_cast(&( + a_thread_vec.template AsType()[Number<1>{}]))), + *(reinterpret_cast(&( + a_thread_vec.template AsType()[Number<2>{}]))), + *(reinterpret_cast(&( + a_thread_vec.template AsType()[Number<3>{}]))), + *(reinterpret_cast(&( + b_thread_vec.template AsType()[Number<0>{}]))), + *(reinterpret_cast(&( + b_thread_vec.template AsType()[Number<1>{}]))), + *(reinterpret_cast(&( + b_thread_vec.template AsType()[Number<2>{}]))), + *(reinterpret_cast(&( + b_thread_vec.template AsType()[Number<3>{}]))), *(reinterpret_cast(&( a_scale_thread_vec .template AsType()[Number<0>{}]))), diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp index 24a95a27d9..9ba1849661 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -68,15 +68,20 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad static constexpr auto block_slice_lengths = BlockSliceLengths{}; static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; - static constexpr auto wave_thread_cluster_lengths = Sequence{}; - static constexpr auto wave_cluster_lengths = Sequence<1, ThreadGroup::GetNumOfThread()/64, 1>{}; + static constexpr auto wave_thread_cluster_lengths = + Sequence{}; + static constexpr auto wave_cluster_lengths = + Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{}; static constexpr auto thread_single_load_size = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); // After a load, each thread moves by `thread_steps` instead of loading the next elements. // It makes the whole wavefront load contiguous memory, what is required for direct loads. - static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; - static constexpr auto wave_single_load_size= wave_thread_cluster_lengths*thread_single_load_size; + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto wave_single_load_size = + wave_thread_cluster_lengths * thread_single_load_size; static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; static __device__ constexpr bool AreThreadClusterLengthsValid() @@ -171,17 +176,17 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); - - const auto wave_cluster_idx = - wave_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()/64)); + + const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId() / 64)); const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size; - const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size; + const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size; SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); // We don't need threadwise offset for lds since it was calculate by HW // We still need input the wavewise offset. - SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin); + SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -240,6 +245,22 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad src_buf.template DirectCopyToLds, ScalarPerVector>( dst_buf, src_offset, dst_offset, is_src_valid); +#if 0 + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + printf("blkx: %u, blky: %u, tid: %u, src: %d, b_dst_offset: " + "%d, b_dst_buffer=<%02x, %02x, %02x, %02x>\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + src_offset, + dst_offset, + static_cast(dst_buf[dst_offset].data), + static_cast(dst_buf[dst_offset + 16].data), + static_cast(dst_buf[dst_offset + 32].data), + static_cast(dst_buf[dst_offset + 48].data)); +#endif + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; @@ -292,6 +313,23 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad }); }); +#if 0 + block_sync_lds(); + if(threadIdx.x == 0) + { + // Print the contents of the destination buffer. + printf("blkx: %u, blky: %u, tid: %u, B_dst_buffer=<%02x, %02x, %02x, %02x>\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + static_cast(dst_buf[Number<0>{}].data), + static_cast(dst_buf[Number<16>{}].data), + static_cast(dst_buf[Number<32>{}].data), + static_cast(dst_buf[Number<48>{}].data)); + } + +#endif + // Reset the destination slice since the entire buffer has been already filled. ResetDstSliceWindow(dst_desc); } diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp index 816c628135..8ad8e3b94d 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp @@ -66,15 +66,27 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; static constexpr auto block_slice_lengths = BlockSliceLengths{}; static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; + static constexpr auto wave_thread_cluster_lengths = + Sequence{}; + static constexpr auto wave_cluster_lengths = + Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{}; static constexpr auto thread_single_load_size = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); + + // CK_PRINT(); + // After a load, each thread moves by `thread_steps` instead of loading the next elements. // It makes the whole wavefront load contiguous memory, what is required for direct loads. - static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto wave_single_load_size = + wave_thread_cluster_lengths * thread_single_load_size; static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; static constexpr index_t gather_num = thread_slice_lengths.At(Number{}); @@ -172,10 +184,16 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId() / 64)); + const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size; + const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size; SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); - SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin); + // We don't need threadwise offset for lds since it was calculate by HW + // We still need input the wavewise offset. + SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -188,6 +206,9 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad return idx; }(); + // CK_PRINT(); + // CK_PRINT(); + src_coord_ = make_tensor_coordinate(src_desc, adjusted_src_origin_idx); src_slice_origin_ = adjusted_src_origin_idx; } @@ -230,20 +251,45 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad // Loop over the destination block and copy data. static_ford{}([&](auto ordered_dst_access_idx) { - // CK_PRINT(); - auto gather_offset = gather_offsets_(Number{}); - const auto src_offset = src_coord_.GetOffset() + gather_offset; - const auto dst_offset = dst_coord_.GetOffset(); - // printf("Tid: %03d, src_offset: %d, dst_offset: %d\n", get_thread_local_1d_id(), - // src_coord_.GetOffset(), dst_coord_.GetOffset()); + IndexType gather_offset = gather_offsets_[ordered_dst_access_idx[Number{}]]; + const IndexType src_offset = src_coord_.GetOffset() + gather_offset; + const IndexType dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset()); + // Check if src data is not in the logic padding area. - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + // Leave the HW for oob checking + // const bool is_src_valid = + // coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, + // src_coord_); src_buf.template DirectCopyToLds, ScalarPerVector>( - dst_buf, src_offset, dst_offset, is_src_valid); + dst_buf, src_offset, dst_offset, true); - constexpr auto move_on_dim = [&]() constexpr +#if 1 + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + printf("blkx: %u, blky: %u, tid: %u, red_id: %d src: %d (cal: %d, gather: %d), " + "dst_offset: " + "%d, a_dst_buffer=<0x%08x, 0x%08x, 0x%08x, 0x%08x>\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + static_cast(ordered_dst_access_idx[Number{}]), + src_offset, + src_coord_.GetOffset(), + gather_offset, + dst_offset, + // *(reinterpret_cast(&(dst_buf[dst_offset + 0].data))), + *(reinterpret_cast( + &(dst_buf[dst_offset + 16 * threadIdx.x].data))), + *(reinterpret_cast( + &(dst_buf[dst_offset + 16 * threadIdx.x].data))), + *(reinterpret_cast( + &(dst_buf[dst_offset + 32 * threadIdx.x].data))), + *(reinterpret_cast( + &(dst_buf[dst_offset + 48 * threadIdx.x].data)))); +#endif + + constexpr auto move_src_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; @@ -260,6 +306,22 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad } (); + constexpr auto move_dst_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + // Decide whether to move forward or backward. constexpr auto forward_sweep = [&]() { StaticallyIndexedArray forward_sweep_; @@ -280,22 +342,58 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad }(); static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) + // Move the source coordinate. + if constexpr(move_src_on_dim[i]) { if constexpr(forward_sweep[i]) { - move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]); move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]); } else { - move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]); move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]); } } + + // Move the destination coordinate. + if constexpr(move_dst_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]); + } + else + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]); + } + } }); }); +#if 0 + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + if(threadIdx.x == 0) + { + // Print the contents of the destination buffer. + printf("blkx: %u, blky: %u, tid: %u, a_dst_buf_offset=<%d, %d, %d, %d>, " + "a_dst_buffer=<%02x, %02x, %02x, %02x>\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + 0, + 16, + 32, + 48, + static_cast(dst_buf[Number<0>{}].data), + static_cast(dst_buf[Number<16>{}].data), + static_cast(dst_buf[Number<32>{}].data), + static_cast(dst_buf[Number<48>{}].data)); + } + +#endif + // Reset the destination slice since the entire buffer has been already filled. ResetDstSliceWindow(dst_desc); } @@ -325,6 +423,8 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad private: static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + static constexpr auto wave_cluster_desc_ = + make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{}); SrcCoord src_coord_; DstCoord dst_coord_; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 4b2579b6b3..46b71737f0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -1387,6 +1387,7 @@ struct GridwiseMoeGemmMXBNS BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + ignore = a_element_op; ignore = b_element_op; const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, @@ -1657,35 +1658,22 @@ struct GridwiseMoeGemmMXBNS p_b_grid_up + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto b_blockwise_copy_up = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( @@ -2167,6 +2155,7 @@ struct GridwiseMoeGemmMXBNS BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + ignore = a_element_op; ignore = b_element_op; const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, @@ -2253,7 +2242,7 @@ struct GridwiseMoeGemmMXBNS gather_offsets(m0) = static_cast(token_offset) * problem.K; }); -#if 0 +#if 1 printf("blkx: %u, blky: %u, tidx: %u, token_pos: %d, gather_offsets:<%d, %d, %d, %d>\n", blockIdx.x, blockIdx.y, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 3d97fdae20..9d38d3e1e9 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1504,6 +1504,23 @@ struct ThreadwiseTensorSliceTransfer_v4 dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; }); +#if 1 + printf("blky: %u, tid: %u, src_offset: %d, repeat_id: %d, dst_tmp_vec=<0x%08x, " + "0x%08x, 0x%08x, " + "0x%08x\n", + blockIdx.y, + threadIdx.x, + static_cast(ordered_access_idx[Number<1>{}]), + src_data_coord.GetOffset(), + *(reinterpret_cast( + &(dst_tmp_vector.template AsType()[Number<0>{}]))), + *(reinterpret_cast( + &(dst_tmp_vector.template AsType()[Number<1>{}]))), + *(reinterpret_cast( + &(dst_tmp_vector.template AsType()[Number<2>{}]))), + *(reinterpret_cast( + &(dst_tmp_vector.template AsType()[Number<3>{}])))); +#endif } }); } diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 1d80f196b5..2aae9ee3ac 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -202,6 +202,21 @@ struct DynamicBuffer static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "Destination data must be stored in an LDS memory buffer."); +#if 0 + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) + // { + // printf("DirectCopyToLds: src_offset=%d, dst_offset=%d\n", src_offset, dst_offset); + // } + printf("blkx: %u, blky: %u, tid: %u, src_offset: %d, dst_offset: %d, sizeof(src_offset): " + "%lu, sizeof(dst_offset): %lu\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + src_offset, + dst_offset, + sizeof(src_offset), + sizeof(dst_offset)); +#endif amd_direct_load_global_to_lds(p_data_, src_offset, dst_buf.p_data_,