diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index fadf4eb0cc..510a1fd1dc 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -121,7 +121,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale; using Base::I0; - using Base::I1; using Base::KRepeat; using Base::xdlops_gemm; using typename Base::HotLoopInstList; @@ -147,6 +146,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale( b_thread_desc_.GetElementSpaceSize()); - auto a_scale_thread_buf = make_static_buffer( - a_scale_thread_desc.GetElementSpaceSize()); - auto b_scale_thread_buf = make_static_buffer( - b_scale_thread_desc.GetElementSpaceSize()); + using AScaleBufferType = decltype(make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize())); + using BScaleBufferType = decltype(make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize())); + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; auto c_scale_thread_buf = make_static_buffer( c_scale_thread_desc.GetElementSpaceSize()); - StaticallyIndexedArray{}> a_scale_thread_bufs; - StaticallyIndexedArray{}> b_scale_thread_bufs; - // Global prefetch 1 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); @@ -381,12 +380,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; }); // Local prefill 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); // Global prefetch 2 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); @@ -400,7 +399,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{})); }); @@ -420,7 +419,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + + // Fill first MFMA buffer + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + __builtin_amdgcn_sched_barrier(0); // main body @@ -468,8 +494,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - vector_type c_scale_thread_vec; - c_scale_thread_vec.template AsType()(Number<0>{}) = c_scale_thread_buf[m0]; - c_scale_thread_vec.template AsType()(Number<1>{}) = c_scale_thread_buf[m0]; - - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto mfma_buf_offset = - ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); - constexpr auto scale_buf_offset = - ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); - - // Clear buffer for new MFMA computation - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number{}) - .template AsType()(Number{}) = 0; - }); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number{})); - }); - - // Apply scaling with packed FMA and accumulate to main buffer - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale - .GetVectorTypeReference(Number{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); - }); - }); - - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - - HotLoopScheduler(); - __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, @@ -584,15 +528,114 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0]; + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + static_for<0, NRepeat, 1>{}([&](auto n0) { + // Compute offsets + constexpr auto mfma_buf_offset = + ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + constexpr auto scale_buf_offset = + ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + + constexpr auto a_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; + constexpr auto b_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; + + constexpr auto local_buf_id = + Number{}; + + // Clear the current mfma output buffer + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + constexpr auto m0_offset = (a_local_buf_offset + HotloopLocalBufSwitch * mfma_reg_buf) % 2; + a_thread_vec.template AsType()(ik) = + a_thread_buf[local_buf_id][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[local_buf_id][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number{})); + }); + // Apply scaling + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + // .template AsType()[Number{}] * + // type_convert(c_scale_thread_buf[m0]); + // }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); }); - }; + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(mfma_reg_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_buf[mfma_reg_buf][m0] * + b_scale_thread_buf[mfma_reg_buf][I0]; + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + } LoopFunc(I0, I1); LoopFunc(I1, I0); @@ -602,65 +645,45 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { - vector_type c_scale_thread_vec; - c_scale_thread_vec.template AsType()(Number<0>{}) = c_scale_thread_buf[m0]; - c_scale_thread_vec.template AsType()(Number<1>{}) = c_scale_thread_buf[m0]; + // if constexpr(TailNum == TailNumber::Full) + // { + // static_for<0, MRepeat, 1>{}([&](auto m0) { + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + // .template AsType()(Number{}) = 0; + // }); + // static_for<0, KRepeat, 1>{}([&](auto k0) { + // vector_type a_thread_vec; + // vector_type b_thread_vec; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto mfma_buf_offset = - ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); - constexpr auto scale_buf_offset = - ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + // static_for<0, KPack, 1>{}([&](auto ik) { + // a_thread_vec.template AsType()(ik) = + // a_thread_buf[Number{}]; + // b_thread_vec.template AsType()(ik) = + // b_thread_buf[Number{}]; + // }); - // Clear buffer for new MFMA computation - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - c_thread_buf_per_scale.GetVectorTypeReference(Number{}) - .template AsType()(Number{}) = 0; - }); + // using mfma_input_type = + // typename vector_type::type; - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(Number{})); - }); - - // Apply scaling with packed FMA and accumulate to main buffer - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { - using pk_fma_type = typename vector_type::type; - - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()(t) = __builtin_elementwise_fma( - c_thread_buf_per_scale - .GetVectorTypeReference(Number{}) - .template AsType()[t], - c_scale_thread_vec.template AsType()[Number<0>{}], - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[t]); - }); - }); - }); + // xdlops_gemm.template Run<>( + // a_thread_vec.template AsType(), + // b_thread_vec.template AsType(), + // c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + // }); + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + // .template AsType()[Number{}] * + // type_convert(c_scale_thread_buf[m0]); + // }); + // }); + // }); __builtin_amdgcn_sched_barrier(0); } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index e80a3702fb..bf19eed977 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -973,8 +973,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + - b_block_space_size_aligned * sizeof(LDSTypeB)), + return math::max((2 * a_block_space_size_aligned * sizeof(LDSTypeA) + + 2 * b_block_space_size_aligned * sizeof(LDSTypeB)), c_block_size * sizeof(CShuffleDataType)); } @@ -1327,15 +1327,26 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); // Cast after lds - auto a_block_buf = make_dynamic_buffer( + auto a_block_buf_1 = make_dynamic_buffer( static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_2 = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_1, a_block_buf_2); - auto b_block_buf = make_dynamic_buffer( + auto b_block_buf_1 = make_dynamic_buffer( static_cast(p_shared) + - a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + 2 * a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_2 = make_dynamic_buffer( + static_cast(p_shared) + + 2 * a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB) + + b_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_1, b_block_buf_2); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); @@ -1410,13 +1421,13 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, - a_block_buf, + a_block_bufs, a_block_slice_copy_step, b_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_blockwise_copy, b_grid_buf, - b_block_buf, + b_block_bufs, b_block_slice_copy_step, c_scale_thread_desc,