[CK_TILE MOE] add NT & preshuffle permute to cktile MOE (#3377)

* update coherence
---------

Co-authored-by: Zzz9990 <Zzz9990>
This commit is contained in:
Zzz9990
2025-12-10 10:03:28 +08:00
committed by GitHub
parent 934ba1208a
commit 1aa93ef551
8 changed files with 88 additions and 29 deletions

View File

@@ -595,16 +595,44 @@ struct MoeFlatmmKernel
}
}();
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
if constexpr(!FlatmmPipeline::BPreShufflePermute)
{
index_t kFlatK =
kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
FlatmmPipeline::BMemNTType>(
b_flat_ptr,
make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
{
index_t kFlatK = FlatmmPipeline::flatKPerWarp;
index_t kFlatN0 = (kargs.N >> 4);
index_t kFlatK0 = (kargs.K >> 7);
auto b_tensor_view_naive = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
FlatmmPipeline::BMemNTType>(
b_flat_ptr,
make_tuple(kFlatK0, kFlatN0 - kargs.n_padded_zeros / NPerXdl, kFlatK),
make_tuple(kFlatK * (kFlatN0 - kargs.n_padded_zeros / NPerXdl), kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
return transform_tensor_view(
b_tensor_view_naive,
make_tuple(
make_pass_through_transform(kFlatN0 - kargs.n_padded_zeros / NPerXdl),
make_merge_transform_v3_division_mod(make_tuple(kFlatK0, kFlatK))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}();
// TODO: enable vector write for C in ColMajor