From 94fb9190bebbdbdedc356710d01c6cfbe8b8585a Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Fri, 16 May 2025 14:46:09 -0500 Subject: [PATCH] init moe mx f4 scale shuffle --- .../moe_gemm2_xdl_mx_fp4.cpp | 112 +++++++++++---- ...pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp | 13 +- .../gpu/grid/gridwise_moe_mx_gemm.hpp | 135 +++++++++--------- 3 files changed, 165 insertions(+), 95 deletions(-) diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 8fddaaa04e..7d0cd6542f 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -82,6 +82,7 @@ struct MulABScaleExpertWeight using CDEElementOp = MulABScaleExpertWeight; +// B preshuffle void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) { int KPack = 32; @@ -113,6 +114,54 @@ void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) } } +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough; @@ -286,20 +335,27 @@ int main(int argc, char* argv[]) Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + // B preshuffle Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {N * Scale_Stride_BN, 1, Scale_Stride_BN})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + e_t_n_device_result.SetZero(); std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; - std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; - std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; switch(init_method) @@ -310,8 +366,6 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_2{0, 1}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_2{0, 1}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove d2_e_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); break; case 2: @@ -319,8 +373,6 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 3: @@ -328,8 +380,6 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; case 4: @@ -337,8 +387,6 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; case 5: @@ -346,8 +394,6 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 6: @@ -355,8 +401,6 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; default: @@ -364,8 +408,6 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * @@ -373,22 +415,42 @@ int main(int argc, char* argv[]) DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize() / 2); - DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a_scale_sorted.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); - DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); - DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int tokenid = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topkid = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(tokenid = = tokens) + { + a_scale_sorted(i, k) = 0; + } + else + { + a_scale_sorted(i, k) = a1_t_k_k(tokenid, topkid, k); + } + } + } + + preShuffleBuffer>( + a_scale_sorted.mData.data(), a_scale_preshuffled.mData.data(), sorted_size, K); + preShuffleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); a0_device_buf.ToDevice(a0_t_k_k.mData.data()); - a1_device_buf.ToDevice(a1_t_k_k.mData.data()); - b1_device_buf.ToDevice(b1_e_n_k.mData.data()); - d0_device_buf.ToDevice(d0_t_n.mData.data()); - d1_device_buf.ToDevice(d1_e_n.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); d2_device_buf.ToDevice(d2_e_n.mData.data()); e_device_buf.ToDevice(e_t_n_device_result.mData.data()); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp index d9d1928b87..bc63f57b93 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp @@ -187,12 +187,23 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3 How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() - static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() static constexpr auto ScalesPerXdlopsRunPerThread = ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index 8f3748d15f..0a8b38c2c6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -190,6 +190,10 @@ struct GridwiseMoeGemmMX static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + static constexpr bool is_single_rate_mfma = false; static constexpr auto is_scale_mfma = true; using mfma_selector = MfmaSelector; - static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize); static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); @@ -209,10 +213,6 @@ struct GridwiseMoeGemmMX static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr index_t MWave = MPerBlock / MPerXdl / MXdlPerWave; - static constexpr auto ScalesPerXdlopsRun = - (KPack * mfma_selector::selected_mfma.num_input_blks) / ScaleBlockSize; - static constexpr auto ScalesPerXdlopsRunPerThread = - ScalesPerXdlopsRun / mfma_selector::selected_mfma.num_input_blks; // static constexpr index_t NumTokens = 1; static constexpr index_t SortedTileSize = MPerBlock; @@ -712,10 +712,10 @@ struct GridwiseMoeGemmMX TopK_, M_, N_, - K_, - StrideA_, + K_ / APackedSize, + StrideA_ / APackedSize, StrideScaleA_, - StrideB_, + StrideB_ / APackedSize, StrideScaleB_, StrideDs_, StrideC_, @@ -784,21 +784,23 @@ struct GridwiseMoeGemmMX // Calculate A scale offset if constexpr(is_same_v) { - a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize; + a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize); } else if constexpr(is_same_v) { - a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize * karg.StrideScaleA; + a_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / PackedSize) * karg.StrideScaleA; } // Calculate B scale offset if constexpr(is_same_v) { - b_scale_k_split_offset = k_id * (karg.KRead / ScaleBlockSize) * karg.StrideScaleB; + b_scale_k_split_offset = + k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB; } else if constexpr(is_same_v) { - b_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize; + b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize); } if(k_id < karg.KBatch - 1) @@ -1011,6 +1013,9 @@ struct GridwiseMoeGemmMX (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || @@ -1211,6 +1216,14 @@ struct GridwiseMoeGemmMX // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, // NPerBlock>; + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + template @@ -1246,17 +1259,17 @@ struct GridwiseMoeGemmMX problem.NPadded, problem.StrideC); - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( - make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, - math::integer_divide_ceil(problem.K, ScaleBlockSize) / - ScalesPerXdlopsRunPerThread, - ScalesPerXdlopsRunPerThread), - make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), - ScalesPerXdlopsRunPerThread, - 1)); - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)), - make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1)); + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -1331,6 +1344,7 @@ struct GridwiseMoeGemmMX p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( @@ -1430,60 +1444,43 @@ struct GridwiseMoeGemmMX static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / - mfma.selected_mfma.num_threads_per_blk; + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; - auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl; + auto a_thread_offset_m = waveId_m; - // get each thread's offset int the scale tensor - const index_t token_scale_pos = block_m_id * MPerBlock; - if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) - return; - - StaticallyIndexedArray scale_gather_offsets; - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { - const index_t fused_token = - p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWave + a_thread_offset_m]; - index_t token_offset = fused_token & 0xffffff; - if constexpr(!IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } - scale_gather_offsets(m0) = - token_offset * math::integer_divide_ceil(problem.K, ScaleBlockSize); - }); - - auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2_gather< + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< AScaleDataType, AScaleDataType, decltype(a_scale_grid_desc_am_ak), decltype(BlockwiseGemmPipe::a_scale_thread_desc), - Sequence<1, 1, 1>, // SliceLengths - Sequence<0, 1, 2>, // DimAccessOrder - 2, // SrcVectorDim - 1, // SrcScalarPerVector - 1, // SrcScalarStrideInVector - true, - MXdlPerWave, - KRepeat>( - a_scale_grid_desc_am_ak, make_multi_index(0, thread_offset_k, 0), scale_gather_offsets); + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); // B scale load - auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl; + auto b_thread_offset_n = waveId_n; - auto b_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, - true>( - b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k)); + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); if constexpr(IsInputGemm) {