diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index da37159aeb..20cc202176 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -23,7 +23,7 @@ args: -n n dimension (default:2048) -k k dimension (default:64) -a_layout Tensor A data layout (default: R) - -b_layout Tensor B data layout (default: R) + -b_layout Tensor B data layout (default: C) -c_layout Tensor C data layout (default: R) -stride_a Tensor A stride (default:0) -stride_b Tensor B stride (default:0) diff --git a/example/ck_tile/37_transpose/transpose_policy.hpp b/example/ck_tile/37_transpose/transpose_policy.hpp index ea1a4130fe..b7e52a94f7 100644 --- a/example/ck_tile/37_transpose/transpose_policy.hpp +++ b/example/ck_tile/37_transpose/transpose_policy.hpp @@ -48,8 +48,8 @@ struct TransposePolicy constexpr auto input_dstr = MakeLdsLoadTileDistribution(); using OutTileDstrEncode = - typename OutputTileDistributionTraits, - typename Problem::DataType>::OutDstrEncode; + typename OutputTileDistributionTraits::TransposedDstrEncode; constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{}); return block_dstr; @@ -131,7 +131,9 @@ struct TransposePolicy constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim; constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations; + constexpr index_t kLaneGroupSize = 16; constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); } - else if constexpr(std::is_same_v, ck_tile::fp8_t>) + else if constexpr(std::is_same_v, ck_tile::fp8_t> || + std::is_same_v, ck_tile::bf8_t> || + std::is_same_v, ck_tile::int8_t>) { - typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t; - __attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>( + typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t; + __attribute__((address_space(3))) llvm_i32x2_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>( reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index ca4ff8ca7e..568a5be64c 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2611,11 +2611,13 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); } - else if constexpr(std::is_same_v, ck_tile::fp8_t>) + else if constexpr(std::is_same_v, ck_tile::fp8_t> || + std::is_same_v, ck_tile::bf8_t> || + std::is_same_v, ck_tile::int8_t>) { - typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t; - __attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>( + typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t; + __attribute__((address_space(3))) llvm_i32x2_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>( reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } diff --git a/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp b/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp index 7ffe6dc0fb..665be1b167 100644 --- a/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp +++ b/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp @@ -10,53 +10,55 @@ namespace ck_tile { // this generate wave level tile distribution -template +template struct LaneGroupTransposeTraits; -template -struct LaneGroupTransposeTraits> +template +struct LaneGroupTransposeTraits> { + static_assert(LaneGroupSize == 16 || LaneGroupSize == 32 || LaneGroupSize == 64, + "LaneGroupSize must be 16, 32, or 64"); // before transpose, 4x16 static constexpr index_t ksecondDim = 4; - static constexpr index_t kleadDim = 16; + static constexpr index_t kleadDim = LaneGroupSize; // after transpose, 16x4 - static constexpr index_t ksecondDimT = 16; + static constexpr index_t ksecondDimT = LaneGroupSize; static constexpr index_t kleadDimT = 4; template - using TileDistribution = - tile_distribution_encoding, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 1, 2>, - sequence<1, 1, 3>>; + using TileDistribution = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 1, 2>, + sequence<1, 1, 4>>; }; -template -struct LaneGroupTransposeTraits> +template +struct LaneGroupTransposeTraits> { static constexpr index_t ksecondDim = 8; - static constexpr index_t kleadDim = 16; + static constexpr index_t kleadDim = LaneGroupSize; - static constexpr index_t ksecondDimT = 16; + static constexpr index_t ksecondDimT = LaneGroupSize; static constexpr index_t kleadDimT = 8; template - using TileDistribution = - tile_distribution_encoding, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 1, 2>, - sequence<1, 1, 3>>; + using TileDistribution = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 1, 2>, + sequence<1, 1, 4>>; }; /* @@ -72,15 +74,15 @@ struct LaneGroupTransposeTraits> * consecutive. */ template CK_TILE_DEVICE constexpr auto make_transposed_distr_encode() { - using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits:: - template TileDistribution; - return xdllevel_dstr_encoding{}; + return typename LaneGroupTransposeTraits:: + template TileDistribution{}; } } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 5cae332007..13b038bc48 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -994,51 +994,34 @@ struct buffer_view" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix + // clang-format off static_assert( - (std::is_same_v, int8_t> && - std::is_same_v, int8_t>) || - (std::is_same_v, int8_t> && - std::is_same_v, int8x2_t>) || - (std::is_same_v, int8_t> && - std::is_same_v, int8x4_t>) || - (std::is_same_v, int8_t> && - std::is_same_v, int8x8_t>) || - (std::is_same_v, int8_t> && - std::is_same_v, int8x16_t>) || - (std::is_same_v, int8x4_t> && - std::is_same_v, int8x4_t>) || - (std::is_same_v, int8x8_t> && - std::is_same_v, int8x8_t>) || - (std::is_same_v, int8x16_t> && - std::is_same_v, int8x16_t>) || + (std::is_same_v, int8_t> && std::is_same_v, int8_t>) || + (std::is_same_v, int8_t> && std::is_same_v, int8x2_t>) || + (std::is_same_v, int8_t> && std::is_same_v, int8x4_t>) || + (std::is_same_v, int8_t> && std::is_same_v, int8x8_t>) || + (std::is_same_v, int8_t> && std::is_same_v, int8x16_t>) || + (std::is_same_v, int8x4_t> && std::is_same_v, int8x4_t>) || + (std::is_same_v, int8x8_t> && std::is_same_v, int8x8_t>) || + (std::is_same_v, int8x16_t> && std::is_same_v, int8x16_t>) || // int8 on thread buffer - (std::is_same_v, int8_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, int8_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, int8_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, int8_t> && - std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || // ext_vector_type for pk_int4 must use int8_t as type - (std::is_same_v, pk_int4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4x4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4x8_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4x16_t> && - std::is_same_v, thread_buffer>), + (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x16_t> && std::is_same_v, thread_buffer>), "wrong! not implemented for this combination, please add " "implementation"); + // clang-format on if constexpr((std::is_same_v, int8_t> && std::is_same_v, int8_t>) || @@ -1090,6 +1073,8 @@ struct buffer_view, int8_t> && std::is_same_v, int8x16_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, thread_buffer>) || (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>)) { diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index d178ccb72c..ceb7e18556 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -17,6 +17,11 @@ namespace ck_tile { +constexpr int DS_READ_TR_SIZE() +{ + return 8; // Literal constant, evaluated at compile time +} + namespace util { template struct is_sequence_suffix @@ -45,48 +50,60 @@ constexpr bool is_sequence_suffix_v = is_sequence_suffix::valu template struct DefaultTranspose { + template struct Quad16 { - using InputEncoding = tile_distribution_encoding, - tuple, sequence<4, 4>>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16, + "LaneGroupSize must be 64, 32, or 16"); + using InputEncoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<2>>; - using OutputEncoding = tile_distribution_encoding, - tuple, sequence<4>>, - tuple>, - tuple>, - sequence<2>, - sequence<0>>; + using OutputEncoding = + tile_distribution_encoding, + tuple, sequence<4>>, + tuple>, + tuple>, + sequence<2>, + sequence<0>>; }; + template struct Quad8 { - using InputEncoding = tile_distribution_encoding, - tuple, sequence<2, 8>>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16, + "LaneGroupSize must be 64, 32, or 16"); + using InputEncoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<2>>; - using OutputEncoding = tile_distribution_encoding, - tuple, sequence<8>>, - tuple>, - tuple>, - sequence<2>, - sequence<0>>; + using OutputEncoding = + tile_distribution_encoding, + tuple, sequence<8>>, + tuple>, + tuple>, + sequence<2>, + sequence<0>>; }; // Select based on data size + template using QuadInputEncoding = std::conditional_t; + typename Quad16::InputEncoding, + typename Quad8::InputEncoding>; + template using QuadOutputEncoding = std::conditional_t; + typename Quad16::OutputEncoding, + typename Quad8::OutputEncoding>; // Always swap last two dimensions static constexpr auto transpose_dims = sequence<1, 0>{}; @@ -96,51 +113,79 @@ struct DefaultTranspose return idx; // Identity mapping }; - template - struct ValidationTraits + template + struct ValidationTraitsImpl { - static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_; - static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_; + using QuadEncoding = std::conditional_t, + QuadInputEncoding>; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto input_hs = InDstrEncode::hs_lengthss_; + static constexpr auto quad_hs = QuadEncoding::hs_lengthss_; // 1. Must be 2D tensor static constexpr bool dims_valid = (InDstrEncode::NDimX == 2); // 2. Quad pattern must be suffix of input pattern static constexpr bool suffix_valid_dim0 = - util::is_sequence_suffix_v()), - decltype(input_hs_lengthss.template get<0>())>; + util::is_sequence_suffix_v; static constexpr bool suffix_valid_dim1 = - util::is_sequence_suffix_v()), - decltype(input_hs_lengthss.template get<1>())>; + util::is_sequence_suffix_v; // 3. PS→RHS mapping constraints - static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_; - static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_; + static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_; + static constexpr auto input_ps_minor = InDstrEncode::ps_to_rhss_minor_; - static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1; - static constexpr index_t ndimp_inner = - input_ps_to_rhss_major[number{}].size() - 1; + static constexpr auto quad_ps_major0 = QuadEncoding::ps_to_rhss_major_[I0]; + static constexpr auto quad_ps_minor0 = QuadEncoding::ps_to_rhss_minor_[I0]; + + static constexpr auto input_ps_major_last = + input_ps_major[number{}]; + static constexpr auto input_ps_minor_last = + input_ps_minor[number{}]; + + using psys_offset = ck_tile::sequence; + static constexpr auto shifted_quad_ps_minor0 = generate_sequence_v2( + [](auto i) { + return number{}; + }, + number{}); static constexpr bool ps_mapping_valid = - (input_ps_to_rhss_major[number{}][number{}] == 2) && - (input_ps_to_rhss_minor[number{}][number{}] == - input_hs_lengthss[number<1>{}].size() - 2) && - (input_ps_to_rhss_major[number{}][number{}] == 1) && - (input_ps_to_rhss_minor[number{}][number{}] == - input_hs_lengthss[number<0>{}].size() - 1); + util::is_sequence_suffix_v && + util::is_sequence_suffix_v; // 4. YS→RHS mapping constraints - static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_; - static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_; + static constexpr auto input_ys_major = InDstrEncode::ys_to_rhs_major_; + static constexpr auto input_ys_minor = InDstrEncode::ys_to_rhs_minor_; + static constexpr auto quad_ys_major = QuadEncoding::ys_to_rhs_major_; + static constexpr auto quad_ys_minor = QuadEncoding::ys_to_rhs_minor_; + static_assert(quad_ys_major.size() == 1 && quad_ys_minor.size() == 1, + "YS->RHS mapping must be single dimension"); + static_assert(quad_ys_major.back() == 2 && quad_ys_minor.back() == quad_hs[I1].size() - 1, + "YS->RHS mapping must be the last dimension"); static constexpr bool ys_mapping_valid = - (input_ys_to_rhs_major.back() == 2) && - (input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) && - (input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) && - (input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] == - input_hs_lengthss[number<0>{}].size() - 2); + (input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1); static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 && ps_mapping_valid && ys_mapping_valid; }; + + template + struct ValidationTraits + { + static constexpr bool value = + ValidationTraitsImpl::value || + ValidationTraitsImpl::value || + ValidationTraitsImpl::value; + static constexpr index_t LaneGroupSize = + ValidationTraitsImpl::value ? 64 + : ValidationTraitsImpl::value ? 32 + : ValidationTraitsImpl::value ? 16 + : 0; + }; }; template struct TransposeTileDistrChecker @@ -154,111 +199,150 @@ struct TransposeTileDistrChecker // this is used to generate the transposed output tile distribution encoding // based on the input tile distribution encoding -template > -struct OutputTileDistributionTraits + typename Policy = DefaultTranspose, + bool ReverseDirection = false> +struct TransposeTileDistributionTraits { - using InDstrEncode = typename remove_cvref_t::DstrEncode; - static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_; - static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_; - static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_; + using InDstrEncode = remove_cvref_t; + static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_; + static constexpr index_t LaneGroupSize = + Policy::template ValidationTraits::LaneGroupSize; + static_assert(Policy::template ValidationTraits::value, + "The input tile distribution encoding is not valid for transpose!"); + + using QuadInputEncoding = std::conditional_t< // + ReverseDirection, + typename Policy::template QuadOutputEncoding, + typename Policy::template QuadInputEncoding>; + using QuadOutputEncoding = std::conditional_t< // + ReverseDirection, + typename Policy::template QuadInputEncoding, + typename Policy::template QuadOutputEncoding>; + + static constexpr auto quad_input_hs_lengthss = QuadInputEncoding::hs_lengthss_; + static constexpr auto quad_output_hs_lengthss = QuadOutputEncoding::hs_lengthss_; static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_; static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_; static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_; static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_; - static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_; - static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_; + static constexpr auto I0 = number<0>{}; + static constexpr auto quad_input_ps_to_rhss_major0 = QuadInputEncoding::ps_to_rhss_major_[I0]; + static constexpr auto quad_input_ps_to_rhss_minor0 = QuadInputEncoding::ps_to_rhss_minor_[I0]; + static constexpr auto quad_output_ps_to_rhss_major0 = QuadOutputEncoding::ps_to_rhss_major_[I0]; + static constexpr auto quad_output_ps_to_rhss_minor0 = QuadOutputEncoding::ps_to_rhss_minor_[I0]; + static constexpr auto quad_output_ys_to_rhs_major = QuadOutputEncoding::ys_to_rhs_major_; + static constexpr auto quad_output_ys_to_rhs_minor = QuadOutputEncoding::ys_to_rhs_minor_; + + static constexpr index_t dim0 = Policy::transpose_dims[0]; + static constexpr index_t dim1 = Policy::transpose_dims[1]; + + static constexpr auto swap_one_and_two = [](const index_t idx) { + return (idx == 1) ? 2 : (idx == 2) ? 1 : idx; + }; // for transpose load - // append the reversed quad output hs lengths to the input hs lengthss after removing - // the quad_input_hs_lengthss - // then reverse the whole sequence to get the dst_out_hs_lengthss - static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss); - - static constexpr auto full_out_hs_lengthss = generate_tuple( + // remove the quad_input_hs_lengthss from the input_hs_lengthss for each dimension and reverse + // dims and append the quad_output_hs_lengthss to the end of each dimension + static constexpr auto outer_hs_lengthss = generate_tuple( [](auto i) { - return input_hs_lengthss[i] - .extract(typename arithmetic_sequence_gen<0, - input_hs_lengthss[i].size() - - quad_input_hs_lengthss[i].size(), - 1>::type{}) - .push_back(reversed_quad_output_hs_lengthss[i]); + constexpr auto input_i = input_hs_lengthss[i]; + constexpr auto outer_len = input_i.size() - quad_input_hs_lengthss[i].size(); + return typename sequence_split::left_type{}; + }, + number{}); + static constexpr auto reversed_outer_hs_lengthss = tuple_reverse(outer_hs_lengthss); + static constexpr auto dst_out_hs_lengthss = generate_tuple( + [](auto i) { + auto outer_i = reversed_outer_hs_lengthss[i]; + // append the reversed quad output hs lengths to the outer hs lengths + return outer_i.push_back(quad_output_hs_lengthss[i]); }, number{}); - static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss); - - // for PS→RHS mapping(both major and minor), we need to modify the last element of the major - // sequence - static constexpr auto modified_ps_to_rhss_major = generate_tuple( + // for PS→RHS mapping(both major and minor), we need to modify the last element (which is for + // thread distr) of the major sequence + static constexpr auto dst_ps_to_rhss_major = generate_tuple( + // for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed [](auto i) { if constexpr(i == input_ps_to_rhss_major.size() - 1) { constexpr auto current_size = input_ps_to_rhss_major[i].size(); - constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size(); + constexpr auto reduce_size = quad_input_ps_to_rhss_major0.size(); + constexpr auto quad_out = quad_output_ps_to_rhss_major0; constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract( typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{}); - return reduced_ps_to_rhss_major.push_back(number<2>{}); + return reduced_ps_to_rhss_major.transform(swap_one_and_two).push_back(quad_out); } else { - // For all other sequences, keep them unchanged - return input_ps_to_rhss_major[i]; + // For all other sequences (i.e. warp), keep them unchanged + return input_ps_to_rhss_major[i].transform(swap_one_and_two); } }, number{}); - static constexpr auto minor_last_index = - full_out_hs_lengthss[number{}].size() - 1; - static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1; + static constexpr auto quad_idx_offset = + transform_tuples([](auto x) { return number{}; }, reversed_outer_hs_lengthss); + + // minus 1 because RsLength is not counted + static constexpr auto quad_output_ps_minor_offset = to_sequence(generate_tuple_for( + [](auto x) { return quad_idx_offset[number{}]; }, quad_output_ps_to_rhss_major0)); + static constexpr auto quad_output_ys_minor_offset = to_sequence(generate_tuple_for( + [](auto x) { return quad_idx_offset[number{}]; }, quad_output_ys_to_rhs_major)); static constexpr auto dst_ps_to_rhss_minor = generate_tuple( [](auto i) { + constexpr auto input_i = input_ps_to_rhss_minor[i]; if constexpr(i == input_ps_to_rhss_minor.size() - 1) { - constexpr auto current_size = input_ps_to_rhss_minor[i].size(); - constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size(); - constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract( - typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{}); - return reduced_ps_to_rhss_minor.push_back(number{}); + constexpr auto outer_len = input_i.size() - quad_input_ps_to_rhss_minor0.size(); + constexpr auto outer_ps = + typename sequence_split::left_type{}; + + return outer_ps.push_back(quad_output_ps_minor_offset + + quad_output_ps_to_rhss_minor0); } else { // For all other sequences, keep them unchanged - return input_ps_to_rhss_minor[i]; + return input_i; } }, number{}); + static constexpr auto outer_input_ys_to_rhs_major = input_ys_to_rhs_major.pop_back(); + // for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed - static constexpr auto swap_one_and_two = [](const index_t idx) { - return (idx == 1) ? 2 : (idx == 2) ? 1 : idx; - }; - static constexpr auto dst_ps_to_rhss_major = generate_tuple( - [](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); }, - number{}); + static constexpr auto dst_ys_to_rhs_major = + outer_input_ys_to_rhs_major.transform(swap_one_and_two).push_back(number<2>{}); - static constexpr auto modified_input_ys_to_rhs_major = - input_ys_to_rhs_major.pop_back().push_back(number<1>{}); + static constexpr auto dst_ys_to_rhs_minor = input_ys_to_rhs_minor.pop_back().push_back( + number<(quad_output_ys_minor_offset + quad_output_ys_to_rhs_minor)[I0]>{}); - static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2( - [](auto i) { return number{}; }, - number{}); - - static constexpr auto dst_ys_to_rhs_minor = - input_ys_to_rhs_minor.pop_back().push_back(number{}); - - using OutDstrEncode = tile_distribution_encoding, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t>; + using TransposedDstrEncode = + tile_distribution_encoding, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>; }; +template > +using OutputTileDistributionTraits = + TransposeTileDistributionTraits; +template > +using InputTileDistributionTraits = + TransposeTileDistributionTraits; + template & tile_window) { - using OutTileDstrEncode = - typename OutputTileDistributionTraits::OutDstrEncode; + using OutTileDstrEncode = typename OutputTileDistributionTraits< + typename TileDistribution_::DstrEncode, + typename BottomTensorView_::DataType>::TransposedDstrEncode; auto out_tensor = make_static_distributed_tensor( make_static_tile_distribution(OutTileDstrEncode{})); auto trans_tensor = tile_window.template load_transpose(); diff --git a/include/ck_tile/core/utility/debug.hpp b/include/ck_tile/core/utility/debug.hpp new file mode 100644 index 0000000000..261bf50148 --- /dev/null +++ b/include/ck_tile/core/utility/debug.hpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include +#include + +#include "ck_tile/core/numeric/integer.hpp" + +namespace ck_tile { +template +[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT() +{ +} +template +[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT() +{ +} + +template +struct str_literal +{ + static constexpr const char data[] = {Xs..., '\0'}; + static constexpr const size_t size = sizeof...(Xs); + + template + CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal /*rhs*/) const + { + return str_literal{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal sep) + { + if constexpr(N == 0) + return str_literal<>{}; + else if constexpr(N == 1) + return str_literal{}; + else + return duplicate_n(sep) + str_literal{}; + } +}; + +#define make_str_literal(lit_) \ + std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \ + makeTuple(std::make_index_sequence())) + +template +constexpr std::tuple...> + makeTuple(std::index_sequence) noexcept +{ + return {}; +} +constexpr size_t constexpr_strlen(const char* c) +{ + size_t t = 0; + while(*c++) + ++t; + return t; +} + +template +struct static_distributed_tensor; + +template +struct thread_buffer; + +// Usage example: CK_PRINTF{}(tensor); +template , + typename PREFIX = str_literal<>, + typename SUFFIX = str_literal<>> +struct CK_PRINTF; +template +struct CK_PRINTF, + str_literal, + str_literal> +{ + template + CK_TILE_HOST_DEVICE static constexpr auto default_format() + { + if constexpr(std::is_same_v) + return make_str_literal("%8.3f"); + else if constexpr(std::is_same_v) + return make_str_literal("%5d"); + else if constexpr(std::is_same_v) + return make_str_literal("%5u"); + else + return make_str_literal("0x%08x"); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_prefix() + { + constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] "); + if constexpr(sizeof...(PREFIXChars) == 0) + return fmt_tid; + else + return fmt_tid + make_str_literal(" ") + str_literal{}; + } + CK_TILE_HOST_DEVICE static constexpr auto get_suffix() + { + constexpr auto lf = make_str_literal("\n"); + if constexpr(sizeof...(SUFFIXChars) == 0) + return lf; + else + return str_literal{} + lf; + } + + template + CK_TILE_HOST_DEVICE void impl(const thread_buffer& buf, + std::integer_sequence) const + { + using FMT1 = std::conditional_t()), + str_literal>; + constexpr auto fmt_v = FMT1::template duplicate_n(make_str_literal(" ")); + constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix(); + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" + printf(fmt_wrap_v.data, get_thread_id(), N, type_convert(buf[Is])...); +#pragma clang diagnostic pop + } + + template + CK_TILE_HOST_DEVICE void operator()(const thread_buffer& buf) const + { + using ConvertTo_ = std::conditional_t, T, ConvertTo>; + impl(buf, std::make_integer_sequence{}); + } + + template + CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor& tensor) const + { + return operator()(tensor.get_thread_buffer()); + } +}; + +template , + typename PREFIX = str_literal<>, + typename SUFFIX = str_literal<>> +struct CK_PRINTF_WARP0 : public CK_PRINTF +{ + using base_t = CK_PRINTF; + + template + CK_TILE_HOST_DEVICE void operator()(const T& buf) const + { + if(get_thread_id() < get_warp_size()) + base_t::operator()(buf); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index f21136d2a8..30bea193b7 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -15,9 +15,9 @@ #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" @@ -29,14 +29,14 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index f1e8bcc0a8..b396f03244 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -13,9 +13,9 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" @@ -44,10 +44,10 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index 8dd1d1ec28..862fa0bbe3 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" namespace ck_tile { @@ -15,6 +16,19 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { +#if defined(__gfx950__) + constexpr bool is_a_load_tr = std::is_same_v, + tensor_layout::gemm::ColumnMajor>; + constexpr bool is_b_load_tr = std::is_same_v, + tensor_layout::gemm::RowMajor>; +#else + constexpr bool is_a_load_tr = false; + constexpr bool is_b_load_tr = false; +#endif + constexpr auto wg_attr_num_access = (is_a_load_tr || is_b_load_tr) + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single; + if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) @@ -40,14 +54,34 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); } #else - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + using WG = WarpGemmMfmaDispatcher; + return make_tuple(WG{}, 4, 1); #endif } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); + using WG = WarpGemmMfmaDispatcher; + return make_tuple(WG{}, 4, 1); } else { diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index d4e23d12dd..e1b0792ecf 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -218,10 +218,16 @@ struct BlockUniversalGemmAsBsCr BLdsTile b_warp_tile_; // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as correspoinding " @@ -300,14 +306,23 @@ struct BlockUniversalGemmAsBsCr ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { if constexpr(std::is_same_v) { load_interleaved_pk_type(a_warp_tile_, a_block_window); } + else if constexpr(ALoadTranspose) + { + a_warp_tile_ = load_tile_transpose(a_block_window); + } else { load_tile(a_warp_tile_, a_block_window); @@ -316,6 +331,10 @@ struct BlockUniversalGemmAsBsCr { load_interleaved_pk_type(b_warp_tile_, b_block_window); } + else if constexpr(BLoadTranspose) + { + b_warp_tile_ = load_tile_transpose(b_block_window); + } else { load_tile(b_warp_tile_, b_block_window); @@ -323,10 +342,16 @@ struct BlockUniversalGemmAsBsCr } // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, [[maybe_unused]] ASmemBlockWindow& a_block_window, - [[maybe_unused]] BSmemBlockWindow& b_block_window) + [[maybe_unused]] BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as correspoinding " @@ -382,40 +407,73 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; static constexpr auto ALdsTileDistr = - decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + make_static_tile_distribution(MakeABlockDistributionEncode()); static constexpr auto BLdsTileDistr = - decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + make_static_tile_distribution(MakeBBlockDistributionEncode()); using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); ALdsTile a_warp_tile_; - ALdsTile b_warp_tile_; + BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(MakeBBlockDistributionEncode()); + constexpr auto a_lds_load_distr = [&]() { + if constexpr(ALoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeABlockDistributionEncode()), + ADataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeABlockDistributionEncode()); + }(); + constexpr auto b_lds_load_distr = [&]() { + if constexpr(BLoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeBBlockDistributionEncode()), + BDataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeBBlockDistributionEncode()); + }(); + constexpr auto a_lds_shape = []() { + if constexpr(ALoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto b_lds_shape = []() { + if constexpr(BLoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto k_idx_offset = KIdx * KPerInnerLoop; + constexpr auto a_offset = + ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + constexpr auto b_offset = + BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; auto a_lds_gemm_window = make_tile_window( - a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {0, KIdx * KPerInnerLoop}, - a_lds_load_tile_distr); + a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr); auto b_lds_gemm_window = make_tile_window( - b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {0, KIdx * KPerInnerLoop}, - b_lds_load_tile_distr); + b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); if constexpr(std::is_same_v) { load_interleaved_pk_type(a_warp_tile_, a_block_window); } + else if constexpr(ALoadTranspose) + { + a_warp_tile_ = load_tile_transpose(a_lds_gemm_window); + } else { load_tile(a_warp_tile_, a_lds_gemm_window); @@ -424,6 +482,10 @@ struct BlockUniversalGemmAsBsCr { load_interleaved_pk_type(b_warp_tile_, b_block_window); } + else if constexpr(BLoadTranspose) + { + b_warp_tile_ = load_tile_transpose(b_lds_gemm_window); + } else { load_tile(b_warp_tile_, b_lds_gemm_window); @@ -431,10 +493,16 @@ struct BlockUniversalGemmAsBsCr } // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as correspoinding " @@ -442,7 +510,7 @@ struct BlockUniversalGemmAsBsCr // hot loop: static_for<0, KRepeat, 1>{}([&](auto kIter) { - LocalPrefetch(a_block_window, b_block_window); + LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); __builtin_amdgcn_sched_barrier(0); // NOTE: Synchronize threads in a workgroup at the start of each MAC // cluster, but except the first, as we can shorten non-MAC cluster a bit @@ -543,29 +611,45 @@ struct BlockUniversalGemmAsBsCr return c_block_tensor; } - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); } // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_(c_block_tensor, a_block_window, b_block_window); + block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr); } // C = A * B - template + template CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { auto c_block_tensor = MakeCBlockTile(); - block_gemm_impl_(c_block_tensor, a_block_window, b_block_window); + block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr); return c_block_tensor; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 6861adb153..2bee550b3c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -20,6 +20,13 @@ struct GemmPipelineAgBgCrImplBase static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; +#if defined(__gfx950__) + static constexpr bool is_a_load_tr = std::is_same_v; + static constexpr bool is_b_load_tr = std::is_same_v; +#else + static constexpr bool is_a_load_tr = false; + static constexpr bool is_b_load_tr = false; +#endif CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } @@ -50,11 +57,15 @@ struct GemmPipelineAgBgCrImplBase store_tile(lds_tile_window, block_tile_tmp); } - template + template CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile, - const SrcTileWindow& lds_tile_window) const + const SrcTileWindow& lds_tile_window, + bool_constant = {}) const { - load_tile(dst_block_tile, lds_tile_window); + if constexpr(LoadTranspose) + dst_block_tile = load_tile_transpose(lds_tile_window); + else + load_tile(dst_block_tile, lds_tile_window); } CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const @@ -96,14 +107,25 @@ struct GemmPipelineAgBgCrImplBase Policy::template MakeADramTileDistribution()); // A LDS tile window for store - auto a_copy_lds_window = make_tile_window( - a_lds_block_view, make_tuple(number{}, number{}), {0, 0}); + auto a_lds_shape = []() { + if constexpr(is_a_load_tr) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}); + auto a_lds_load_tile_distr = []() { + if constexpr(is_a_load_tr) + return make_static_tile_distribution( + typename InputTileDistributionTraits< + typename ALdsLoadTileDistr::DstrEncode, + typename Problem::ADataType>::TransposedDstrEncode{}); + else + return ALdsLoadTileDistr{}; + }(); auto a_lds_gemm_window = - make_tile_window(a_lds_block_view, - make_tuple(number{}, number{}), - {0, 0}, - ALdsLoadTileDistr{}); + make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr); return make_tuple(std::move(a_copy_dram_window), std::move(a_copy_lds_window), @@ -130,14 +152,25 @@ struct GemmPipelineAgBgCrImplBase // TODO: Do we really need those two tile windows??? // They're exactly same... // B LDS tile window for store - auto b_copy_lds_window = make_tile_window( - b_lds_block_view, make_tuple(number{}, number{}), {0, 0}); + auto b_lds_shape = []() { + if constexpr(is_b_load_tr) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); + auto b_lds_load_tile_distr = []() { + if constexpr(is_b_load_tr) + return make_static_tile_distribution( + typename InputTileDistributionTraits< + typename BLdsLoadTileDistr::DstrEncode, + typename Problem::BDataType>::TransposedDstrEncode{}); + else + return BLdsLoadTileDistr{}; + }(); auto b_lds_gemm_window = - make_tile_window(b_lds_block_view, - make_tuple(number{}, number{}), - {0, 0}, - BLdsLoadTileDistr{}); + make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr); return make_tuple(std::move(b_copy_dram_window), std::move(b_copy_lds_window), diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 6d0db060cd..8f54e4eda6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -153,6 +153,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop); static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; using Base::UsePersistentKernel; @@ -467,7 +470,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -478,7 +481,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -494,7 +497,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); __builtin_amdgcn_sched_barrier(0); @@ -506,7 +510,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -517,7 +521,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -536,7 +540,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -578,7 +583,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); } block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } // __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 8e6bab21be..ac91c2f58f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -141,6 +141,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off @@ -305,17 +308,23 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); - auto a_copy_lds_window0 = make_tile_window( - a_lds_block0, make_tuple(number{}, number{}), {0, 0}); + constexpr auto a_lds_shape = []() { + if constexpr(is_a_load_tr_v()) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0}); + auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0}); - auto a_copy_lds_window1 = make_tile_window( - a_lds_block1, make_tuple(number{}, number{}), {0, 0}); - - auto b_copy_lds_window0 = make_tile_window( - b_lds_block0, make_tuple(number{}, number{}), {0, 0}); - - auto b_copy_lds_window1 = make_tile_window( - b_lds_block1, make_tuple(number{}, number{}), {0, 0}); + constexpr auto b_lds_shape = []() { + if constexpr(is_b_load_tr_v()) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0}); + auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0}); // Block GEMM auto block_gemm = BlockGemm(); @@ -325,7 +334,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -336,7 +345,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -354,51 +363,53 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 block_sync_lds(); - constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution( - BlockGemm::MakeABlockDistributionEncode())){}; - constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution( - BlockGemm::MakeBBlockDistributionEncode())){}; + constexpr auto ALdsTileDistr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto BLdsTileDistr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + ALdsTile a_block_tile0, a_block_tile1; + BLdsTile b_block_tile0, b_block_tile1; - ALdsTile a_block_tile0; - ALdsTile a_block_tile1; - - BLdsTile b_block_tile0; - BLdsTile b_block_tile1; - + constexpr auto a_lds_input_tile_distr = [&]() { + if constexpr(is_a_load_tr_v()) + return make_static_tile_distribution( + typename InputTileDistributionTraits< + decltype(BlockGemm::MakeABlockDistributionEncode()), + typename Problem::ADataType>::TransposedDstrEncode{}); + else + return ALdsTileDistr; + }(); + constexpr auto b_lds_input_tile_distr = [&]() { + if constexpr(is_b_load_tr_v()) + return make_static_tile_distribution( + typename InputTileDistributionTraits< + decltype(BlockGemm::MakeBBlockDistributionEncode()), + typename Problem::BDataType>::TransposedDstrEncode{}); + else + return BLdsTileDistr; + }(); auto a_lds_ld_window0 = - make_tile_window(a_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr); auto a_lds_ld_window1 = - make_tile_window(a_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr); auto b_lds_ld_window0 = - make_tile_window(b_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr); auto b_lds_ld_window1 = - make_tile_window(b_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr); - static_assert( - !(is_tile_window_linear_v)&&!(is_tile_window_linear_v)&&!( - is_tile_window_linear_v< - decltype(b_lds_ld_window0)>)&&!(is_tile_window_linear_v), - "LDS windows must not be linear"); + static_assert(!is_tile_window_linear_v && + !is_tile_window_linear_v && + !is_tile_window_linear_v && + !is_tile_window_linear_v, + "LDS windows must not be linear"); - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); - Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -409,7 +420,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -433,10 +444,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // ping { block_sync_lds(); - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -448,7 +459,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 Base::LocalPrefill( a_copy_lds_window0, a_global_load_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -473,10 +484,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // pong { block_sync_lds(); - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); - Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -488,7 +499,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 Base::LocalPrefill( a_copy_lds_window1, a_global_load_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -521,9 +532,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // 3 { block_sync_lds(); - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); - if constexpr(is_a_col_major) + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -534,7 +545,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -550,8 +561,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // 2 { block_sync_lds(); - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); block_gemm(c_block_tile, a_block_tile1, b_block_tile1); } // 1 @@ -565,8 +576,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // 2 { block_sync_lds(); - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); block_gemm(c_block_tile, a_block_tile0, b_block_tile0); static_for<0, 8, 1>{}([&](auto i) { ignore = i; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index a42ddd93a0..4e9a70140e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -21,15 +21,27 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { // using AccDataType = float; - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + constexpr bool single_load_tr_length = + (DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) == + (WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size()); + constexpr auto wg_attr_num_access = + ((is_a_load_tr || is_b_load_tr)&&!single_load_tr_length) + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single; + using WarpGemm = WarpGemmMfmaDispatcher; + Problem::TransposeC, + false, + false, + wg_attr_num_access>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off @@ -272,10 +275,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto& b_lds_block = ab_lds_blocks.at(I1{}); // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution( - BlockGemm::MakeABlockDistributionEncode())){}; - constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution( - BlockGemm::MakeBBlockDistributionEncode())){}; + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); // A DRAM tile window for load // A LDS tile window for store @@ -332,7 +335,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -343,7 +346,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -373,12 +376,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -394,7 +398,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -427,12 +431,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -445,7 +450,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -461,14 +466,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem }); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); }; if constexpr(TailNum == TailNumber::One) { block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } else if constexpr(TailNum == TailNumber::Two) @@ -558,10 +565,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto& b_lds_block = ab_lds_blocks.at(I1{}); // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution( - BlockGemm::MakeABlockDistributionEncode())){}; - constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution( - BlockGemm::MakeBBlockDistributionEncode())){}; + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); // A DRAM tile window for load // A LDS tile window for store @@ -617,7 +624,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -628,7 +635,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -658,10 +665,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, + a_lds_gemm_window, + b_lds_gemm_window, + is_a_load_tr_v, + is_b_load_tr_v); // no second block_sync_lds because it's interwave - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -677,7 +688,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -709,10 +720,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto HotLoopTail = [&](auto tail_num) { static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, + a_lds_gemm_window, + b_lds_gemm_window, + is_a_load_tr_v, + is_b_load_tr_v); // no second block_sync_lds because it's interwave - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -725,7 +740,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -741,13 +756,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem }); block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, + a_lds_gemm_window, + b_lds_gemm_window, + is_a_load_tr_v, + is_b_load_tr_v); }; if constexpr(TailNum == TailNumber::One) { block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, + a_lds_gemm_window, + b_lds_gemm_window, + is_a_load_tr_v, + is_b_load_tr_v); } else if constexpr(TailNum == TailNumber::Two) { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 881467cb94..2335c4eced 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -47,6 +47,8 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; + static constexpr bool Preshuffle = Problem::Preshuffle; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr index_t kLdsAlignmentInBytes = 16; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index c19d42ce25..52bd07c9e2 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -49,6 +49,9 @@ struct GemmPipelineProblemBase static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr index_t VectorLoadSize = Traits::_VectorSize; + // In the base situation, the Preshuffle setting should be false. + static constexpr bool Preshuffle = false; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index d5f2eedf2d..6820e82d09 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -12,6 +12,20 @@ namespace ck_tile { template struct UniversalGemmBasePolicy { +#if defined(__gfx950__) + template + static constexpr bool is_a_load_tr = + std::is_same_v, tensor_layout::gemm::ColumnMajor>; + template + static constexpr bool is_b_load_tr = + std::is_same_v, tensor_layout::gemm::RowMajor>; +#else + template + static constexpr bool is_a_load_tr = false; + template + static constexpr bool is_b_load_tr = false; +#endif + static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; @@ -22,51 +36,65 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - using ADataType = remove_cvref_t; + using ADataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); - constexpr auto DataTypeSize = sizeof(ADataType); - constexpr auto MLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + if constexpr(is_a_load_tr) + { + // TODO: better lds descriptor for performance + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return a_lds_block_desc_0; + } + else + { + constexpr index_t KPack = GetSmemPackA(); - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - return a_lds_block_desc; + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } } /** @@ -78,14 +106,24 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - // using BLayout = remove_cvref_t; using BDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; #if 1 - // if constexpr(std::is_same_v) + if constexpr(is_b_load_tr) + { + // TODO: better lds descriptor for performance + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( // + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return b_lds_block_desc_0; + } + else + // else if constexpr(std::is_same_v) { constexpr index_t KPack = GetSmemPackB(); constexpr auto BK0 = number{}; @@ -584,8 +622,18 @@ struct UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + constexpr index_t vector_size = DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); + constexpr auto wg_attr_num_access = + !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; + using WarpGemm = WarpGemmMfmaDispatcher; + Problem::UseStructuredSparsity, + wg_attr_num_access>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr index_t Preshuffle = Problem::Preshuffle; + static constexpr bool Preshuffle = Problem::Preshuffle; using Base::UsePersistentKernel; [[nodiscard]] CK_TILE_HOST static const std::string GetName() diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 185abccd3f..ae25bf0711 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -21,22 +21,29 @@ using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; #if defined(__gfx950__) +template using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; - + WarpGemmAtrributeMfma, + AttrNumAccess>>; #else +template using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) +template using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; #else +template using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl>>; #if defined(__gfx950__) +template using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplF16F16F32M32N32K16, + AttrNumAccess>>; #else +template using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) +template using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplF16F16F32M16N16K32, + AttrNumAccess>>; #else +template using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) @@ -123,22 +138,29 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; #if defined(__gfx950__) +template using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; - + WarpGemmAtrributeMfma, + AttrNumAccess>>; #else +template using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) +template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; #else +template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl>>; #if defined(__gfx950__) +template using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16, + AttrNumAccess>>; #else +template using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) +template using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32, + AttrNumAccess>>; #else +template using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) @@ -247,17 +277,25 @@ using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl>>; +template using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; +template using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; +template using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; +template using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl +// Number of groups of consecutive elements to fill in a ABKLane +enum class WGAttrNumAccessEnum +{ + Single = 1, + Double = 2, + Quad = 4, + Invalid = -1 +}; + +template struct WarpGemmAtrributeMfma { - using Impl = remove_cvref_t; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; @@ -31,21 +43,35 @@ struct WarpGemmAtrributeMfma static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - using AWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; - - using BWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + template + static constexpr auto get_warp_dstr_encoding() + { + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -73,12 +99,16 @@ struct WarpGemmAtrributeMfma } }; -template +template struct WarpGemmAtrributeMfmaIterateK { static_assert(kKIter > 0, "wrong!"); - using Impl = remove_cvref_t; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; @@ -104,17 +134,37 @@ struct WarpGemmAtrributeMfmaIterateK { if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } } else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); // each M blocks share the same data return tile_distribution_encoding< sequence, @@ -127,6 +177,8 @@ struct WarpGemmAtrributeMfmaIterateK } else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); // single block to multi-block thread mapping return tile_distribution_encoding< sequence<>, @@ -143,17 +195,38 @@ struct WarpGemmAtrributeMfmaIterateK { if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } } else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); // single block to multi-block thread mapping return tile_distribution_encoding< sequence<>, @@ -166,6 +239,8 @@ struct WarpGemmAtrributeMfmaIterateK } else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); // each N blocks share the same data return tile_distribution_encoding< sequence, @@ -289,10 +364,13 @@ struct WarpGemmAtrributeMfmaIterateK } }; -template +template struct WarpGemmAtrributeMfmaTransposedCDistribution { - using Impl = remove_cvref_t; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); using ADataType = typename Impl::BDataType; using BDataType = typename Impl::ADataType; @@ -312,21 +390,35 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - using AWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; - - using BWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + template + static constexpr auto get_warp_dstr_encoding() + { + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -450,10 +542,13 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB } }; -template +template struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution { - using Impl = remove_cvref_t; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; // swap A and B using ADataType = typename Impl::BDataType; @@ -478,80 +573,14 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() { - if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) - { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) - { - // single block to multi-block thread mapping - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) - { - // each N blocks share the same data - return tile_distribution_encoding< - sequence, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } + return WarpGemmAtrributeMfmaIterateK:: + get_bwarp_dstr_encoding(); } CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() { - if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) - { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) - { - // each M blocks share the same data - return tile_distribution_encoding< - sequence, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) - { - // single block to multi-block thread mapping - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } + return WarpGemmAtrributeMfmaIterateK:: + get_awarp_dstr_encoding(); } CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index b6ada83532..4e5d102e35 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -16,8 +16,9 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false, + WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single> struct WarpGemmMfmaDispatcher; // clang-format off @@ -25,12 +26,20 @@ struct WarpGemmMfmaDispatcher; // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaF16F16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaF16F16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M4N64K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M64N4K16; }; @@ -46,12 +55,20 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; }; @@ -80,10 +97,18 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; // int8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity @@ -102,8 +127,9 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false, + WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single> using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher::Type; + UseStructuredSparsity, + AttrNumAccess>::Type; } // namespace ck_tile diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 450a3a538f..7b519760b9 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -333,8 +333,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::HostTensor c_m_n_dev_result( f_host_tensor_descriptor(M, N, stride_C, CLayout{})); - ck_tile::FillUniformDistributionIntegerValue{-5, 5}(a_m_k); - ck_tile::FillUniformDistributionIntegerValue{-5, 5}(b_k_n); + ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11939}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11940}(b_k_n); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());