This commit is contained in:
coderfeli
2025-08-06 02:06:13 +00:00
parent 080bfd881a
commit cd45fe941d
5 changed files with 69 additions and 138 deletions

View File

@@ -10,6 +10,6 @@ endif()
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16_F8=1 -Wno-unused-local-typedef)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x128_F8=1 -Wno-unused-local-typedef)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0")
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1")
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

View File

@@ -33,17 +33,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
// static constexpr index_t kM = BlockTile::at(number<0>{});
// static constexpr index_t kN = BlockTile::at(number<1>{});
// static constexpr index_t kK = BlockTile::at(number<2>{});
// static constexpr bool PermuteA = PermuteA_;
// static constexpr bool PermuteB = PermuteB_;
// static constexpr index_t flatNPerWarp = BlockWarps::at(number<1>{}); // 4
// static constexpr index_t flatKPerWarp = WarpTile::at(number<2>{}) * WarpTile::at(number<1>{});// 16 * 64
// static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(number<2>{}); // 16 * 128
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
FlatmmConfig::TileParitionerGroupNum,

View File

@@ -350,6 +350,7 @@ struct FlatmmKernel
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{

View File

@@ -71,10 +71,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
static constexpr index_t GetVectorSizeA()
{
static_assert(PipelinePolicy::template GetVectorSizeA<Problem>()==16);
return PipelinePolicy::template GetVectorSizeA<Problem>();
}
static constexpr index_t GetVectorSizeB()
{
static_assert(PipelinePolicy::template GetVectorSizeB<Problem>()==16);
return PipelinePolicy::template GetVectorSizeB<Problem>();
}
@@ -706,19 +708,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// Prefill A0
// if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
// {
// auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
// PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
// shuffle_tile(a_shuffle_tmp, a_block_tile);
// const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
// store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
// }
// else
// {
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile));
// }
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
__builtin_amdgcn_sched_barrier(0);
@@ -760,17 +749,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
// {
while(iCounter > 0)
{
// prefetch B(2i+1)
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
// move_tile_window(b_flat_dram_windows(nIter)(kIter),
// {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
// b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
// });
// });
b_warp_tensor_pong = load_tile(b_flat_dram_window);
@@ -805,15 +783,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// if constexpr(mIter==0 && nIter ==0)
// if(threadIdx.x % 16== 0 && threadIdx.x<64){
// for(int i=0;i<b_warp_tensor_ping(nIter)(kIter).get_thread_buffer_size();i++) {
// printf("tid=%u, i0 %d bval=%f\n", threadIdx.x, i, type_convert<float>(b_warp_tensor_ping(nIter)(kIter).thread_buf_(i)));
// }
// }
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -840,22 +811,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// HotLoopScheduler();
HotLoopScheduler();
//Next K
// prefetch B(2i+2)
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
// move_tile_window(b_flat_dram_windows(nIter)(kIter),
// {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
// b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
// });
// });
b_warp_tensor_ping = load_tile(b_flat_dram_window);
// Prefill A(2i+2)
@@ -883,23 +841,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
a_warp_tensor.get_thread_buffer() = a_warp_tensor_ping(AwarpIter).get_y_sliced_thread_data(
merge_sequences(sequence<kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1>{}, a_warp_y_lengths));
// warp GEMM
// if(threadIdx.x % 16 == 0 && threadIdx.x>=192){
// for(int i=0;i<b_warp_tensor_ping(nIter)(kIter).get_thread_buffer_size();i++) {
// printf("tid=%u, aval01 %f bval=%f\n", threadIdx.x, type_convert<float>(a_warp_tensor.thread_buf_(i)), type_convert<float>(b_warp_tensor_ping(nIter)(kIter).thread_buf_(i)));
// }
// for(int i=0;i<a_warp_tensor_ping(AwarpIter).get_thread_buffer_size();i++) {
// printf("tid=%u, aval2 %f\n", threadIdx.x, type_convert<float>(a_warp_tensor_ping(AwarpIter).thread_buf_[i]));
// }
// }
// if constexpr(mIter==0 && nIter ==0)
// if(threadIdx.x % 16== 0 && threadIdx.x<64){
// for(int i=0;i<b_warp_tensor_pong(nIter)(kIter).get_thread_buffer_size();i++) {
// printf("tid=%u, i1 %d bval=%f\n", threadIdx.x, i, type_convert<float>(b_warp_tensor_pong(nIter)(kIter).thread_buf_(i)));
// }
// }
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tensor_pong.get_y_sliced_thread_data(
@@ -932,7 +873,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// HotLoopScheduler();
HotLoopScheduler();
iCounter--;
}
@@ -1042,45 +983,45 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
// // __builtin_amdgcn_sched_barrier(0);
// }
// else if constexpr(TailNum == TailNumber::Odd)
if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = number<mIter % m_preload>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// if constexpr(TailNum == TailNumber::Odd)
// {
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// constexpr auto AwarpIter = number<mIter % m_preload>{};
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// // read C warp tensor from C block tensor
// CWarpTensor c_warp_tensor;
AWarpTensor a_warp_tensor;
// AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tensor_ping(AwarpIter).get_y_sliced_thread_data(
merge_sequences(sequence<kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1>{}, a_warp_y_lengths));
// a_warp_tensor.get_thread_buffer() = a_warp_tensor_ping(AwarpIter).get_y_sliced_thread_data(
// merge_sequences(sequence<kIter>{}, a_warp_y_index_zeros),
// merge_sequences(sequence<1>{}, a_warp_y_lengths));
// set_tile(a_warp_tensor, 1.0f);
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tensor_ping.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// // set_tile(a_warp_tensor, 1.0f);
// c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// BWarpTensor b_warp_tensor;
// b_warp_tensor.get_thread_buffer() = b_warp_tensor_ping.get_y_sliced_thread_data(
// merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// __builtin_amdgcn_sched_barrier(0x7F6);
});
});
if constexpr(mIter < MIterPerWarp - 2)
{
a_warp_tensor_ping(AwarpIter) = load_tile(a_warp_windows_ping(number<mIter + 2>{}));
}
});
}
// // write C warp tensor into C block tensor
// c_block_tile.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
// c_warp_tensor.get_thread_buffer());
// // __builtin_amdgcn_sched_barrier(0x7F6);
// });
// });
// if constexpr(mIter < MIterPerWarp - 2)
// {
// a_warp_tensor_ping(AwarpIter) = load_tile(a_warp_windows_ping(number<mIter + 2>{}));
// }
// });
// }
// }
return c_block_tile;

View File

@@ -23,38 +23,38 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
// if constexpr(MPerXdl == 16 && NPerXdl == 16)
// {
// /*reduce transform layers,compare with old ck*/
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// constexpr index_t KPack = GetSmemPackA<Problem>();
if constexpr(MPerXdl == 16 && NPerXdl == 16)
{
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
// make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
// make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
// number<KPack>{},
// number<1>{});
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack *MPerBlock >{}, number<KPack>{}, number<1>{}),
number<KPack>{},
number<1>{});
// constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
// a_lds_block_desc_0,
// make_tuple(make_xor_transform(
// make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
// make_pass_through_transform(number<KPack>{})),
// make_tuple(sequence<1, 0>{}, sequence<2>{}),
// make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(
make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
// constexpr auto a_lds_block_desc = transform_tensor_descriptor(
// a_lds_block_desc_permuted,
// make_tuple(make_pass_through_transform(number<MPerBlock>{}),
// make_merge_transform_v3_division_mod(
// make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc;
// }
// else
return a_lds_block_desc;
}
else
{
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;