mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE MOE] add NT & preshuffle permute to cktile MOE (#3377)
* update coherence --------- Co-authored-by: Zzz9990 <Zzz9990>
This commit is contained in:
@@ -595,16 +595,44 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
if constexpr(!FlatmmPipeline::BPreShufflePermute)
|
||||
{
|
||||
index_t kFlatK =
|
||||
kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
FlatmmPipeline::BMemNTType>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t kFlatK = FlatmmPipeline::flatKPerWarp;
|
||||
index_t kFlatN0 = (kargs.N >> 4);
|
||||
index_t kFlatK0 = (kargs.K >> 7);
|
||||
|
||||
auto b_tensor_view_naive = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
FlatmmPipeline::BMemNTType>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatK0, kFlatN0 - kargs.n_padded_zeros / NPerXdl, kFlatK),
|
||||
make_tuple(kFlatK * (kFlatN0 - kargs.n_padded_zeros / NPerXdl), kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
return transform_tensor_view(
|
||||
b_tensor_view_naive,
|
||||
make_tuple(
|
||||
make_pass_through_transform(kFlatN0 - kargs.n_padded_zeros / NPerXdl),
|
||||
make_merge_transform_v3_division_mod(make_tuple(kFlatK0, kFlatK))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
|
||||
Reference in New Issue
Block a user