diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index dc386bac38..a7e437ed44 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -313,9 +313,9 @@ struct CShuffleEpilogue } template - CK_TILE_DEVICE auto merged_op(ODramWindow&, //out_dram_window, + CK_TILE_DEVICE auto merged_op(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, - const DsDramWindows&, //ds_dram_windows, + const DsDramWindows& ds_dram_windows, void* p_smem) { constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); @@ -383,13 +383,12 @@ struct CShuffleEpilogue }); block_sync_lds(); -#if 0 // Set-up LDS to global memory copy. // We copy only the diagonal blocks from LDS to global memory. // Hence, we must configure the tile distrinbution // on the group size blocks. - constexpr index_t MPerGroup = kMPerBlock / NumGroupsToMerge; - constexpr index_t NPerGroup = kNPerBlock / NumGroupsToMerge; + // constexpr index_t MPerGroup = kMPerBlock / NumGroupsToMerge; + // constexpr index_t NPerGroup = kNPerBlock / NumGroupsToMerge; //constexpr index_t NumThreads = get_total_number_of_threads(LdsTileDistr.get_static_tile_distribution_encoding()); // TODO: Remove this debug assert @@ -412,10 +411,49 @@ struct CShuffleEpilogue // constexpr index_t NumThreadsDram = get_total_number_of_threads(dram_tile_encoding); // static_assert(NumThreadsDram == NumThreads, "NumThreadsDram must be equal to NumThreads."); + // auto lds_window = make_tile_window( + // lds_view, + // make_tuple(number{}, number{}), + // make_tuple(number<0>{}, group_index * NBlockWidth), + // sequential_distribution); + + // 4D tensor view of LDS memory + // with (g_i, g_j, i, j) where (g_i, g_j) is the group index + // and (i, j) is the index within the group. + constexpr index_t Gs = NumGroupsToMerge; + constexpr index_t MPerGroup = kMPerBlock / Gs; + constexpr index_t NPerGroup = kNPerBlock / Gs; + constexpr auto lds_desc_4d = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}, number{}), + make_tuple(number{}, number<1>{}, number{}, number{})); + + // We must merge (r,m) and (c,n) dimensions together to make a 2D tensor descriptor. + constexpr auto lds_desc = transform_tensor_descriptor( + lds_desc_4d, + make_tuple( + make_merge_transform(make_tuple(Gs, MPerGroup)), + make_merge_transform(make_tuple(Gs, NPerGroup)) + ), + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{}) + ); + + auto lds_view = make_tensor_view( + static_cast(p_smem), lds_desc); + + // This is hard-coded for a specific MPerBlock, NPerBlock, NumGroupsToMerge case. + // constexpr auto dram_tile_encoding = tile_distribution_encoding< + // sequence<>, + // tuple, + // sequence<1, 1, 8, 16>>, + // tuple, sequence<1,2>>, + // tuple, sequence<2,2>>, + // sequence<1, 1, 2, 2>, + // sequence<0, 3, 0, 3>>{}; constexpr auto dram_tile_encoding = tile_distribution_encoding< sequence<>, - tuple, - sequence<1, 1, 8, 16>>, + tuple, + sequence<1, Gs, NPerGroup, 1>>, tuple, sequence<1,2>>, tuple, sequence<2,2>>, sequence<1, 1, 2, 2>, @@ -429,40 +467,34 @@ struct CShuffleEpilogue }, number{}); + // Calculate which block in the Gs x Gs space we are located at. + const auto x_space_coord = dram_tile_distribution.calculate_index(); + const index_t m_block = x_space_coord[0] / MPerGroup; + const index_t n_block = x_space_coord[1] / NPerGroup; + auto current_lds_window = make_tile_window( - o_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - dram_tile_distribution); + lds_view, + make_tuple(number{}, number{}, number{}, number{}), + {0, 0, 0, 0}, + dram_tile_distribution); - block_sync_lds(); - - // Calculate which block in the Gm x Gm space we are located at. - auto get_block_number = [&]() -> ck_tile::tuple - { - const auto x_space_coord = dram_tile_distribution.calculate_index(); - const index_t m_block = x_space_coord[0] / MPerGroup; - const index_t n_block = x_space_coord[1] / NPerGroup; - return make_tuple(m_block, n_block); - }; - - auto mask = [&]() -> bool - { - // Return true only for the diagonal blocks. - const auto blockId = get_block_number(); - return blockId[number<0>{}] == blockId[number<1>{}]; - }; - - if (mask()) + // Copy only the diagonal blocks. + if (m_block == n_block) { // Load static_distributed_tensor from LDS. auto c_out_tensor = load_tile(current_lds_window); + // Debug: Print out the c_out_tensor contents for debugging + __syncthreads(); if (threadIdx.x == 0 && blockIdx.x == 0) { // Print out the c_out_tensor contents for debugging print_tensor_matrix_format(c_out_tensor, "c_out_tensor"); } + __syncthreads(); + // End debug + + // TODO: We must move the d_dram_windows and the out_dram_windows to the correct group position. const auto ds_tensor = generate_tuple( [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); @@ -483,7 +515,6 @@ struct CShuffleEpilogue update_tile(out_dram_window, c_out_tensor); } } -#endif } template diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 7d31b4d68b..82fb3e0786 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -754,11 +754,11 @@ struct GroupedConvolutionBackwardWeightKernel return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); } - // template - // CK_TILE_DEVICE void transfer_lds_to_global_simple(OutDataType* c_ptr, void* smem_ptr_0) + // template + // CK_TILE_DEVICE static void transfer_lds_to_global(OutDataType* c_ptr, void* smem_ptr_0) // { - // constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType::NumGroupsToMerge; - // constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType::NumGroupsToMerge; + // constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType_::NumGroupsToMerge; + // constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType_::NumGroupsToMerge; // // Create a single-thread tile distribution for sequential access // constexpr auto sequential_encoding = tile_distribution_encoding< @@ -772,18 +772,18 @@ struct GroupedConvolutionBackwardWeightKernel // constexpr auto sequential_distribution = make_static_tile_distribution(sequential_encoding); - // if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType::NumGroupsToMerge) { + // if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType_::NumGroupsToMerge) { // const auto group_index = threadIdx.x; - // // LDS tensor view + // // LDS tensor view, column-major ordering // constexpr auto lds_desc = make_naive_tensor_descriptor( // make_tuple(number{}, number{}), - // make_tuple(number{}, number<1>{})); + // make_tuple(number<1>{}, number{})); // auto lds_view = make_tensor_view( // static_cast(smem_ptr_0), lds_desc); - // // Global memory tensor view + // // Global memory tensor view, row-major ordering // constexpr auto global_desc = make_naive_tensor_descriptor( // make_tuple(number{}, number{}), // make_tuple(number{}, number<1>{})); @@ -864,59 +864,61 @@ struct GroupedConvolutionBackwardWeightKernel EpiloguePipeline{}.template operator()( c_block_window, c_block_tile, d_block_window, smem_ptr_0); + //constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType_::NumGroupsToMerge; + //constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType_::NumGroupsToMerge; + //Run LDS to global memory manually, one thread per convolution group. - constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType_::NumGroupsToMerge; - constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType_::NumGroupsToMerge; - if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType_::NumGroupsToMerge) - { - const auto group_index = threadIdx.x; - const index_t c_ptr_offset = group_index * MBlockWidth * NBlockWidth; - OutDataType* lds_data = reinterpret_cast(smem_ptr_0); - for (auto i_loc = 0; i_loc < NBlockWidth; ++i_loc) - { - const auto lds_index = (group_index * NBlockWidth + i_loc) * TilePartitioner::MPerBlock + group_index; - c_ptr[c_ptr_offset + i_loc] = lds_data[lds_index]; - } - } - - // __syncthreads(); - // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) + // if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType_::NumGroupsToMerge) // { - // // Print out LDS contents. - // // The LDS corresponds TilePartitioner_::MPerBlock * TilePartitioner_::NPerBlock matrix. - // // Print LDS contents as matrix - // printf("LDS Contents (%d x %d):\n", TilePartitioner::MPerBlock, TilePartitioner::NPerBlock); + // const auto group_index = threadIdx.x; + // const index_t c_ptr_offset = group_index * MBlockWidth * NBlockWidth; // OutDataType* lds_data = reinterpret_cast(smem_ptr_0); - - // // Print LDS as a grid with each row being 16 elements wide - // const int total_rows = TilePartitioner::MPerBlock / MBlockWidth; - // const int cols_per_row = TilePartitioner::NPerBlock / NBlockWidth; - - // for(int row = 0; row < total_rows; ++row) { - // printf("Block %d:\n", row); - // for(int col_block = 0; col_block < cols_per_row; ++col_block) { - // printf("Row %d: ", col_block); - // for(int elem = 0; elem < NBlockWidth; ++elem) { - // int idx = (col_block * NBlockWidth + elem) * TilePartitioner::MPerBlock + row * MBlockWidth; - // printf("%.7f ", static_cast(lds_data[idx])); - // } - // printf(" \n"); - // } - // printf("\n\n"); - // } - - // // Print out the c_block_window contents for debugging - // printf("C Ptr Contents (%d x %d):\n", TilePartitioner::MPerBlock, NBlockWidth); - // for(int m = 0; m < TilePartitioner::MPerBlock; ++m) { - // for(int n = 0; n < NBlockWidth; ++n) { - // int idx = m * NBlockWidth + n; - // printf("%.7f ", static_cast(c_ptr[idx])); - // if((n + 1) % NBlockWidth == 0) printf("\n "); // Line break every NBlockWidth elements for readability - // } - // printf("\n"); + // for (auto i_loc = 0; i_loc < NBlockWidth; ++i_loc) + // { + // const auto lds_index = (group_index * NBlockWidth + i_loc) * TilePartitioner::MPerBlock + group_index; + // c_ptr[c_ptr_offset + i_loc] = lds_data[lds_index]; // } // } - // __syncthreads(); + //transfer_lds_to_global(c_ptr, smem_ptr_0); + + __syncthreads(); + if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) + { + constexpr index_t Gs = GroupedConvTraitsType_::NumGroupsToMerge; + constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / Gs; + + // Print out LDS contents. + // The LDS corresponds TilePartitioner_::MPerBlock * TilePartitioner_::NPerBlock matrix. + // Print LDS contents as matrix + printf("LDS Contents (%d x %d):\n", TilePartitioner::MPerBlock, TilePartitioner::NPerBlock); + OutDataType* lds_data = reinterpret_cast(smem_ptr_0); + + for(int c = 0; c < Gs; ++c) { + printf("Block %d:\n", c); + for(int r = 0; r < Gs; ++r) { + printf("Row %d: ", r); + for(int n = 0; n < NBlockWidth; ++n) + { + int idx = (r * NBlockWidth + n) * TilePartitioner::MPerBlock + c; + printf("%.7f ", static_cast(lds_data[idx])); + } + printf(" \n"); + } + printf("\n\n"); + } + + // Print out the c_block_window contents for debugging + printf("C Ptr Contents (%d x %d):\n", TilePartitioner::MPerBlock, NBlockWidth); + for(int m = 0; m < TilePartitioner::MPerBlock; ++m) { + for(int n = 0; n < NBlockWidth; ++n) { + int idx = m * NBlockWidth + n; + printf("%.7f ", static_cast(c_ptr[idx])); + if((n + 1) % NBlockWidth == 0) printf("\n "); // Line break every NBlockWidth elements for readability + } + printf("\n"); + } + } + __syncthreads(); } /**