mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
WIP: LDS to global mem transfer using CK tile tensor descriptor and tile distribution encoding.
This commit is contained in:
@@ -313,9 +313,9 @@ struct CShuffleEpilogue
|
||||
}
|
||||
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto merged_op(ODramWindow&, //out_dram_window,
|
||||
CK_TILE_DEVICE auto merged_op(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows&, //ds_dram_windows,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
@@ -383,13 +383,12 @@ struct CShuffleEpilogue
|
||||
});
|
||||
block_sync_lds();
|
||||
|
||||
#if 0
|
||||
// Set-up LDS to global memory copy.
|
||||
// We copy only the diagonal blocks from LDS to global memory.
|
||||
// Hence, we must configure the tile distrinbution
|
||||
// on the group size blocks.
|
||||
constexpr index_t MPerGroup = kMPerBlock / NumGroupsToMerge;
|
||||
constexpr index_t NPerGroup = kNPerBlock / NumGroupsToMerge;
|
||||
// constexpr index_t MPerGroup = kMPerBlock / NumGroupsToMerge;
|
||||
// constexpr index_t NPerGroup = kNPerBlock / NumGroupsToMerge;
|
||||
//constexpr index_t NumThreads = get_total_number_of_threads(LdsTileDistr.get_static_tile_distribution_encoding());
|
||||
|
||||
// TODO: Remove this debug assert
|
||||
@@ -412,10 +411,49 @@ struct CShuffleEpilogue
|
||||
// constexpr index_t NumThreadsDram = get_total_number_of_threads(dram_tile_encoding);
|
||||
// static_assert(NumThreadsDram == NumThreads, "NumThreadsDram must be equal to NumThreads.");
|
||||
|
||||
// auto lds_window = make_tile_window(
|
||||
// lds_view,
|
||||
// make_tuple(number<MBlockWidth>{}, number<NBlockWidth>{}),
|
||||
// make_tuple(number<0>{}, group_index * NBlockWidth),
|
||||
// sequential_distribution);
|
||||
|
||||
// 4D tensor view of LDS memory
|
||||
// with (g_i, g_j, i, j) where (g_i, g_j) is the group index
|
||||
// and (i, j) is the index within the group.
|
||||
constexpr index_t Gs = NumGroupsToMerge;
|
||||
constexpr index_t MPerGroup = kMPerBlock / Gs;
|
||||
constexpr index_t NPerGroup = kNPerBlock / Gs;
|
||||
constexpr auto lds_desc_4d = make_naive_tensor_descriptor(
|
||||
make_tuple(number<Gs>{}, number<Gs>{}, number<MPerGroup>{}, number<NPerGroup>{}),
|
||||
make_tuple(number<Gs * MPerGroup * NPerGroup>{}, number<1>{}, number<Gs>{}, number<Gs * MPerGroup>{}));
|
||||
|
||||
// We must merge (r,m) and (c,n) dimensions together to make a 2D tensor descriptor.
|
||||
constexpr auto lds_desc = transform_tensor_descriptor(
|
||||
lds_desc_4d,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(Gs, MPerGroup)),
|
||||
make_merge_transform(make_tuple(Gs, NPerGroup))
|
||||
),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{})
|
||||
);
|
||||
|
||||
auto lds_view = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_desc);
|
||||
|
||||
// This is hard-coded for a specific MPerBlock, NPerBlock, NumGroupsToMerge case.
|
||||
// constexpr auto dram_tile_encoding = tile_distribution_encoding<
|
||||
// sequence<>,
|
||||
// tuple<sequence<1, 1, 8, 1>,
|
||||
// sequence<1, 1, 8, 16>>,
|
||||
// tuple<sequence<1,2>, sequence<1,2>>,
|
||||
// tuple<sequence<1,1>, sequence<2,2>>,
|
||||
// sequence<1, 1, 2, 2>,
|
||||
// sequence<0, 3, 0, 3>>{};
|
||||
constexpr auto dram_tile_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<1, 1, 8, 1>,
|
||||
sequence<1, 1, 8, 16>>,
|
||||
tuple<sequence<1, Gs, MPerGroup, 1>,
|
||||
sequence<1, Gs, NPerGroup, 1>>,
|
||||
tuple<sequence<1,2>, sequence<1,2>>,
|
||||
tuple<sequence<1,1>, sequence<2,2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
@@ -429,40 +467,34 @@ struct CShuffleEpilogue
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Calculate which block in the Gs x Gs space we are located at.
|
||||
const auto x_space_coord = dram_tile_distribution.calculate_index();
|
||||
const index_t m_block = x_space_coord[0] / MPerGroup;
|
||||
const index_t n_block = x_space_coord[1] / NPerGroup;
|
||||
|
||||
auto current_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{0, 0},
|
||||
dram_tile_distribution);
|
||||
lds_view,
|
||||
make_tuple(number<Gs>{}, number<Gs>{}, number<MPerGroup>{}, number<NPerGroup>{}),
|
||||
{0, 0, 0, 0},
|
||||
dram_tile_distribution);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Calculate which block in the Gm x Gm space we are located at.
|
||||
auto get_block_number = [&]() -> ck_tile::tuple<index_t, index_t>
|
||||
{
|
||||
const auto x_space_coord = dram_tile_distribution.calculate_index();
|
||||
const index_t m_block = x_space_coord[0] / MPerGroup;
|
||||
const index_t n_block = x_space_coord[1] / NPerGroup;
|
||||
return make_tuple(m_block, n_block);
|
||||
};
|
||||
|
||||
auto mask = [&]() -> bool
|
||||
{
|
||||
// Return true only for the diagonal blocks.
|
||||
const auto blockId = get_block_number();
|
||||
return blockId[number<0>{}] == blockId[number<1>{}];
|
||||
};
|
||||
|
||||
if (mask())
|
||||
// Copy only the diagonal blocks.
|
||||
if (m_block == n_block)
|
||||
{
|
||||
// Load static_distributed_tensor from LDS.
|
||||
auto c_out_tensor = load_tile(current_lds_window);
|
||||
|
||||
// Debug: Print out the c_out_tensor contents for debugging
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
{
|
||||
// Print out the c_out_tensor contents for debugging
|
||||
print_tensor_matrix_format(c_out_tensor, "c_out_tensor");
|
||||
}
|
||||
__syncthreads();
|
||||
// End debug
|
||||
|
||||
// TODO: We must move the d_dram_windows and the out_dram_windows to the correct group position.
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
@@ -483,7 +515,6 @@ struct CShuffleEpilogue
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
|
||||
@@ -754,11 +754,11 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
|
||||
}
|
||||
|
||||
// template <typename OutDataType, typename TilePartitioner, typename GroupedConvTraitsType>
|
||||
// CK_TILE_DEVICE void transfer_lds_to_global_simple(OutDataType* c_ptr, void* smem_ptr_0)
|
||||
// template <typename OutDataType>
|
||||
// CK_TILE_DEVICE static void transfer_lds_to_global(OutDataType* c_ptr, void* smem_ptr_0)
|
||||
// {
|
||||
// constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType::NumGroupsToMerge;
|
||||
// constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType::NumGroupsToMerge;
|
||||
// constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
|
||||
// // Create a single-thread tile distribution for sequential access
|
||||
// constexpr auto sequential_encoding = tile_distribution_encoding<
|
||||
@@ -772,18 +772,18 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
// constexpr auto sequential_distribution = make_static_tile_distribution(sequential_encoding);
|
||||
|
||||
// if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType::NumGroupsToMerge) {
|
||||
// if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType_::NumGroupsToMerge) {
|
||||
// const auto group_index = threadIdx.x;
|
||||
|
||||
// // LDS tensor view
|
||||
// // LDS tensor view, column-major ordering
|
||||
// constexpr auto lds_desc = make_naive_tensor_descriptor(
|
||||
// make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
// make_tuple(number<TilePartitioner::NPerBlock>{}, number<1>{}));
|
||||
// make_tuple(number<1>{}, number<TilePartitioner::NPerBlock>{}));
|
||||
|
||||
// auto lds_view = make_tensor_view<address_space_enum::lds>(
|
||||
// static_cast<OutDataType*>(smem_ptr_0), lds_desc);
|
||||
|
||||
// // Global memory tensor view
|
||||
// // Global memory tensor view, row-major ordering
|
||||
// constexpr auto global_desc = make_naive_tensor_descriptor(
|
||||
// make_tuple(number<MBlockWidth>{}, number<NBlockWidth>{}),
|
||||
// make_tuple(number<NBlockWidth>{}, number<1>{}));
|
||||
@@ -864,59 +864,61 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
|
||||
//constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
//constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
|
||||
//Run LDS to global memory manually, one thread per convolution group.
|
||||
constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType_::NumGroupsToMerge)
|
||||
{
|
||||
const auto group_index = threadIdx.x;
|
||||
const index_t c_ptr_offset = group_index * MBlockWidth * NBlockWidth;
|
||||
OutDataType* lds_data = reinterpret_cast<OutDataType*>(smem_ptr_0);
|
||||
for (auto i_loc = 0; i_loc < NBlockWidth; ++i_loc)
|
||||
{
|
||||
const auto lds_index = (group_index * NBlockWidth + i_loc) * TilePartitioner::MPerBlock + group_index;
|
||||
c_ptr[c_ptr_offset + i_loc] = lds_data[lds_index];
|
||||
}
|
||||
}
|
||||
|
||||
// __syncthreads();
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0)
|
||||
// if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType_::NumGroupsToMerge)
|
||||
// {
|
||||
// // Print out LDS contents.
|
||||
// // The LDS corresponds TilePartitioner_::MPerBlock * TilePartitioner_::NPerBlock matrix.
|
||||
// // Print LDS contents as matrix
|
||||
// printf("LDS Contents (%d x %d):\n", TilePartitioner::MPerBlock, TilePartitioner::NPerBlock);
|
||||
// const auto group_index = threadIdx.x;
|
||||
// const index_t c_ptr_offset = group_index * MBlockWidth * NBlockWidth;
|
||||
// OutDataType* lds_data = reinterpret_cast<OutDataType*>(smem_ptr_0);
|
||||
|
||||
// // Print LDS as a grid with each row being 16 elements wide
|
||||
// const int total_rows = TilePartitioner::MPerBlock / MBlockWidth;
|
||||
// const int cols_per_row = TilePartitioner::NPerBlock / NBlockWidth;
|
||||
|
||||
// for(int row = 0; row < total_rows; ++row) {
|
||||
// printf("Block %d:\n", row);
|
||||
// for(int col_block = 0; col_block < cols_per_row; ++col_block) {
|
||||
// printf("Row %d: ", col_block);
|
||||
// for(int elem = 0; elem < NBlockWidth; ++elem) {
|
||||
// int idx = (col_block * NBlockWidth + elem) * TilePartitioner::MPerBlock + row * MBlockWidth;
|
||||
// printf("%.7f ", static_cast<float>(lds_data[idx]));
|
||||
// }
|
||||
// printf(" \n");
|
||||
// }
|
||||
// printf("\n\n");
|
||||
// }
|
||||
|
||||
// // Print out the c_block_window contents for debugging
|
||||
// printf("C Ptr Contents (%d x %d):\n", TilePartitioner::MPerBlock, NBlockWidth);
|
||||
// for(int m = 0; m < TilePartitioner::MPerBlock; ++m) {
|
||||
// for(int n = 0; n < NBlockWidth; ++n) {
|
||||
// int idx = m * NBlockWidth + n;
|
||||
// printf("%.7f ", static_cast<float>(c_ptr[idx]));
|
||||
// if((n + 1) % NBlockWidth == 0) printf("\n "); // Line break every NBlockWidth elements for readability
|
||||
// }
|
||||
// printf("\n");
|
||||
// for (auto i_loc = 0; i_loc < NBlockWidth; ++i_loc)
|
||||
// {
|
||||
// const auto lds_index = (group_index * NBlockWidth + i_loc) * TilePartitioner::MPerBlock + group_index;
|
||||
// c_ptr[c_ptr_offset + i_loc] = lds_data[lds_index];
|
||||
// }
|
||||
// }
|
||||
// __syncthreads();
|
||||
//transfer_lds_to_global(c_ptr, smem_ptr_0);
|
||||
|
||||
__syncthreads();
|
||||
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0)
|
||||
{
|
||||
constexpr index_t Gs = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / Gs;
|
||||
|
||||
// Print out LDS contents.
|
||||
// The LDS corresponds TilePartitioner_::MPerBlock * TilePartitioner_::NPerBlock matrix.
|
||||
// Print LDS contents as matrix
|
||||
printf("LDS Contents (%d x %d):\n", TilePartitioner::MPerBlock, TilePartitioner::NPerBlock);
|
||||
OutDataType* lds_data = reinterpret_cast<OutDataType*>(smem_ptr_0);
|
||||
|
||||
for(int c = 0; c < Gs; ++c) {
|
||||
printf("Block %d:\n", c);
|
||||
for(int r = 0; r < Gs; ++r) {
|
||||
printf("Row %d: ", r);
|
||||
for(int n = 0; n < NBlockWidth; ++n)
|
||||
{
|
||||
int idx = (r * NBlockWidth + n) * TilePartitioner::MPerBlock + c;
|
||||
printf("%.7f ", static_cast<float>(lds_data[idx]));
|
||||
}
|
||||
printf(" \n");
|
||||
}
|
||||
printf("\n\n");
|
||||
}
|
||||
|
||||
// Print out the c_block_window contents for debugging
|
||||
printf("C Ptr Contents (%d x %d):\n", TilePartitioner::MPerBlock, NBlockWidth);
|
||||
for(int m = 0; m < TilePartitioner::MPerBlock; ++m) {
|
||||
for(int n = 0; n < NBlockWidth; ++n) {
|
||||
int idx = m * NBlockWidth + n;
|
||||
printf("%.7f ", static_cast<float>(c_ptr[idx]));
|
||||
if((n + 1) % NBlockWidth == 0) printf("\n "); // Line break every NBlockWidth elements for readability
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user