mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
update scale for mxfp4
This commit is contained in:
@@ -40,6 +40,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr int QuantPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
static constexpr int N_Pack = 2;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -47,6 +48,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I4 = number<4>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
@@ -149,7 +151,21 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
|
||||
auto scale_n = kargs.scale_n_ptr;
|
||||
|
||||
index_t FlatScaleK =
|
||||
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
const auto scale_b_flat_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(scale_n.ptr,
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(
|
||||
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -215,7 +231,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
@@ -275,6 +291,12 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
auto scale_block_window =
|
||||
make_tile_window(views.at(I4),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
|
||||
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
|
||||
}
|
||||
|
||||
@@ -304,8 +326,13 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I3);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
|
||||
b_flat_block_window,
|
||||
scale_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
|
||||
|
||||
Reference in New Issue
Block a user