[CK_TILE] Add Flatmm MX FP8 (#3208)

* Use async for flatmm mxfp4

* Fix preshuffle

* Add flatmm mxfp8

* Thanks, Copilot

* Thanks Copilot again~
This commit is contained in:
Yi DING
2025-11-20 10:35:15 +08:00
committed by GitHub
parent 4e49e0228b
commit 47e2ed838e
17 changed files with 698 additions and 595 deletions

View File

@@ -143,16 +143,24 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
}
}();
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
const index_t kFlatN = kargs.N / kNWarpTile;
const auto& b_flat_tensor_view = [&]() {
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
"wrong! vector size for B tensor");
auto&& naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
auto&& desc = transform_tensor_descriptor(
naive_desc,
make_tuple(make_pass_through_transform(kFlatN),
make_merge_transform_v3_division_mod(
make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
}();
const auto& ds_tensor_view = generate_tuple(