mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
temp
This commit is contained in:
@@ -10,6 +10,6 @@ endif()
|
||||
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16_F8=1 -Wno-unused-local-typedef)
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x128_F8=1 -Wno-unused-local-typedef)
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps)
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0")
|
||||
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1")
|
||||
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -33,17 +33,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
// static constexpr index_t kM = BlockTile::at(number<0>{});
|
||||
// static constexpr index_t kN = BlockTile::at(number<1>{});
|
||||
// static constexpr index_t kK = BlockTile::at(number<2>{});
|
||||
|
||||
// static constexpr bool PermuteA = PermuteA_;
|
||||
// static constexpr bool PermuteB = PermuteB_;
|
||||
|
||||
// static constexpr index_t flatNPerWarp = BlockWarps::at(number<1>{}); // 4
|
||||
// static constexpr index_t flatKPerWarp = WarpTile::at(number<2>{}) * WarpTile::at(number<1>{});// 16 * 64
|
||||
// static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(number<2>{}); // 16 * 128
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
|
||||
@@ -350,6 +350,7 @@ struct FlatmmKernel
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
|
||||
@@ -71,10 +71,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
static_assert(PipelinePolicy::template GetVectorSizeA<Problem>()==16);
|
||||
return PipelinePolicy::template GetVectorSizeA<Problem>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
static_assert(PipelinePolicy::template GetVectorSizeB<Problem>()==16);
|
||||
return PipelinePolicy::template GetVectorSizeB<Problem>();
|
||||
}
|
||||
|
||||
@@ -706,19 +708,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// Prefill A0
|
||||
// if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
|
||||
// {
|
||||
// auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
// PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
|
||||
// shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
// const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
// store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
// }
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -760,17 +749,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
// {
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// prefetch B(2i+1)
|
||||
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
// move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
// {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
// b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
// });
|
||||
// });
|
||||
|
||||
b_warp_tensor_pong = load_tile(b_flat_dram_window);
|
||||
|
||||
@@ -805,15 +783,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// if constexpr(mIter==0 && nIter ==0)
|
||||
// if(threadIdx.x % 16== 0 && threadIdx.x<64){
|
||||
// for(int i=0;i<b_warp_tensor_ping(nIter)(kIter).get_thread_buffer_size();i++) {
|
||||
// printf("tid=%u, i0 %d bval=%f\n", threadIdx.x, i, type_convert<float>(b_warp_tensor_ping(nIter)(kIter).thread_buf_(i)));
|
||||
// }
|
||||
// }
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -840,22 +811,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// HotLoopScheduler();
|
||||
HotLoopScheduler();
|
||||
|
||||
//Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
// move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
// {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
// b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
// });
|
||||
// });
|
||||
|
||||
b_warp_tensor_ping = load_tile(b_flat_dram_window);
|
||||
|
||||
// Prefill A(2i+2)
|
||||
@@ -883,23 +841,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tensor_ping(AwarpIter).get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1>{}, a_warp_y_lengths));
|
||||
// warp GEMM
|
||||
// if(threadIdx.x % 16 == 0 && threadIdx.x>=192){
|
||||
// for(int i=0;i<b_warp_tensor_ping(nIter)(kIter).get_thread_buffer_size();i++) {
|
||||
// printf("tid=%u, aval01 %f bval=%f\n", threadIdx.x, type_convert<float>(a_warp_tensor.thread_buf_(i)), type_convert<float>(b_warp_tensor_ping(nIter)(kIter).thread_buf_(i)));
|
||||
// }
|
||||
|
||||
// for(int i=0;i<a_warp_tensor_ping(AwarpIter).get_thread_buffer_size();i++) {
|
||||
// printf("tid=%u, aval2 %f\n", threadIdx.x, type_convert<float>(a_warp_tensor_ping(AwarpIter).thread_buf_[i]));
|
||||
// }
|
||||
// }
|
||||
// if constexpr(mIter==0 && nIter ==0)
|
||||
// if(threadIdx.x % 16== 0 && threadIdx.x<64){
|
||||
// for(int i=0;i<b_warp_tensor_pong(nIter)(kIter).get_thread_buffer_size();i++) {
|
||||
// printf("tid=%u, i1 %d bval=%f\n", threadIdx.x, i, type_convert<float>(b_warp_tensor_pong(nIter)(kIter).thread_buf_(i)));
|
||||
|
||||
// }
|
||||
// }
|
||||
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tensor_pong.get_y_sliced_thread_data(
|
||||
@@ -932,7 +873,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// HotLoopScheduler();
|
||||
HotLoopScheduler();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
@@ -1042,45 +983,45 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
// // __builtin_amdgcn_sched_barrier(0);
|
||||
// }
|
||||
// else if constexpr(TailNum == TailNumber::Odd)
|
||||
if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = number<mIter % m_preload>{};
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
// if constexpr(TailNum == TailNumber::Odd)
|
||||
// {
|
||||
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// constexpr auto AwarpIter = number<mIter % m_preload>{};
|
||||
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// // read C warp tensor from C block tensor
|
||||
// CWarpTensor c_warp_tensor;
|
||||
|
||||
AWarpTensor a_warp_tensor;
|
||||
// AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tensor_ping(AwarpIter).get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1>{}, a_warp_y_lengths));
|
||||
// a_warp_tensor.get_thread_buffer() = a_warp_tensor_ping(AwarpIter).get_y_sliced_thread_data(
|
||||
// merge_sequences(sequence<kIter>{}, a_warp_y_index_zeros),
|
||||
// merge_sequences(sequence<1>{}, a_warp_y_lengths));
|
||||
|
||||
// set_tile(a_warp_tensor, 1.0f);
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tensor_ping.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// // set_tile(a_warp_tensor, 1.0f);
|
||||
// c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
// BWarpTensor b_warp_tensor;
|
||||
// b_warp_tensor.get_thread_buffer() = b_warp_tensor_ping.get_y_sliced_thread_data(
|
||||
// merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
// merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
// __builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
});
|
||||
if constexpr(mIter < MIterPerWarp - 2)
|
||||
{
|
||||
a_warp_tensor_ping(AwarpIter) = load_tile(a_warp_windows_ping(number<mIter + 2>{}));
|
||||
}
|
||||
});
|
||||
}
|
||||
// // write C warp tensor into C block tensor
|
||||
// c_block_tile.set_y_sliced_thread_data(
|
||||
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
// c_warp_tensor.get_thread_buffer());
|
||||
// // __builtin_amdgcn_sched_barrier(0x7F6);
|
||||
// });
|
||||
// });
|
||||
// if constexpr(mIter < MIterPerWarp - 2)
|
||||
// {
|
||||
// a_warp_tensor_ping(AwarpIter) = load_tile(a_warp_windows_ping(number<mIter + 2>{}));
|
||||
// }
|
||||
// });
|
||||
// }
|
||||
// }
|
||||
|
||||
return c_block_tile;
|
||||
|
||||
@@ -23,38 +23,38 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
// if constexpr(MPerXdl == 16 && NPerXdl == 16)
|
||||
// {
|
||||
// /*reduce transform layers,compare with old ck*/
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
if constexpr(MPerXdl == 16 && NPerXdl == 16)
|
||||
{
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
// make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
// make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
// number<KPack>{},
|
||||
// number<1>{});
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack *MPerBlock >{}, number<KPack>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
// constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
// a_lds_block_desc_0,
|
||||
// make_tuple(make_xor_transform(
|
||||
// make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
// make_pass_through_transform(number<KPack>{})),
|
||||
// make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
// make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
// constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
// a_lds_block_desc_permuted,
|
||||
// make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
// make_merge_transform_v3_division_mod(
|
||||
// make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// return a_lds_block_desc;
|
||||
// }
|
||||
// else
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
Reference in New Issue
Block a user