diff --git a/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp b/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp index 3d93e3a221..07dd47aa26 100644 --- a/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp +++ b/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp @@ -36,7 +36,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio // clang-format off template -using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< +using DeviceOpInstance_128_32_64_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, @@ -52,6 +52,24 @@ using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMul 1, 1, S<1, 16, 1, 8>, S<8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, F16>; +template +using DeviceOpInstance_256_128_128_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< + ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, + DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, + 128, 128, 64, + 8, 4, + 32, 32, + 2, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 4, 0, + 1, 1, + S<1, 32, 1, 8>, S<8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F16>; + template using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, @@ -128,14 +146,32 @@ float run_impl(const GemmBiasAddArgs& args, const StreamConfig& config) return true; }; - auto gemm = - DeviceOpInstance_64_16_16_64{}; - if(!Run(gemm)) + do { + if(args.M <= 512) + { + auto gemm = DeviceOpInstance_128_32_64_64{}; + if(Run(gemm)) + break; + } + else + { + auto gemm = DeviceOpInstance_256_128_128_64{}; + if(Run(gemm)) + break; + } auto gemm_def = DeviceOpInstance_default{}; Run(gemm_def); - } + } while(0); return ave_time; }