mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
update
This commit is contained in:
@@ -252,10 +252,11 @@ struct MoeFlatmmKernel
|
||||
#else
|
||||
false;
|
||||
#endif
|
||||
static constexpr int MXFP4M_Pack = MXF8F6F4MFMA ? 1 : 2;
|
||||
static constexpr int MXFP4N_Pack = MXF8F6F4MFMA ? 1 : 2;
|
||||
static constexpr int MXFP4K_Pack = MXF8F6F4MFMA ? 4 : 2;
|
||||
static constexpr int MXFP4M_Pack = 2;
|
||||
static constexpr int MXFP4N_Pack = 2;
|
||||
static constexpr int MXFP4K_Pack = 2;
|
||||
|
||||
static constexpr int M_Pack = AQUANT_Pipeline ? MXFP4M_Pack : 1;
|
||||
static constexpr int N_Pack = BMXFP4_Pipeline ? MXFP4N_Pack : 1;
|
||||
static constexpr int K_Pack = BMXFP4_Pipeline ? MXFP4K_Pack : 1;
|
||||
|
||||
@@ -647,12 +648,12 @@ struct MoeFlatmmKernel
|
||||
constexpr int AGranularityK = 32;
|
||||
|
||||
//TODO: enable e8m0_t scale
|
||||
using AScaleType = float; //std::conditional_t<MXF8F6F4MFMA, e8m0_t, float>;
|
||||
// using AScaleType = e8m0_t; //std::conditional_t<MXF8F6F4MFMA, e8m0_t, float>;
|
||||
// using AScaleType = float; //std::conditional_t<MXF8F6F4MFMA, e8m0_t, float>;
|
||||
using AScaleType = e8m0_t; //std::conditional_t<MXF8F6F4MFMA, e8m0_t, float>;
|
||||
|
||||
const auto& scale_a_tensor_view = [&]() {
|
||||
// if constexpr(std::is_same_v<AScaleType, float>)
|
||||
// {
|
||||
if constexpr(std::is_same_v<AScaleType, float>)
|
||||
{
|
||||
index_t scale_m = kargs.M;
|
||||
index_t scale_k = AGranularityK == 0 ? 1 : (kargs.K + AGranularityK - 1) / AGranularityK;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -661,57 +662,57 @@ struct MoeFlatmmKernel
|
||||
make_tuple(scale_k, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
// }
|
||||
// else if constexpr(std::is_same_v<ScaleType, e8m0_t>)
|
||||
// {
|
||||
// constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
|
||||
// constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
|
||||
// index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
|
||||
// index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
|
||||
// // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
|
||||
// const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
// make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
|
||||
// const auto scale_a_desc = transform_tensor_descriptor(
|
||||
// scale_a_naive_desc,
|
||||
// make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
|
||||
// make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
// make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// return make_tensor_view<address_space_enum::global>(
|
||||
// reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
|
||||
// }
|
||||
}
|
||||
else if constexpr(std::is_same_v<ScaleType, e8m0_t>)
|
||||
{
|
||||
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
|
||||
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
|
||||
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
|
||||
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
|
||||
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
|
||||
const auto scale_a_desc = transform_tensor_descriptor(
|
||||
scale_a_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_n = kargs.scale_n;
|
||||
|
||||
constexpr int BGranularityK = decltype(scale_n)::GranularityK;
|
||||
const auto scale_b_flat_view = [&]() {
|
||||
// if constexpr(AQUANT_Pipeline)
|
||||
// {
|
||||
// constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(I1);
|
||||
// constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I1);
|
||||
// index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl);
|
||||
// index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl);
|
||||
// const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
|
||||
// make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
|
||||
// const auto scale_b_desc = transform_tensor_descriptor(
|
||||
// scale_b_navie_desc,
|
||||
// make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
|
||||
// make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
// make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
//
|
||||
// return make_tensor_view<address_space_enum::global>(
|
||||
// reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
|
||||
//
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
if constexpr(AQUANT_Pipeline)
|
||||
{
|
||||
constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I1);
|
||||
index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl);
|
||||
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl);
|
||||
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
|
||||
const auto scale_b_desc = transform_tensor_descriptor(
|
||||
scale_b_navie_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t scale_k = BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
|
||||
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
using ScaleType = std::conditional_t<BMXFP4_Pipeline, e8m0_t, float>;
|
||||
using ScaleType = e8m0_t;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
|
||||
@@ -719,7 +720,7 @@ struct MoeFlatmmKernel
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
// }
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -819,21 +820,36 @@ struct MoeFlatmmKernel
|
||||
|
||||
constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
|
||||
auto a_scale_block_window =
|
||||
// make_tile_window(views.at(I3),
|
||||
// make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
// number<TilePartitioner::KPerBlock / GranularityK>{}),
|
||||
// {coord_m, 0});
|
||||
make_tile_window(views.at(I3),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GranularityK>{}),
|
||||
{coord_m, 0});
|
||||
make_tuple(number<TilePartitioner::MPerBlock / M_Pack>{},
|
||||
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
|
||||
{i_m / M_Pack, 0});
|
||||
|
||||
// constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
|
||||
constexpr int XDLPerLoadScaleB =
|
||||
BMXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
|
||||
|
||||
auto b_scale_block_window =
|
||||
make_tile_window(views.at(I4),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
|
||||
XDLPerLoadScaleB / GranularityK>{}),
|
||||
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
auto b_scale_block_window = [&]() {
|
||||
if constexpr(MXF8F6F4MFMA)
|
||||
{
|
||||
return make_tile_window(views.at(I4),
|
||||
make_tuple(number<TilePartitioner::NPerBlock / N_Pack>{},
|
||||
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
|
||||
{coord_n / N_Pack, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(views.at(I4),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
|
||||
XDLPerLoadScaleB / GranularityK>{}),
|
||||
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
@@ -947,26 +963,9 @@ struct MoeFlatmmKernel
|
||||
// so don't need extra processing
|
||||
if constexpr(AQUANT_Pipeline)
|
||||
{
|
||||
constexpr int AGranularityK = decltype(kargs.scale_m)::GranularityK;
|
||||
constexpr auto a_scale_dram_dist = FlatmmPipeline::GetAScaleDramTileDistribution();
|
||||
constexpr ck_tile::index_t DramMScaleRepeat =
|
||||
decltype(a_scale_dram_dist)::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
|
||||
statically_indexed_array<ck_tile::index_t, DramMScaleRepeat> a_scale_offsets;
|
||||
static_for<0, DramMScaleRepeat, 1>{}([&](auto m0) {
|
||||
const auto row_idx =
|
||||
coord_m + m0 * (TilePartitioner::MPerBlock / DramMScaleRepeat) + a_coord[I0];
|
||||
index_t gather_token_id = row_to_token_idx(row_idx);
|
||||
a_scale_offsets[m0] = gather_token_id * kargs.stride_A / AGranularityK;
|
||||
});
|
||||
auto a_scale_gather_block_tile =
|
||||
ck_tile::make_tile_scatter_gather(a_scale_block_window.get_bottom_tensor_view(),
|
||||
a_scale_block_window.get_window_lengths(),
|
||||
a_scale_block_window.get_window_origin(),
|
||||
a_scale_dram_dist,
|
||||
a_scale_offsets); // K DRAM tile window for
|
||||
return FlatmmPipeline{}(a_gather_block_tile,
|
||||
b_block_window,
|
||||
a_scale_gather_block_tile, // weight scale with granularityK = 32
|
||||
a_scale_block_window, // weight scale with granularityK = 32
|
||||
b_scale_block_window, // weight scale with granularityK = 32
|
||||
num_loop,
|
||||
// kargs.k_padded_zeros,
|
||||
|
||||
@@ -2493,7 +2493,7 @@ template <typename ADataType_,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct F8xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
struct F8xMXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
ADataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
@@ -2512,16 +2512,14 @@ struct F8xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
static constexpr int ScaleGranularityK = 32;
|
||||
|
||||
static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
|
||||
static constexpr int MXdlPack = 1; // it's fixed for fp4
|
||||
static constexpr int NXdlPack = 1; // it's fixed for fp4
|
||||
static constexpr int KXdlPack = 4;
|
||||
static constexpr int ContinuousScaleNPerThread = 1; // it's fixed for fp4
|
||||
static constexpr int ContinuousScaleKPerThread = 4; // it's fixed for fp4
|
||||
static constexpr int MXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int NXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int KXdlPack = 2;
|
||||
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
|
||||
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = F8xMXF4FlatmmPipelineAgBgCrPolicy>
|
||||
template <typename Problem, typename PipelinePolicy = MXF8FlatmmPipelineAgBgCrPolicy>
|
||||
struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
|
||||
{
|
||||
using Underlying = FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>;
|
||||
@@ -2945,11 +2943,6 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Pr
|
||||
return PipelinePolicy::template MakeADramTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAScaleDramTileDistribution()
|
||||
{
|
||||
return PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
@@ -2989,7 +2982,7 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Pr
|
||||
a_copy_dram_window_tmp.get_bottom_tensor_view()),
|
||||
a_copy_dram_window_tmp.get_window_lengths(),
|
||||
a_copy_dram_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
|
||||
@@ -237,7 +237,292 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
};
|
||||
|
||||
struct F8xMXF4FlatmmPipelineAgBgCrPolicy : MXF4FlatmmPipelineAgBgCrPolicy
|
||||
// struct F8xMXF4FlatmmPipelineAgBgCrPolicy : MXF4FlatmmPipelineAgBgCrPolicy
|
||||
// {
|
||||
// static constexpr auto I0 = number<0>{};
|
||||
// static constexpr auto I1 = number<1>{};
|
||||
// static constexpr auto I2 = number<2>{};
|
||||
//
|
||||
// static constexpr index_t kDramLoadPackBytes = 128;
|
||||
//
|
||||
// static constexpr int MXdlPack = 1;
|
||||
// static constexpr int NXdlPack = 1;
|
||||
// static constexpr int KXdlPack = 4;
|
||||
//
|
||||
// template <typename Problem>
|
||||
// static inline constexpr auto wg_attr_num_access =
|
||||
// std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
|
||||
// ? WGAttrNumAccessEnum::Single
|
||||
// : WGAttrNumAccessEnum::Double;
|
||||
//
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
|
||||
// {
|
||||
// using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
// using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
// static_assert(
|
||||
// sizeof(ADataType) * numeric_traits<BDataType>::PackedSize ==
|
||||
// sizeof(BDataType) * numeric_traits<ADataType>::PackedSize,
|
||||
// "sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!");
|
||||
// using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
// using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
// using WarpGemm = WarpGemmDispatcher< //
|
||||
// ADataType,
|
||||
// BDataType,
|
||||
// typename Problem::CDataType,
|
||||
// WarpTile::at(I0),
|
||||
// WarpTile::at(I1),
|
||||
// WarpTile::at(I2),
|
||||
// Problem::TransposeC,
|
||||
// false,
|
||||
// false,
|
||||
// wg_attr_num_access<Problem>>;
|
||||
// using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
|
||||
// ADataType,
|
||||
// BDataType,
|
||||
// typename Problem::CDataType,
|
||||
// BlockWarps,
|
||||
// WarpGemm>;
|
||||
// return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
|
||||
// }
|
||||
//
|
||||
// template <typename Problem, typename TensorView>
|
||||
// CK_TILE_DEVICE static constexpr auto
|
||||
// MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
|
||||
// {
|
||||
// using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
// using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
// constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
// constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
// static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
// static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
//
|
||||
// const auto& naive_desc = naive_view.get_tensor_descriptor();
|
||||
// constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
|
||||
// static_assert(ndims == 2, "only support 2D tensor");
|
||||
// const auto rows = naive_desc.get_length(number<0>{});
|
||||
// const auto cols = naive_desc.get_length(number<1>{});
|
||||
//
|
||||
// constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
// constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
// constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
// const index_t K0 = cols / (K1 * K2);
|
||||
// const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
|
||||
//
|
||||
// constexpr index_t M1 = 4; // so that we can use imm offset to load lds
|
||||
// const index_t M0 = rows / M1;
|
||||
// const auto row_lens = make_tuple(M0, number<M1>{});
|
||||
//
|
||||
// const auto desc_0 =
|
||||
// make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
|
||||
// const auto desc_1 = transform_tensor_descriptor(
|
||||
// desc_0,
|
||||
// make_tuple(make_pass_through_transform(M0),
|
||||
// make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
|
||||
// make_pass_through_transform(K0),
|
||||
// make_pass_through_transform(number<K2>{})),
|
||||
// make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
||||
// const auto desc = transform_tensor_descriptor( //
|
||||
// desc_1,
|
||||
// make_tuple(make_merge_transform_v3_division_mod(row_lens),
|
||||
// make_merge_transform_v3_division_mod(col_lens)),
|
||||
// make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// // printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1));
|
||||
//
|
||||
// return tensor_view<typename TensorView::buffer_view,
|
||||
// remove_cvref_t<decltype(desc)>,
|
||||
// TensorView::DstInMemOp>{naive_view.buf_, desc};
|
||||
// }
|
||||
//
|
||||
// template <typename Problem>
|
||||
// CK_TILE_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
// {
|
||||
//
|
||||
// using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
// using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
// static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
//
|
||||
// constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
//
|
||||
// constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
// constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
// constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
//
|
||||
// constexpr index_t M2 = get_warp_size() / K1; // 8
|
||||
// constexpr index_t M1 = BlockSize / get_warp_size(); // 4
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
|
||||
// static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
//
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding< //
|
||||
// sequence<1>,
|
||||
// tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
|
||||
// tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
|
||||
// tuple<sequence<1>, sequence<2, 1>>,
|
||||
// sequence<1, 2, 2>, // M0,K0,K2
|
||||
// sequence<0, 0, 2>>{});
|
||||
// }
|
||||
//
|
||||
// template <typename Problem>
|
||||
// CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
|
||||
// {
|
||||
// using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
// using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
// constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
// constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
// static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
// static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
//
|
||||
// /*reduce transform layers,compare with old ck*/
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
// constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
// constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
// constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
// static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
//
|
||||
// constexpr index_t M3 = 4; // so that we can use imm offset to load lds
|
||||
// constexpr index_t M2 = get_warp_size() / K1 / M3; // 2
|
||||
// constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
|
||||
// constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
|
||||
// static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
|
||||
//
|
||||
// constexpr index_t Pad = 4 * K2; // 4 * 32
|
||||
//
|
||||
// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
// make_tuple(number<M0>{},
|
||||
// number<M1>{},
|
||||
// number<K0>{},
|
||||
// number<M2>{},
|
||||
// number<M3>{},
|
||||
// number<K1>{},
|
||||
// number<K2>{}),
|
||||
// make_tuple(number<M1*(K0 * (M2 * M3 * K1 * K2) + (K0 - 1) * Pad)>{},
|
||||
// number<K0*(M2 * M3 * K1 * K2) + (K0 - 1) * Pad>{},
|
||||
// number<M2 * M3 * K1 * K2 + Pad>{},
|
||||
// number<M3 * K1 * K2>{},
|
||||
// number<K1 * K2>{},
|
||||
// number<K2>{},
|
||||
// number<1>{}),
|
||||
// number<K2>{},
|
||||
// number<1>{});
|
||||
//
|
||||
// constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor(
|
||||
// a_lds_block_desc_0,
|
||||
// make_tuple(make_pass_through_transform(M0),
|
||||
// make_pass_through_transform(M1),
|
||||
// make_pass_through_transform(K0),
|
||||
// make_pass_through_transform(M2),
|
||||
// make_xor_transform(make_tuple(number<M3>{}, number<K1>{})),
|
||||
// make_pass_through_transform(number<K2>{})),
|
||||
// make_tuple(sequence<0>{},
|
||||
// sequence<1>{},
|
||||
// sequence<2>{},
|
||||
// sequence<3>{},
|
||||
// sequence<4, 5>{},
|
||||
// sequence<6>{}),
|
||||
// make_tuple(sequence<0>{},
|
||||
// sequence<1>{},
|
||||
// sequence<2>{},
|
||||
// sequence<3>{},
|
||||
// sequence<4, 5>{},
|
||||
// sequence<6>{}));
|
||||
// constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
// a_lds_block_desc_1,
|
||||
// make_tuple(make_merge_transform_v3_division_mod(
|
||||
// make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
|
||||
// make_merge_transform_v3_division_mod(
|
||||
// make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
|
||||
// make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
//
|
||||
// // return a_lds_block_desc_permuted;
|
||||
// return a_lds_block_desc;
|
||||
// }
|
||||
//
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
|
||||
// {
|
||||
// using TileShape = typename Problem::BlockGemmShape;
|
||||
//
|
||||
// static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
|
||||
// static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
//
|
||||
// constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
|
||||
// constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
|
||||
// constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16
|
||||
//
|
||||
// constexpr int K_Lane = 64 / M_Lane; // 4
|
||||
//
|
||||
// constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32
|
||||
// constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
|
||||
// constexpr int K1 = K_Thread / num_access_v; // 16
|
||||
//
|
||||
// return make_static_tile_distribution(
|
||||
// std::conditional_t<
|
||||
// num_access_v == 1,
|
||||
// tile_distribution_encoding<
|
||||
// sequence<N_warps>,
|
||||
// tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
// sequence<2>,
|
||||
// sequence<1>>,
|
||||
// tile_distribution_encoding< //
|
||||
// sequence<N_warps>,
|
||||
// tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<num_access_v, K_Lane, K1>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
// sequence<2, 2>,
|
||||
// sequence<0, 2>>>{});
|
||||
// }
|
||||
//
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
|
||||
// {
|
||||
// using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
//
|
||||
// constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// constexpr index_t WaveSize = get_warp_size();
|
||||
// constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
//
|
||||
// constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0);
|
||||
//
|
||||
// constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
|
||||
// constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
|
||||
//
|
||||
// static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
//
|
||||
// constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
|
||||
// constexpr index_t K_Lanes = 64 / M_Lanes;
|
||||
//
|
||||
// // Y dimension (M) decomposition
|
||||
// constexpr index_t Y2 = M_Lanes;
|
||||
// constexpr index_t Y1 = M_Warps;
|
||||
// constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
|
||||
//
|
||||
// // X dimension (K) decomposition
|
||||
// constexpr index_t X0 = K_Lanes;
|
||||
// constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
//
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<sequence<N_Warps>, // repeat N_warps
|
||||
// tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<1, 0>, sequence<0, 2>>,
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 1>>{});
|
||||
// }
|
||||
// };
|
||||
|
||||
struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
@@ -245,9 +530,9 @@ struct F8xMXF4FlatmmPipelineAgBgCrPolicy : MXF4FlatmmPipelineAgBgCrPolicy
|
||||
|
||||
static constexpr index_t kDramLoadPackBytes = 128;
|
||||
|
||||
static constexpr int MXdlPack = 1;
|
||||
static constexpr int NXdlPack = 1;
|
||||
static constexpr int KXdlPack = 4;
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
static constexpr int KXdlPack = 2;
|
||||
|
||||
template <typename Problem>
|
||||
static inline constexpr auto wg_attr_num_access =
|
||||
|
||||
Reference in New Issue
Block a user