add mixed_prec fp16xfp4

This commit is contained in:
Feng Shijie
2025-08-08 20:19:16 +00:00
parent 3dea10a277
commit f788d3d629
9 changed files with 252 additions and 123 deletions

View File

@@ -39,6 +39,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr int QuantPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr index_t NumDTensor = DsDataType::size();
static constexpr auto I0 = number<0>();
@@ -89,16 +91,15 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
}
}();
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
return make_naive_tensor_view<address_space_enum::global>(b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<32>{},
number<1>{});
}();
const auto& ds_tensor_view = generate_tuple(
@@ -307,7 +308,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0))
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
@@ -346,8 +348,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
// options
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / QuantPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS