mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Add v3 support. Add compile_option
This commit is contained in:
@@ -16,3 +16,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
set(EXAMPLE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental")
|
||||
target_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
@@ -158,11 +158,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
32, 32,
|
||||
2, 2,
|
||||
4, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 1, S<1, 8, 1, 32>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
1, 1, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, false, false, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -746,25 +746,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
@@ -786,25 +782,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
@@ -970,25 +962,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
|
||||
@@ -1464,43 +1464,6 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
// // print C
|
||||
// printf("tid: %d, blkid: %d, "
|
||||
// "c_thread_buf = <%1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f,
|
||||
// %1.f, %1.f, %1.f, %1.f, %1.f, %1.f\n", get_thread_local_1d_id(), block_m_id,
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<0>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<1>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<2>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<3>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<4>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<5>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<6>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<7>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<8>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<9>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<10>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<11>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<12>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<13>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<14>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<15>{}]);
|
||||
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
@@ -1811,6 +1774,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
const BDataType* p_b_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
CDataType* p_c_grid,
|
||||
const AScaleType* p_a_scale_grid,
|
||||
const BScaleType* p_b_scale_grid,
|
||||
void* p_shared,
|
||||
void* p_shared1,
|
||||
const Problem& problem,
|
||||
@@ -1834,6 +1799,17 @@ struct GridwiseMoeGemmBlockScale
|
||||
problem.N,
|
||||
problem.NPadded,
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
|
||||
: problem.NumTokens * problem.TopK,
|
||||
ScaleBlockM),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
@@ -1891,7 +1867,9 @@ struct GridwiseMoeGemmBlockScale
|
||||
gather_offsets(m0) = token_offset * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
|
||||
const index_t expert_scale_stride =
|
||||
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N, ScaleBlockN) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK));
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
@@ -1902,6 +1880,12 @@ struct GridwiseMoeGemmBlockScale
|
||||
p_b_grid + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + expert_id * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
@@ -1932,7 +1916,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
1,
|
||||
2>(a_grid_desc_ak0_m_ak1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
@@ -1984,19 +1968,111 @@ struct GridwiseMoeGemmBlockScale
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
//scale
|
||||
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
|
||||
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
|
||||
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
|
||||
// ScaleSliceSizeK is last dimension in A/B scale for vector memory access
|
||||
// ScaleSliceSizeK is first dimension in C scale for packed math
|
||||
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
auto a_thread_offset =
|
||||
get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl;
|
||||
|
||||
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<ScaleSliceSizeK>{}, Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeN>{}));
|
||||
|
||||
// get each thread's offset in the scale tensor
|
||||
// A scale
|
||||
const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
|
||||
|
||||
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
|
||||
const index_t fused_token =
|
||||
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
if constexpr(!IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
scale_gather_offsets(m0) =
|
||||
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
|
||||
});
|
||||
|
||||
// printf("blkid: %d, tid:%d, a_thread_offset: %d, scale_gather_offsets: %d\n", block_m_id,
|
||||
// threadIdx.x, a_thread_offset,
|
||||
// scale_gather_offsets(Number<0>{}));
|
||||
|
||||
auto a_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2_gather<AScaleType,
|
||||
AScaleType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(a_scale_thread_desc),
|
||||
Sequence<1, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false,
|
||||
MXdlPerWave>(
|
||||
a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
BScaleType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_desc),
|
||||
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
|
||||
|
||||
// constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
|
||||
constexpr auto a_scale_thread_slice_copy_step =
|
||||
make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
|
||||
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
|
||||
|
||||
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_scale_thread_desc,
|
||||
c_thread_buf,
|
||||
a_scale_grid_desc_am_ak,
|
||||
a_scale_thread_desc,
|
||||
a_scale_thread_copy,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_slice_copy_step,
|
||||
b_scale_grid_desc_bn_ak,
|
||||
b_scale_thread_desc,
|
||||
b_scale_thread_copy,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_slice_copy_step,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
@@ -2006,23 +2082,24 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
// transposed XDL
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
//only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
|
||||
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
|
||||
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
@@ -2031,24 +2108,24 @@ struct GridwiseMoeGemmBlockScale
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4)),
|
||||
M2)), // M2 = MPerXdl
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2))), // N2 = NPerXdl
|
||||
N2, // N2 * N3 * N4 = NPerXdl
|
||||
N3,
|
||||
N4))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -2058,56 +2135,56 @@ struct GridwiseMoeGemmBlockScale
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
N2,
|
||||
I1,
|
||||
N4>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
n_thread_data_on_block_idx[I2],
|
||||
n_thread_data_on_block_idx[I3],
|
||||
n_thread_data_on_block_idx[I4]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
using EDataType = CDataType;
|
||||
@@ -2162,7 +2239,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
using CDEBlockTransferCluster =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
|
||||
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix
|
||||
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
@@ -2202,16 +2279,16 @@ struct GridwiseMoeGemmBlockScale
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, 1, N2, 1, N4>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
1,
|
||||
1,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>>{};
|
||||
N2,
|
||||
1,
|
||||
N4>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
@@ -2230,7 +2307,6 @@ struct GridwiseMoeGemmBlockScale
|
||||
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
|
||||
constexpr auto ENThreads =
|
||||
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats>
|
||||
@@ -2243,16 +2319,14 @@ struct GridwiseMoeGemmBlockScale
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
|
||||
: 0.0;
|
||||
float weight = token_offset < problem.NumTokens ? 1 : 0.0;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I0];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
@@ -2262,10 +2336,10 @@ struct GridwiseMoeGemmBlockScale
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
|
||||
Reference in New Issue
Block a user