mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[CK_TILE] Add Flatmm MX FP8 (#3208)
* Use async for flatmm mxfp4 * Fix preshuffle * Add flatmm mxfp8 * Thanks, Copilot * Thanks Copilot again~
This commit is contained in:
@@ -143,16 +143,24 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
|
||||
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
|
||||
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
|
||||
const index_t kFlatN = kargs.N / kNWarpTile;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
|
||||
"wrong! vector size for B tensor");
|
||||
auto&& naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
|
||||
auto&& desc = transform_tensor_descriptor(
|
||||
naive_desc,
|
||||
make_tuple(make_pass_through_transform(kFlatN),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
|
||||
Reference in New Issue
Block a user