This commit is contained in:
Feng Shijie
2025-08-11 11:24:34 +00:00
parent 200a11afc8
commit edb58d0680
7 changed files with 112 additions and 50 deletions

View File

@@ -157,12 +157,12 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
const auto scale_b_flat_view =
make_naive_tensor_view<address_space_enum::global>(scale_n.ptr,
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
return make_tuple(
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
@@ -297,7 +297,11 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
return make_tuple(a_block_window,
b_flat_block_window,
ds_block_window,
e_block_window,
scale_block_window);
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
@@ -326,7 +330,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& scale_block_window = gemm_tile_windows.at(I3);
const auto& scale_block_window = gemm_tile_windows.at(I4);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
b_flat_block_window,
scale_block_window,