diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 4f77176be6..18049b1757 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -274,27 +274,51 @@ struct CShuffleEpilogue } } - template + template + CK_TILE_DEVICE void print_tensor_matrix_format( + const static_distributed_tensor& tensor, + const char* /*name = "tensor_matrix"*/) + { + const auto spans = tensor.get_distributed_spans(); + //static_assert(spans.size() == 2, "This function is for 2D tensors only"); + + const auto dim0_span = spans[number<0>{}]; + const auto dim1_span = spans[number<1>{}]; + + //printf("%s matrix format (tid %u):\n", name, threadIdx.x); + + sweep_tile_span(dim0_span, [&](auto row) { + printf(" "); + sweep_tile_span(dim1_span, [&](auto col) { + constexpr auto distributed_indices = make_tuple(row, col); + const auto value = tensor[distributed_indices]; + printf("tid %u: %.7f\n", threadIdx.x, static_cast(value)); + }); + //printf("\n"); + }); + //printf("\n"); + } + + template + //template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* p_smem) { - if constexpr (NumGroupsToMerge > 1) + if constexpr (NumGroupsToMerge == 1) { - // When NumGroupsToMerge > 1, we want to write out only the diagonal blocks. - // Hence, we configure the shuffle such that it iterates one merge group block at a time. - return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem); - } - else - { - // When NumGroupsToMerge == 1, we want to write out all the blocks. return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem); } + else + { + return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem); + //return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem); + } } - template + template CK_TILE_DEVICE auto merged_op(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, @@ -359,8 +383,8 @@ struct CShuffleEpilogue block_sync_lds(); constexpr index_t Gs = NumGroupsToMerge; - constexpr index_t MPerGroup = kMPerBlock / Gs; - constexpr index_t NPerGroup = kNPerBlock / Gs; + constexpr index_t MBlockWidth = kMPerBlock / Gs; + constexpr index_t NBlockWidth = kNPerBlock / Gs; // Tile enconding for a single group (diagonal block in LDS) constexpr auto dram_tile_encoding = tile_distribution_encoding< @@ -374,20 +398,31 @@ struct CShuffleEpilogue constexpr auto dram_tile_distribution = make_static_tile_distribution(dram_tile_encoding); // The LDS data has the following 4D layout in the row-major case. - // linear_index = c + Gs * m + Gs * MPerGroup * n + Gs * MPerGroup * NPerGroup * r + // linear_index = c + Gs * n + Gs * NBlockWidth * m + Gs * MBlockWidth * NBlockWidth * r // for 4D coordinates (r,c,m,n) where (r,c) is the group index and (m,n) is the index within the group. + // Within the sub-block, we have column-major layout (n is the faster index). // We pick-up only the diagonal blocks where r == c. // For each block, the tile distribution and the tensor descriptors are the same. // The only thing that changes is the p_smem offset. constexpr auto lds_block_desc_2d = make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number{})); + make_tuple(number{}, number{}), + make_tuple(number{}, number{})); + + if (threadIdx.x == 0) + { + printf("""CShuffleEpilogue::merged_op(): MPerGroup=%d, NPerGroup=%d, Gs=%d, MBlockWidth=%d, NBlockWidth=%d\n", + static_cast(MPerGroup), static_cast(NPerGroup), static_cast(Gs), + static_cast(MBlockWidth), static_cast(NBlockWidth)); + } // Loop over the groups (diagonal blocks in LDS) static_for<0, Gs, 1>{}([&](auto g) { block_sync_lds(); - constexpr index_t group_offset = g * (1 + Gs* MPerGroup * NPerGroup); + // With to the single diagonal block of LDS. + // This block may have more elements that the actual output groups contains + // because we have MPerGroup <= MBlockWidth and NPerGroup <= NBlockWidth. + constexpr index_t group_offset = g * (1 + Gs* MBlockWidth * NBlockWidth * MPerGroup); auto lds_view = make_tensor_view( static_cast(p_smem) + group_offset, lds_block_desc_2d); @@ -405,6 +440,10 @@ struct CShuffleEpilogue auto c_out_tensor = load_tile(lds_window); + // DEBUG: Print out the c_out_tensor contents for debugging + //print_tensor_matrix_format(c_out_tensor, "c_out_tensor"); + //__syncthreads(); + const auto ds_tensor = generate_tuple( [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 51d3b891f3..178a9e3bb7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -99,12 +99,14 @@ struct GroupedConvBwdWeightKernelArgs GemmN = b_grid_desc_n_k.get_length(number<0>{}); GemmK = a_grid_desc_m_k.get_length(number<1>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + ZYX = conv_to_gemm_transformer.ZYX_; if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch + << ", ZYX: " << ZYX << std::endl; } } @@ -185,12 +187,14 @@ struct GroupedConvBwdWeightKernelArgs GemmN = b_grid_desc_n_k.get_length(number<0>{}); GemmK = a_grid_desc_m_k.get_length(number<1>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + ZYX = conv_to_gemm_transformer.ZYX_; if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch + << ", ZYX: " << ZYX << std::endl; } } @@ -278,12 +282,14 @@ struct GroupedConvBwdWeightKernelArgs GemmN = b_grid_desc_n_k.get_length(number<0>{}); GemmK = a_grid_desc_m_k.get_length(number<1>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + ZYX = conv_to_gemm_transformer.ZYX_; if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch + << ", ZYX: " << ZYX << std::endl; } } @@ -310,6 +316,7 @@ struct GroupedConvBwdWeightKernelArgs index_t GemmK; index_t GemmBatch; index_t NumGroupsPerBatch; + index_t ZYX; const void* out_ptr; const void* in_ptr; @@ -797,8 +804,114 @@ struct GroupedConvolutionBackwardWeightKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + constexpr index_t Gs = GroupedConvTraitsType_::NumGroupsToMerge; + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; + const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; + const index_t ZYX = kargs.ZYX; + const index_t MPerGroup = ConvK; + const index_t NPerGroup = ZYX * ConvC; + + if (threadIdx.x == 0) + { + printf("MPerGroup: %d, NPerGroup: %d \n", MPerGroup, NPerGroup); + } + + // Check that MPerGroup and NPerGroup map to the existing options + if (MPerGroup == 1 && NPerGroup == 16) + { + EpiloguePipeline{}.template operator()<1, 16, decltype(c_block_window), decltype(c_block_tile)>( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else if (MPerGroup == 2 && NPerGroup == 16) + { + EpiloguePipeline{}.template operator()<2, 16, decltype(c_block_window), decltype(c_block_tile)>( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else if (MPerGroup == 1 && NPerGroup == 32) + { + EpiloguePipeline{}.template operator()<1, 32, decltype(c_block_window), decltype(c_block_tile)>( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else if (MPerGroup == 2 && NPerGroup == 32) + { + EpiloguePipeline{}.template operator()<2, 32, decltype(c_block_window), decltype(c_block_tile)>( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else if (MPerGroup == 1 && NPerGroup == 4) + { + EpiloguePipeline{}.template operator()<1, 4, decltype(c_block_window), decltype(c_block_tile)>( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else if (MPerGroup == 1 && NPerGroup == 8) + { + EpiloguePipeline{}.template operator()<1, 8, decltype(c_block_window), decltype(c_block_tile)>( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else if (MPerGroup == 2 && NPerGroup == 4) + { + EpiloguePipeline{}.template operator()<2, 4, decltype(c_block_window), decltype(c_block_tile)>( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else if (MPerGroup == 2 && NPerGroup == 8) + { + EpiloguePipeline{}.template operator()<2, 8, decltype(c_block_window), decltype(c_block_tile)>( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + auto print_lds_per_wg = [&]() + { + // Print out LDS contents. + // The LDS corresponds TilePartitioner_::MPerBlock * TilePartitioner_::NPerBlock matrix. + // Print LDS contents as matrix + printf("LDS Contents (%d x %d) for thread block %u:\n", TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, blockIdx.x); + OutDataType* lds_data = reinterpret_cast(smem_ptr_0); + + const int MBlockWidth = TilePartitioner::MPerBlock / Gs; + const int NBlockWidth = TilePartitioner::NPerBlock / Gs; + const int Ncols = Gs; + const int Nrows = Gs; + + for(int c = 0; c < Ncols; ++c) { + printf("Block %d:\n", c); + for(int r = 0; r < Nrows; ++r) { + printf("Row %d:\n", r); + for (int m = 0; m < MBlockWidth; ++m) + { + for(int n = 0; n < NBlockWidth; ++n) + { + int idx = c + Ncols * n + Ncols * NBlockWidth * m + Ncols * MBlockWidth * NBlockWidth * r; + printf("%.7f ", static_cast(lds_data[idx])); + } + printf(" \n"); + } + } + printf("\n\n"); + } + + + // Print out the c_block_window contents for debugging + printf("C Ptr Contents (%d x %d):\n", TilePartitioner::MPerBlock, NPerGroup); + for(int m = 0; m < TilePartitioner::MPerBlock; ++m) { + for(int n = 0; n < NPerGroup; ++n) { + int idx = m * NPerGroup + n; + printf("%.7f ", static_cast(c_ptr[idx])); + if((n + 1) % NPerGroup == 0) printf("\n "); // Line break every NBlockWidth elements for readability + } + printf("\n"); + } + }; + + __syncthreads(); + if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) + { + print_lds_per_wg(); + } + __syncthreads(); + // if (blockIdx.x == 1 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) + // { + // print_lds_per_wg(); + // } + // __syncthreads(); } /**