diff --git a/CMakeLists.txt b/CMakeLists.txt index e90f893de0..3558666e5d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -246,13 +246,6 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000) add_compile_options("SHELL: -mllvm --lsr-drop-solution=1") endif() endif() -if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090) - check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED) - if(HAS_ENABLE_POST_MISCHED) - message("Adding the enable-post-misched=0 compiler flag") - add_compile_options("SHELL: -mllvm -enable-post-misched=0") - endif() -endif() set(check-coerce) check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index 9b7849a654..b54ba5ddfb 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -55,7 +55,7 @@ using CDEElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr ck::index_t Scale_Block_M = 128; +static constexpr ck::index_t Scale_Block_M = 1; static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; @@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_ A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, - 128, 128, - 128, 16, 16, + 16, 128, + 256, 16, 16, 16, 16, - 4, 4, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; + 1, 2, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 2, S<1, 16, 1, 16>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; // clang-format on int main(int argc, char* argv[]) @@ -80,11 +80,12 @@ int main(int argc, char* argv[]) bool do_verification = true; int init_method = 1; bool time_kernel = false; + bool flush_cache = true; // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; ck::index_t StrideA = K; ck::index_t StrideB = K; @@ -100,7 +101,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 10) + else if(argc == 8) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -110,16 +111,19 @@ int main(int argc, char* argv[]) N = std::stoi(argv[5]); K = std::stoi(argv[6]); - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideE = std::stoi(argv[9]); + flush_cache = std::stoi(argv[7]); + + StrideA = K; + StrideB = K; + StrideE = N; } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); exit(0); } @@ -182,9 +186,15 @@ int main(int argc, char* argv[]) b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 4: - a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); break; default: @@ -194,6 +204,16 @@ int main(int argc, char* argv[]) b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); } #endif +#if 0 + for(int im =0; im< (M + Scale_Block_M - 1) / Scale_Block_M; im++){ + float row_sum = .0; + for(int ik =0; ik< (K + Scale_Block_K - 1) / Scale_Block_K; ik++){ + printf("%lf ",a1_m_k(im, ik)); + row_sum += a1_m_k(im, ik); + } + printf("sum: %lf\n", row_sum * 128); + } +#endif DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); @@ -239,12 +259,24 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); - std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + float ave_time = .0; + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp index 821bbb0051..8375e81fa0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -7,10 +7,10 @@ namespace ck { -// Naive pipeline with lowest resource request per WGP -// GlobalPrefetchStages: 1 +// Compute optimized pipeline +// GlobalPrefetchStages: 2 // LocalPreFillStages: 1 -// LocalPreFetchStages: 0 +// LocalPreFetchStages: 1 // LocalSharedMemoryBuffer: 1 template + KPack, + true> { using Base = BlockwiseGemmXdlops_pipeline_base; + KPack, + true>; + using Base::A_K1; + using Base::B_K1; using Base::I0; + using Base::I1; using Base::KRepeat; using Base::xdlops_gemm; + using typename Base::HotLoopInstList; using Base::CalculateCThreadOriginDataIndex; using Base::CalculateCThreadOriginDataIndex8D; @@ -131,19 +137,43 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale PrefetchStages; @@ -151,11 +181,116 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); } template ( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -223,6 +359,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale( b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); // Global prefetch 1 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); @@ -231,11 +369,26 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -243,17 +396,101 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}); + constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + // Initialize C c_thread_buf.Clear(); - auto c_thread_buf_per_scale = remove_cvref_t(); + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); // main body if constexpr(HasMainLoop) @@ -261,13 +498,85 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = + CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -289,19 +598,70 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + i += 1; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { vector_type a_thread_vec; vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; + make_tuple(m0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; + make_tuple(n0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; }); using mfma_input_type = @@ -311,46 +671,41 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale( a_thread_vec.template AsType(), b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * - type_convert(b_scale_thread_buf[I0]); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); }); }); }); + }); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(I0, I0), - a_scale_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(I0, I0), - b_scale_thread_buf); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); - - block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - i += 1; - - } while(i < (num_loop - 1)); - } - - // tail - if constexpr(TailNum == TailNumber::Full) - { block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -371,49 +726,143 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - using mfma_input_type = - typename vector_type::type; + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * - type_convert(b_scale_thread_buf[I0]); + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); + }); }); }); }); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); + }); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); } } protected: - using Base::a_thread_copy_; using Base::a_thread_desc_; - using Base::b_thread_copy_; using Base::b_thread_desc_; using Base::c_thread_desc_; + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index 40fa776484..c8ad9c5b02 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale + KPack, + true> { using Base = BlockwiseGemmXdlops_pipeline_base; + KPack, + true>; using Base::I0; using Base::KRepeat; using Base::xdlops_gemm; @@ -270,11 +272,26 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -282,7 +299,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}) += c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * + type_convert(a_scale_thread_buf[m0]) * type_convert(b_scale_thread_buf[I0]); }); }); }); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(I0, I0), - a_scale_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -378,8 +409,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}) += c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * + type_convert(a_scale_thread_buf[m0]) * type_convert(b_scale_thread_buf[I0]); }); }); }); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(I0, I0), - a_scale_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -471,7 +515,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}) += c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * + type_convert(a_scale_thread_buf[m0]) * type_convert(b_scale_thread_buf[I0]); }); }); @@ -586,7 +629,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}) += c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * + type_convert(a_scale_thread_buf[m0]) * type_convert(b_scale_thread_buf[I0]); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index de542866a6..fc0075b196 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale + KPack, + true> { using Base = BlockwiseGemmXdlops_pipeline_base; + KPack, + true>; using Base::I0; using Base::KRepeat; using Base::xdlops_gemm; @@ -177,11 +179,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}) == 1, + "Pipeline v3 only support scaleblocksliceK=1"); + static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1, + "Pipeline v3 only support scaleblocksliceN=1"); // assume kperblock = scaleblockk - ignore = num_loop_per_scale; auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -330,6 +337,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale( b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); // Global prefetch 1 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); @@ -338,11 +347,26 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -350,8 +374,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; + }); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -363,10 +391,44 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + // Initialize C c_thread_buf.Clear(); - auto c_thread_buf_per_scale = remove_cvref_t(); + StaticBufferTupleOfVector + c_thread_buf_per_scale; // Local prefetch 1 block_sync_lds(); @@ -409,7 +471,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); static_for<0, KRepeat, 1>{}([&](auto k0) { vector_type a_thread_vec; vector_type b_thread_vec; @@ -430,19 +495,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale( a_thread_vec.template AsType(), b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * - type_convert(b_scale_thread_buf[I0]); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[m0]); }); }); }); + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; + }); + block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -462,11 +531,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -474,7 +559,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); static_for<0, KRepeat, 1>{}([&](auto k0) { vector_type a_thread_vec; vector_type b_thread_vec; @@ -507,15 +594,15 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale( a_thread_vec.template AsType(), b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * - type_convert(b_scale_thread_buf[I0]); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[m0]); }); }); }); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp index 480402b7e1..d5fec7201a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" namespace ck { namespace tensor_operation { @@ -177,14 +178,57 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const auto Run = [&](const auto& kernel) { - if(arg.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); + if(stream_config.flush_cache) + { + Argument arg_ = arg; - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } }; constexpr index_t minimum_occupancy = @@ -195,7 +239,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 if(has_main_k_block_loop) { - // Tail number always 1 + // Tail number always full if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { @@ -208,127 +252,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 Run(kernel); } } - // Tail number could be One to Seven - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - { - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Three) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Four) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Five) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Seven) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - } - } } else { // Tail number always 1 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) { const auto kernel = kernel_gemm_xdl_cshuffle_v3; Run(kernel); } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } } } return ave_time; @@ -363,10 +303,11 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 return false; } - if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock) - { - return false; - } + // if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != + // KPerBlock) + // { + // return false; + // } if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 813acfa656..25be9bebb7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -225,7 +225,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); } - __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) { const auto a_grid_desc_mraw_kraw = [&]() { @@ -307,7 +307,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 } } - __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) { const auto b_grid_desc_nraw_kraw = [&]() { @@ -422,6 +422,13 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 } }(); + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MNPadding || @@ -459,6 +466,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 // not pad M or N return c_grid_desc_mraw_nraw; } +#endif } __host__ __device__ static auto MakeDsGridDescriptor_M_N( @@ -656,40 +664,19 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 // in some cases. else if constexpr(is_same::value) { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(LDSTypeA); - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; + return a_lds_block_desc_permuted; } else // ColumnMajor A { @@ -791,42 +778,19 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 } else if constexpr(is_same::value) { - // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(LDSTypeB); - ; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; + return b_lds_block_desc_permuted; } else // RowMajor B { @@ -992,7 +956,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) { if(!(karg.M % MPerBlock == 0)) { @@ -1009,7 +974,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) { if(!(karg.N % NPerBlock == 0)) { @@ -1357,28 +1323,39 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - const index_t ScaleSliceSizeM = 1; - const index_t ScaleSliceSizeN = 1; - const index_t ScaleSliceSizeK = 1; + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2, + Sequence<1, ScaleSliceSizeK>, Sequence<0, 1>, 1, - 1, + ScaleSliceSizeK, 1, false>( - a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM, 0)); + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2, Sequence<0, 1>, 1, - 1, + ScaleSliceSizeK, 1, false>( b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); - constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); - constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1); + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(MWaves * MPerXdl, 0), + make_multi_index(-MPerBlock, 0), + make_multi_index(-MPerBlock, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); - const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); - blockwise_gemm_pipeline.template Run( + blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_blockwise_copy, @@ -1411,6 +1392,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 b_grid_buf, b_block_buf, b_block_slice_copy_step, + + c_scale_thread_desc, c_thread_buf, a_scale_grid_desc_am_ak, @@ -1425,8 +1408,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 b_scale_grid_buf, b_scale_thread_slice_copy_step, - num_k_block_main_loop, - num_k_block_per_scale); + num_k_block_main_loop); // shuffle C and write out { @@ -1437,23 +1419,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + // transposed XDL + // // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + // // TODO: hacky, fix it! + // only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -1462,24 +1445,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 static_cast(p_shared), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( Number{}, // M0 (MXdlPerWave) per shuffle M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), + M2)), // M2 = MPerXdl make_freeze_transform(I0), make_unmerge_transform(make_tuple( Number{}, // N0 (NXdlPerWave) per shuffle N1, // N1 = NWave - N2))), // N2 = NPerXdl + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -1489,57 +1472,57 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0>{})); + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3, + N2, + I1, + N4>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, 7, 1, InMemoryDataOperationEnum::Set, 1, true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, make_multi_index(0, 0, m_thread_data_on_block_idx[I1], n_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; using EDataType = CDataType; @@ -1621,18 +1604,17 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), c_element_op}; - // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = - SpaceFillingCurve, + SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence>{}; + N2, + 1, + N4>>{}; constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); @@ -1652,10 +1634,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 block_sync_lds(); // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, c_shuffle_block_buf); // make sure it's safe to read from LDS diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp index 7553d5e76e..3fa82ae53a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp @@ -17,7 +17,7 @@ namespace tensor_operation { namespace device { namespace instance { #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( std::vector, @@ -28,14 +28,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( std::vector, @@ -46,14 +46,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_ F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( std::vector, @@ -64,14 +64,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( std::vector, @@ -82,61 +82,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin F32, Tuple<>, BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, + 1, 128, 128, PassThrough, @@ -163,7 +109,7 @@ struct DeviceOperationInstanceFactory, CDataType, - 128, + 1, 128, 128, ck::tensor_operation::element_wise::PassThrough, @@ -180,7 +126,7 @@ struct DeviceOperationInstanceFactory, CDataType, - 128, + 1, 128, 128, ck::tensor_operation::element_wise::PassThrough, @@ -198,20 +144,14 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances( + add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances( - op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances( - op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances( + add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances( + add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances( - op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances( + add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( op_ptrs); } } diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt index aab1c4e86e..d572862884 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt @@ -4,16 +4,13 @@ set(GEMM_AB_SCALE_INSTANCES) list(APPEND GEMM_AB_SCALE_INSTANCES device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp - device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp - device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp - device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp ) set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_ab_scale_instance ${GEMM_AB_SCALE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp index 3a7df8d974..eba9cfcb7c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp @@ -34,49 +34,50 @@ static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; template -using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances = std::tuple< +using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< // clang-format off - //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Compute friendly - // Spill in current compiler - // DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - // DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template -using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances = std::tuple< +using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< // clang-format off - //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //################################| ALayout| BLayout| DsLayout| ELayout|AData | BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // Latency friendly - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - // Memory friendly - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8> + // Memory friendly + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp index ab83c7eb3e..aebffc01f2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( std::vector, @@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, @@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i { add_device_operation_instances( instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances{}); + device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp index dfb1bb6e2d..31fffae080 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( std::vector, @@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_ F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, @@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_ { add_device_operation_instances( instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances{}); + device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp deleted file mode 100644 index d2d3ebe81e..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp deleted file mode 100644 index f6ce77a751..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp index e2205ad728..569911e3de 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( std::vector, @@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, @@ -28,8 +28,8 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default { add_device_operation_instances( instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances{}); + device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp index 5c0a6eb00d..d1e5b6b535 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( std::vector, @@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, @@ -28,8 +28,8 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin { add_device_operation_instances( instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances{}); + device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp deleted file mode 100644 index cc1a03b060..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/profiler/src/profile_gemm_ab_scale.cpp b/profiler/src/profile_gemm_ab_scale.cpp index 56c8b5e7a1..3956038a30 100644 --- a/profiler/src/profile_gemm_ab_scale.cpp +++ b/profiler/src/profile_gemm_ab_scale.cpp @@ -32,6 +32,7 @@ enum struct GemmDataType enum struct ScaleBlockTile { Tile_128_128_128, // 0 + Tile_1_128_128, // 1 }; #define OP_NAME "gemm_ab_scale" @@ -49,7 +50,8 @@ int profile_gemm_ab_scale(int argc, char* argv[]) printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); - printf("arg4: scale block tile (0: ScaleBlockM/N/K = [128, 128, 128];\n"); + printf("arg4: scale block tile (0: ScaleBlockM/N/K = [128, 128, 128]; 1: ScaleBlockM/N/K = " + "[1, 128, 128];\n"); printf("arg5: verification (0: no; 1: yes)\n"); printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg7: print tensor value (0: no; 1: yes)\n"); @@ -155,7 +157,7 @@ int profile_gemm_ab_scale(int argc, char* argv[]) }; if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN && - scale_block_tile == ScaleBlockTile::Tile_128_128_128) + scale_block_tile == ScaleBlockTile::Tile_1_128_128) { return profile(F8{}, F32{}, @@ -164,7 +166,7 @@ int profile_gemm_ab_scale(int argc, char* argv[]) F8{}, F32{}, BF16{}, - ck::Number<128>{}, + ck::Number<1>{}, ck::Number<128>{}, ck::Number<128>{}, Row{},