diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index b25aec101b..47c47334e7 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -459,7 +459,7 @@ struct PipelineTypeTraits ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; }; -auto create_args() +inline auto create_args() { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3840", "m dimension") diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index c4f100b36b..78f3a9b0b3 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -197,8 +197,8 @@ bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, return pass; } -std::tuple -parse_gemm_size(ck_tile::ArgParser& arg_parser) +std::tuple inline parse_gemm_size( + ck_tile::ArgParser& arg_parser) { ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 3445f063f5..52b2b86574 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -986,6 +986,8 @@ struct MoeSortingKernel p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id; } } + __syncthreads(); + smem_cumdup(num_experts) = smem_cumsum(num_experts); // fill the p_sorted_token_ids/p_sorted_weights diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 63993c5eb6..838fc236d2 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -561,6 +561,7 @@ struct GroupedGemmKernel const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d); Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d); + block_sync_lds(); block_id = block_id + grid_size; // advance to next block // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR if(block_id >= cum_grid_size) diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index 91f1358321..6130107cfe 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -631,6 +631,7 @@ struct StreamKKernel tile_idx += kargs.tile_partitioner.get_grid()) { BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0); + block_sync_lds(); } // Stream-K section @@ -679,8 +680,8 @@ struct StreamKKernel { hipDeviceProp_t dev_prop; hipDevice_t dev; - hip_check_error(hipGetDevice(&dev)); - hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + ck_tile::hip_check_error(hipGetDevice(&dev)); + ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); int num_cu = dev_prop.multiProcessorCount; return num_cu; @@ -700,7 +701,7 @@ struct StreamKKernel constexpr int min_block_per_cu = 1; const auto kernel = kentry; - hip_check_error( + ck_tile::hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); return max(occupancy, 1); diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 4b28ac3f12..866a4cc693 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -280,7 +280,7 @@ struct UniversalGemmKernel using Kernel = UniversalGemmKernel; const auto kernel = kentry<1, Kernel, KernelArgs>; int occupancy; - hip_check_error( + ck_tile::hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0)); const int grid_size = get_available_compute_units(s) * occupancy; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 16ed8de22f..936c38ddf3 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -9,11 +9,35 @@ namespace ck_tile { +template +struct BaseGemmPipelineAGmemBGmemCRegV1 +{ + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr bool UsePersistentKernel = false; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; } + + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t) + { + return TailNumber::Empty; + } + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) + { + return run_func(bool_constant{}, integral_constant{}); + } +}; + // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register template -struct GemmPipelineAGmemBGmemCRegV1 +struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -48,14 +72,14 @@ struct GemmPipelineAGmemBGmemCRegV1 template static constexpr index_t GetVectorSizeA() { - return Problem::VectorSizeA; + return Policy::template GetVectorSizeA(); } template static constexpr index_t GetVectorSizeB() { - return Problem::VectorSizeB; + return Policy::template GetVectorSizeB(); } - static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index 5dbcde80a6..c711c768ec 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -9,11 +9,34 @@ namespace ck_tile { +template +struct BaseGemmPipelineAGmemBGmemCRegV2 +{ + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; } + + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t) + { + return TailNumber::Empty; + } + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) + { + return run_func(bool_constant{}, integral_constant{}); + } +}; // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register template -struct GemmPipelineAGmemBGmemCRegV2 +struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2 { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index d76fd6dc0f..47607a40f5 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -43,13 +43,14 @@ template + bool Preshuffle_ = false, + int VectorSize_ = 16> struct TileGemmUniversalTraits { static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; static constexpr bool kPadK = kPadK_; - static constexpr int _VectorSize = 16; + static constexpr int _VectorSize = VectorSize_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; using AsLayout = AsLayout_; diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index cbf4afefb2..ba6ed27651 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -425,7 +425,7 @@ struct BlockReduce2dCrossWarpSync if constexpr(num_reduce_warps == 1) return; - + block_sync_lds(); // Each warp's lane 0 writes its partial results to shared memory const index_t smem_offset = warp_id; if(lane_id == 0) diff --git a/include/ck_tile/utility/json_dump.hpp b/include/ck_tile/utility/json_dump.hpp index b5bab28cac..03e97c0b76 100644 --- a/include/ck_tile/utility/json_dump.hpp +++ b/include/ck_tile/utility/json_dump.hpp @@ -160,23 +160,23 @@ void dump_gemm_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_batched_gemm_json_results(const std::string& json_filename, - const std::string& op_name, - int M, - int N, - int K, - int stride_A, - int stride_B, - int stride_C, - int batch_stride_A, - int batch_stride_B, - int batch_stride_C, - int batch_count, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "batched_gemm_basic") +inline void dump_batched_gemm_json_results(const std::string& json_filename, + const std::string& op_name, + int M, + int N, + int K, + int stride_A, + int stride_B, + int stride_C, + int batch_stride_A, + int batch_stride_B, + int batch_stride_C, + int batch_count, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "batched_gemm_basic") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -218,20 +218,20 @@ void dump_grouped_gemm_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_flatmm_json_results(const std::string& json_filename, - const std::string& datatype, - int M, - int N, - int K, - int stride_A, - int stride_B, - int stride_C, - int kbatch, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "flatmm_basic") +inline void dump_flatmm_json_results(const std::string& json_filename, + const std::string& datatype, + int M, + int N, + int K, + int stride_A, + int stride_B, + int stride_C, + int kbatch, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "flatmm_basic") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -248,21 +248,22 @@ void dump_flatmm_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename, - const std::string& op_name, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideD0, - int StrideD1, - int StrideE, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "gemm_multi_d_fp16") +inline void +dump_gemm_multi_d_fp16_json_results(const std::string& json_filename, + const std::string& op_name, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "gemm_multi_d_fp16") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -280,14 +281,14 @@ void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_elementwise_json_results(const std::string& json_filename, - const std::string& prec, - int grid_size, - int block_size, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "elementwise") +inline void dump_elementwise_json_results(const std::string& json_filename, + const std::string& prec, + int grid_size, + int block_size, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "elementwise") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -298,22 +299,22 @@ void dump_elementwise_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_layernorm2d_fwd_json_results(const std::string& json_filename, - const std::string& prec_i, - const std::string& prec_o, - const std::string& prec_sm, - const std::string& prec_sy, - int m, - int n, - int x_stride, - int xr_stride, - int y_stride, - int yr_stride, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "layernorm2d_fwd") +inline void dump_layernorm2d_fwd_json_results(const std::string& json_filename, + const std::string& prec_i, + const std::string& prec_o, + const std::string& prec_sm, + const std::string& prec_sy, + int m, + int n, + int x_stride, + int xr_stride, + int y_stride, + int yr_stride, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "layernorm2d_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -357,13 +358,13 @@ void dump_reduce_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_permute_json_results(const std::string& json_filename, - const std::string& data_type, - bool pass, - float ave_time, - float tflop, - float gb_per_sec, - const std::string& kernel_name = "permute") +inline void dump_permute_json_results(const std::string& json_filename, + const std::string& data_type, + bool pass, + float ave_time, + float tflop, + float gb_per_sec, + const std::string& kernel_name = "permute") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -373,19 +374,19 @@ void dump_permute_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_topk_softmax_json(const std::string& json_filename, - const std::string& input_prec, - const std::string& weight_prec, - int tokens, - int experts, - int topk, - int stride_input, - int stride_output, - float ave_time, - float tflop, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "topk_softmax") +inline void dump_topk_softmax_json(const std::string& json_filename, + const std::string& input_prec, + const std::string& weight_prec, + int tokens, + int experts, + int topk, + int stride_input, + int stride_output, + float ave_time, + float tflop, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "topk_softmax") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -401,20 +402,20 @@ void dump_topk_softmax_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_rmsnorm2d_fwd_json(const std::string& json_filename, - const std::string& prec_str, - int m, - int n, - int x_stride, - int xr_stride, - int y_stride, - int yr_stride, - int use_model_sensitive_rmsnorm, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "rmsnorm2d_fwd") +inline void dump_rmsnorm2d_fwd_json(const std::string& json_filename, + const std::string& prec_str, + int m, + int n, + int x_stride, + int xr_stride, + int y_stride, + int yr_stride, + int use_model_sensitive_rmsnorm, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "rmsnorm2d_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -431,19 +432,19 @@ void dump_rmsnorm2d_fwd_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_add_rmsnorm2d_rdquant_fwd_json( - const std::string& json_filename, - const std::string& input_data_type, - const std::string& quantized_data_type, - int m, - int n, - int stride, - float epsilon, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd") +inline void +dump_add_rmsnorm2d_rdquant_fwd_json(const std::string& json_filename, + const std::string& input_data_type, + const std::string& quantized_data_type, + int m, + int n, + int stride, + float epsilon, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -458,17 +459,17 @@ void dump_add_rmsnorm2d_rdquant_fwd_json( END_JSON_DUMP_FILE(); } -void dump_smoothquant_json(const std::string& json_filename, - const std::string& prec_str, - int m, - int n, - int x_stride, - int y_stride, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "smoothquant") +inline void dump_smoothquant_json(const std::string& json_filename, + const std::string& prec_str, + int m, + int n, + int x_stride, + int y_stride, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "smoothquant") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -482,19 +483,19 @@ void dump_smoothquant_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_moe_sorting_json(const std::string& json_filename, - const std::string& index_prec, - const std::string& weight_prec, - const std::string& workspace_size, - int dispatch_policy, - int tokens, - int num_experts, - int topk, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "moe_sorting") +inline void dump_moe_sorting_json(const std::string& json_filename, + const std::string& index_prec, + const std::string& weight_prec, + const std::string& workspace_size, + int dispatch_policy, + int tokens, + int num_experts, + int topk, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "moe_sorting") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -510,19 +511,19 @@ void dump_moe_sorting_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_batched_transpose_json(const std::string& json_filename, - int N, - int C, - int H, - int W, - const std::string& layout_in, - const std::string& layout_out, - const std::string& prec, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "batched_transpose") +inline void dump_batched_transpose_json(const std::string& json_filename, + int N, + int C, + int H, + int W, + const std::string& layout_in, + const std::string& layout_out, + const std::string& prec, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "batched_transpose") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -538,19 +539,19 @@ void dump_batched_transpose_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_moe_smoothquant_json(const std::string& json_filename, - const std::string& prec_i, - const std::string& prec_o, - int tokens, - int hidden_size, - int stride, - int experts, - int topk, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "moe_smoothquant") +inline void dump_moe_smoothquant_json(const std::string& json_filename, + const std::string& prec_i, + const std::string& prec_o, + int tokens, + int hidden_size, + int stride, + int experts, + int topk, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "moe_smoothquant") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -566,26 +567,26 @@ void dump_moe_smoothquant_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_fused_moe_json(const std::string& json_filename, - const std::string& api_str, - const std::string& prec_str, - int tokens, - bool is_local_token, - int local_tokens, - int experts, - int topk, - int hidden_size, - int intermediate_size, - int stride, - int block_m, - int activation, - bool gate_only, - bool fused_quant, - bool pass, - float ave_time, - float tflops, - float tb_per_sec, - const std::string& kernel_name = "fused_moe") +inline void dump_fused_moe_json(const std::string& json_filename, + const std::string& api_str, + const std::string& prec_str, + int tokens, + bool is_local_token, + int local_tokens, + int experts, + int topk, + int hidden_size, + int intermediate_size, + int stride, + int block_m, + int activation, + bool gate_only, + bool fused_quant, + bool pass, + float ave_time, + float tflops, + float tb_per_sec, + const std::string& kernel_name = "fused_moe") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -610,29 +611,29 @@ void dump_fused_moe_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_fmha_fwd_json_results(const std::string& json_filename, - const std::string& prec, - const std::string& mode, - const std::string& io_layout, - int batch, - int nhead, - int nhead_k, - int seqlen_qs, - int seqlen_ks, - int seqlen_kpads, - int hdim_q, - int hdim_v, - float scale_s, - float p_drop, - bool lse, - const std::string& qscale, - const std::string& bias, - const std::string& vlayout, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "fmha_fwd") +inline void dump_fmha_fwd_json_results(const std::string& json_filename, + const std::string& prec, + const std::string& mode, + const std::string& io_layout, + int batch, + int nhead, + int nhead_k, + int seqlen_qs, + int seqlen_ks, + int seqlen_kpads, + int hdim_q, + int hdim_v, + float scale_s, + float p_drop, + bool lse, + const std::string& qscale, + const std::string& bias, + const std::string& vlayout, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "fmha_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -658,33 +659,33 @@ void dump_fmha_fwd_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_fmha_bwd_json_results(const std::string& json_filename, - const std::string& data_type, - const std::string& mode, - const std::string& i_perm, - const std::string& o_perm, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - float scale, - const std::string& bias, - bool use_dbias, - float p_drop, - bool s_randval, - bool deterministic, - const std::string& mask, - int mask_left, - int mask_right, - int workspace_size, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "fmha_bwd") +inline void dump_fmha_bwd_json_results(const std::string& json_filename, + const std::string& data_type, + const std::string& mode, + const std::string& i_perm, + const std::string& o_perm, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + const std::string& bias, + bool use_dbias, + float p_drop, + bool s_randval, + bool deterministic, + const std::string& mask, + int mask_left, + int mask_right, + int workspace_size, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "fmha_bwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp index 4fdbf23864..9b90110c07 100644 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp @@ -130,7 +130,7 @@ auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None) constexpr index_t kMPerBlock = Problem::kMPerBlock; constexpr index_t kNPerBlock = Problem::kNPerBlock; - constexpr index_t kBlockSize = Problem::kBlockSize; + index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize; std::cout << "Running CShuffleEpilogue test with M=" << M << ", N=" << N << ", MPerBlock=" << kMPerBlock << ", NPerBlock=" << kNPerBlock diff --git a/test/ck_tile/gemm_multi_abd/CMakeLists.txt b/test/ck_tile/gemm_multi_abd/CMakeLists.txt index 2dccf9cd60..03759652cd 100644 --- a/test/ck_tile/gemm_multi_abd/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_abd/CMakeLists.txt @@ -7,7 +7,7 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") add_gtest_executable(test_ck_tile_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) add_gtest_executable(test_ck_tile_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) target_compile_definitions(test_ck_tile_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp index 08997529b2..ab00f16632 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp @@ -20,20 +20,21 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using KernelTypes = ::testing::Types< // Has cshuffle epilogue enabled // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, +#endif std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> - >; + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type> + >; // clang-format on TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes); diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp index dac33b4656..c4bfc3e7cb 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -20,17 +20,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using KernelTypes = ::testing::Types< // Has cshuffle epilogue disabled // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, +#endif std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp index ee045c7f48..8cee050db2 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -23,6 +23,28 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +#endif +} + template & args, const ck_tile::stream_config& s) { - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; +#if CK_TILE_USE_WMMA + using ADataType = + ck_tile::remove_cvref_t{}, AsDataType>>; + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); +#else constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; +#endif constexpr bool DoubleSmemBuffer = false; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 43a73738d9..7c085b5098 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -13,6 +13,28 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +#endif +} + template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, @@ -80,7 +102,7 @@ struct config_wmma static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template diff --git a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt index f86da3c4d5..5363e365fc 100644 --- a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt @@ -9,7 +9,7 @@ endif() # Use standard asm for rtn bf16 conversion instead of turncate list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) -if(GPU_TARGETS MATCHES "gfx94|gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_multi_d test_grouped_gemm_multi_d.cpp) target_compile_options(test_ck_tile_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp index 65c662199b..8d56c274aa 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp @@ -29,9 +29,6 @@ template ; - static constexpr int M_Tile_ = M_Tile_val_; - static constexpr int N_Tile_ = N_Tile_val_; - static constexpr int K_Tile_ = K_Tile_val_; - static constexpr int M_Warp_ = M_Warp_val_; - static constexpr int N_Warp_ = N_Warp_val_; - static constexpr int K_Warp_ = K_Warp_val_; - static constexpr int M_Warp_Tile_ = M_Warp_Tile_val_; - static constexpr int N_Warp_Tile_ = N_Warp_Tile_val_; - static constexpr int K_Warp_Tile_ = K_Warp_Tile_val_; + static constexpr int M_Tile_ = M_Tile_val_; + static constexpr int N_Tile_ = N_Tile_val_; + static constexpr int K_Tile_ = K_Tile_val_; + static constexpr int M_Warp_ = M_Warp_val_; + static constexpr int N_Warp_ = N_Warp_val_; + static constexpr int K_Warp_ = K_Warp_val_; +#if CK_TILE_USE_WMMA + static constexpr int M_Warp_Tile_ = 16; + static constexpr int N_Warp_Tile_ = 16; + static constexpr int K_Warp_Tile_ = 16; +#else + static constexpr int M_Warp_Tile_ = 32; + static constexpr int N_Warp_Tile_ = 32; + static constexpr int K_Warp_Tile_ = (M_Warp_val_ == 2) ? 16 : 8; +#endif static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_; static constexpr auto Scheduler_ = Scheduler_val_; static constexpr PipelineType Pipeline_ = Pipeline_val_; @@ -68,21 +71,21 @@ struct KernelConfig // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent + // ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, DoubleSmemBuffer, Scheduler, Pipeline, Persistent // FP16 A/B/D/E - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4 // BF16 A/B/D/E - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4 >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt index 08b413aea9..3a230aed0c 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt @@ -6,7 +6,7 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94|gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp) target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp index 623d0152d6..450b7b8f24 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp @@ -50,16 +50,16 @@ struct KernelConfig // clang-format off using KernelTypes = ::testing::Types< // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu - KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>, +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>, - KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>, KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>, - - KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>, - KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>, KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>, - +#endif + KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>, KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 128, 128, 128, 2>, diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index 0eb388082b..5628b6feae 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -14,6 +14,9 @@ template constexpr ck_tile::index_t get_k_warp_tile_flatmm() { +#if CK_TILE_USE_WMMA + return 16; +#else #if defined(CK_GFX950_SUPPORT) if constexpr(M_Warp_Tile == 32) return sizeof(PrecType) == 2 ? 16 : 64; @@ -25,6 +28,7 @@ constexpr ck_tile::index_t get_k_warp_tile_flatmm() else return sizeof(PrecType) == 2 ? 32 : 64; #endif +#endif } template @@ -101,13 +105,40 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view( - {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / N_Warp_Tile, + N_Warp_Tile, + k_ / K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view( + {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template @@ -115,6 +146,11 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test const ck_tile::stream_config& s, void* kargs_ptr) { + constexpr ck_tile::index_t WaveSize = 32; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile); + constexpr bool SupportVectorSize16 = + (M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0; + constexpr int VectorSize = SupportVectorSize16 ? 16 : 8; using GemmShape = ck_tile::TileGemmShape, @@ -137,7 +173,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test /*UseStructuredSparsity*/ false, /*Persistent*/ false, /*NumWaveGroups*/ 1, - /*Preshuffle*/ true>; + /*Preshuffle*/ true, + VectorSize>; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem, ck_tile::sequence, @@ -230,7 +273,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test /*UseStructuredSparsity*/ false, /*Persistent*/ true, // Enable persistent mode /*NumWaveGroups*/ 1, - /*Preshuffle*/ true>; + /*Preshuffle*/ true, + VectorSize>; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem