diff --git a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp index e05b02ad18..0748131340 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp +++ b/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp @@ -55,26 +55,26 @@ using DeviceInstance = YElementOp, Rank, NumReduceDim, - 256, // BlockSize - 8, // ClusterM - 32, // ClusterK - 1, // SliceM - 8, // SliceK - 1, // SrcVecDim (0=M, 1=K) - 8, // SrcScalarPerVector - 1, // GammaVecDim (0=M, 1=K) - 8, // GammaScalarPerVector - 1, // BetaVecDim (0=M, 1=K) - 8, // BetaScalarPerVector - 8>; // OutScalarPerVector + 1024, // BlockSize + 1, // ClusterM + 1024, // ClusterK + 1, // SliceM + 32, // SliceK + 1, // SrcVecDim (0=M, 1=K) + 2, // SrcScalarPerVector + 1, // GammaVecDim (0=M, 1=K) + 2, // GammaScalarPerVector + 1, // BetaVecDim (0=M, 1=K) + 2, // BetaScalarPerVector + 2>; // OutScalarPerVector int main(int argc, char* argv[]) { - ck::index_t N = 128; - ck::index_t H = 16; - ck::index_t W = 16; + ck::index_t N = 2; + ck::index_t H = 32; + ck::index_t W = 32; ck::index_t G = 32; - ck::index_t C = 40; + ck::index_t C = 30; if(argc == 1) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp index 8d17178649..094c79c6f8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp @@ -57,7 +57,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}))); + make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); @@ -73,8 +73,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto XThreadBufferNumber = Number{}; + static constexpr auto GammaThreadBufferNumber = Number{}; + static constexpr auto BetaThreadBufferNumber = Number{}; + static constexpr auto YThreadBufferNumber = Number{}; __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, int thread_k_cluster_id) @@ -87,10 +93,13 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk if(kPerBlockTail > 0) { - int thread_max_len = (thread_k_cluster_id + 1) * KThreadSliceSize; - int delta = thread_max_len - kPerBlockTail; - delta = math::clamp(thread_max_len - kPerBlockTail, 0, KThreadSliceSize); - kPerThread += KThreadSliceSize - delta; + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + int thread_max_len = + (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); + kPerThread += XSrcVectorSize - delta; + }); } return kPerThread; @@ -116,19 +125,41 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk auto y_global_val_buf = make_dynamic_buffer( p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); - StaticBuffer - x_thread_buf; + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); - StaticBuffer - gamma_thread_buf; + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); - StaticBuffer& beta_thread_buf = gamma_thread_buf; + auto beta_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); - StaticBuffer - y_thread_buf; + auto y_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); StaticBuffer mean_thread_buf; StaticBuffer var_thread_buf; @@ -142,9 +173,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_k_cluster_id = thread_cluster_idx[I1]; - using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M_K = Sequence; constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); + make_tuple(Number{}, Number{})); auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + }); } static_for<0, MThreadSliceSize, 1>{}([&](auto I) { @@ -256,7 +285,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); }); - auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k; threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); @@ -267,62 +297,86 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk { if constexpr(!SweepOnce) { - threadwise_x_load.Run(x_grid_desc_m_k, - x_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - x_thread_buf); + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); } - threadwise_gamma_load.Run(gamma_grid_desc_m_k, - gamma_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - gamma_thread_buf); + static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon); + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // normalize - y_thread_buf(Number{}) = - (x_thread_buf(Number{}) - mean_thread_buf(iM)) / - sqrt(var_thread_buf(iM) + epsilon); + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; - // gamma - y_thread_buf(Number{}) = - y_thread_buf(Number{}) * gamma_thread_buf(Number{}); + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); }); }); - threadwise_beta_load.Run(beta_grid_desc_m_k, - beta_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - beta_thread_buf); + static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - // beta - y_thread_buf(Number{}) = - y_thread_buf(Number{}) + beta_thread_buf(Number{}); + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); }); }); - threadwise_y_store.Run(thread_buffer_desc_m_k, - make_tuple(I0, I0), - y_thread_buf, - y_grid_desc_m_k, - y_global_val_buf); + static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); } } }; diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp index bf0f7a3d2c..89bdf9438c 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp @@ -31,7 +31,9 @@ using device_layernorm_f16_instances = std::tuple< DeviceLayernormImpl, DeviceLayernormImpl, DeviceLayernormImpl, - DeviceLayernormImpl + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl // clang-format on >; diff --git a/test/layernorm/test_groupnorm_fp16.cpp b/test/layernorm/test_groupnorm_fp16.cpp index 235ebca3d1..550813323b 100644 --- a/test/layernorm/test_groupnorm_fp16.cpp +++ b/test/layernorm/test_groupnorm_fp16.cpp @@ -26,6 +26,8 @@ class TestGroupnorm : public ::testing::Test {256, 9, 9, 9, 9}, {1, 64, 64, 32, 10}, {1, 32, 32, 32, 20}, + {2, 32, 32, 32, 30}, + {2, 32, 32, 32, 40}, {1, 16, 16, 32, 40}}; for(auto length : lengths)