WIP: LDS to global mem transfer using CK tile tensor descriptor and tile distribution encoding.

This commit is contained in:
Ville Pietilä
2025-09-24 15:08:01 +00:00
parent 7280df1bc3
commit 625a78b17b
2 changed files with 120 additions and 87 deletions

View File

@@ -313,9 +313,9 @@ struct CShuffleEpilogue
}
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto merged_op(ODramWindow&, //out_dram_window,
CK_TILE_DEVICE auto merged_op(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows&, //ds_dram_windows,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
@@ -383,13 +383,12 @@ struct CShuffleEpilogue
});
block_sync_lds();
#if 0
// Set-up LDS to global memory copy.
// We copy only the diagonal blocks from LDS to global memory.
// Hence, we must configure the tile distrinbution
// on the group size blocks.
constexpr index_t MPerGroup = kMPerBlock / NumGroupsToMerge;
constexpr index_t NPerGroup = kNPerBlock / NumGroupsToMerge;
// constexpr index_t MPerGroup = kMPerBlock / NumGroupsToMerge;
// constexpr index_t NPerGroup = kNPerBlock / NumGroupsToMerge;
//constexpr index_t NumThreads = get_total_number_of_threads(LdsTileDistr.get_static_tile_distribution_encoding());
// TODO: Remove this debug assert
@@ -412,10 +411,49 @@ struct CShuffleEpilogue
// constexpr index_t NumThreadsDram = get_total_number_of_threads(dram_tile_encoding);
// static_assert(NumThreadsDram == NumThreads, "NumThreadsDram must be equal to NumThreads.");
// auto lds_window = make_tile_window(
// lds_view,
// make_tuple(number<MBlockWidth>{}, number<NBlockWidth>{}),
// make_tuple(number<0>{}, group_index * NBlockWidth),
// sequential_distribution);
// 4D tensor view of LDS memory
// with (g_i, g_j, i, j) where (g_i, g_j) is the group index
// and (i, j) is the index within the group.
constexpr index_t Gs = NumGroupsToMerge;
constexpr index_t MPerGroup = kMPerBlock / Gs;
constexpr index_t NPerGroup = kNPerBlock / Gs;
constexpr auto lds_desc_4d = make_naive_tensor_descriptor(
make_tuple(number<Gs>{}, number<Gs>{}, number<MPerGroup>{}, number<NPerGroup>{}),
make_tuple(number<Gs * MPerGroup * NPerGroup>{}, number<1>{}, number<Gs>{}, number<Gs * MPerGroup>{}));
// We must merge (r,m) and (c,n) dimensions together to make a 2D tensor descriptor.
constexpr auto lds_desc = transform_tensor_descriptor(
lds_desc_4d,
make_tuple(
make_merge_transform(make_tuple(Gs, MPerGroup)),
make_merge_transform(make_tuple(Gs, NPerGroup))
),
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{})
);
auto lds_view = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_desc);
// This is hard-coded for a specific MPerBlock, NPerBlock, NumGroupsToMerge case.
// constexpr auto dram_tile_encoding = tile_distribution_encoding<
// sequence<>,
// tuple<sequence<1, 1, 8, 1>,
// sequence<1, 1, 8, 16>>,
// tuple<sequence<1,2>, sequence<1,2>>,
// tuple<sequence<1,1>, sequence<2,2>>,
// sequence<1, 1, 2, 2>,
// sequence<0, 3, 0, 3>>{};
constexpr auto dram_tile_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<1, 1, 8, 1>,
sequence<1, 1, 8, 16>>,
tuple<sequence<1, Gs, MPerGroup, 1>,
sequence<1, Gs, NPerGroup, 1>>,
tuple<sequence<1,2>, sequence<1,2>>,
tuple<sequence<1,1>, sequence<2,2>>,
sequence<1, 1, 2, 2>,
@@ -429,40 +467,34 @@ struct CShuffleEpilogue
},
number<NumDTensor>{});
// Calculate which block in the Gs x Gs space we are located at.
const auto x_space_coord = dram_tile_distribution.calculate_index();
const index_t m_block = x_space_coord[0] / MPerGroup;
const index_t n_block = x_space_coord[1] / NPerGroup;
auto current_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{0, 0},
dram_tile_distribution);
lds_view,
make_tuple(number<Gs>{}, number<Gs>{}, number<MPerGroup>{}, number<NPerGroup>{}),
{0, 0, 0, 0},
dram_tile_distribution);
block_sync_lds();
// Calculate which block in the Gm x Gm space we are located at.
auto get_block_number = [&]() -> ck_tile::tuple<index_t, index_t>
{
const auto x_space_coord = dram_tile_distribution.calculate_index();
const index_t m_block = x_space_coord[0] / MPerGroup;
const index_t n_block = x_space_coord[1] / NPerGroup;
return make_tuple(m_block, n_block);
};
auto mask = [&]() -> bool
{
// Return true only for the diagonal blocks.
const auto blockId = get_block_number();
return blockId[number<0>{}] == blockId[number<1>{}];
};
if (mask())
// Copy only the diagonal blocks.
if (m_block == n_block)
{
// Load static_distributed_tensor from LDS.
auto c_out_tensor = load_tile(current_lds_window);
// Debug: Print out the c_out_tensor contents for debugging
__syncthreads();
if (threadIdx.x == 0 && blockIdx.x == 0)
{
// Print out the c_out_tensor contents for debugging
print_tensor_matrix_format(c_out_tensor, "c_out_tensor");
}
__syncthreads();
// End debug
// TODO: We must move the d_dram_windows and the out_dram_windows to the correct group position.
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
@@ -483,7 +515,6 @@ struct CShuffleEpilogue
update_tile(out_dram_window, c_out_tensor);
}
}
#endif
}
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>

View File

@@ -754,11 +754,11 @@ struct GroupedConvolutionBackwardWeightKernel
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
}
// template <typename OutDataType, typename TilePartitioner, typename GroupedConvTraitsType>
// CK_TILE_DEVICE void transfer_lds_to_global_simple(OutDataType* c_ptr, void* smem_ptr_0)
// template <typename OutDataType>
// CK_TILE_DEVICE static void transfer_lds_to_global(OutDataType* c_ptr, void* smem_ptr_0)
// {
// constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType::NumGroupsToMerge;
// constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType::NumGroupsToMerge;
// constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
// constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
// // Create a single-thread tile distribution for sequential access
// constexpr auto sequential_encoding = tile_distribution_encoding<
@@ -772,18 +772,18 @@ struct GroupedConvolutionBackwardWeightKernel
// constexpr auto sequential_distribution = make_static_tile_distribution(sequential_encoding);
// if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType::NumGroupsToMerge) {
// if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType_::NumGroupsToMerge) {
// const auto group_index = threadIdx.x;
// // LDS tensor view
// // LDS tensor view, column-major ordering
// constexpr auto lds_desc = make_naive_tensor_descriptor(
// make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
// make_tuple(number<TilePartitioner::NPerBlock>{}, number<1>{}));
// make_tuple(number<1>{}, number<TilePartitioner::NPerBlock>{}));
// auto lds_view = make_tensor_view<address_space_enum::lds>(
// static_cast<OutDataType*>(smem_ptr_0), lds_desc);
// // Global memory tensor view
// // Global memory tensor view, row-major ordering
// constexpr auto global_desc = make_naive_tensor_descriptor(
// make_tuple(number<MBlockWidth>{}, number<NBlockWidth>{}),
// make_tuple(number<NBlockWidth>{}, number<1>{}));
@@ -864,59 +864,61 @@ struct GroupedConvolutionBackwardWeightKernel
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
//constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
//constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
//Run LDS to global memory manually, one thread per convolution group.
constexpr index_t MBlockWidth = TilePartitioner::MPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / GroupedConvTraitsType_::NumGroupsToMerge;
if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType_::NumGroupsToMerge)
{
const auto group_index = threadIdx.x;
const index_t c_ptr_offset = group_index * MBlockWidth * NBlockWidth;
OutDataType* lds_data = reinterpret_cast<OutDataType*>(smem_ptr_0);
for (auto i_loc = 0; i_loc < NBlockWidth; ++i_loc)
{
const auto lds_index = (group_index * NBlockWidth + i_loc) * TilePartitioner::MPerBlock + group_index;
c_ptr[c_ptr_offset + i_loc] = lds_data[lds_index];
}
}
// __syncthreads();
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0)
// if (blockIdx.x == 0 && threadIdx.x < GroupedConvTraitsType_::NumGroupsToMerge)
// {
// // Print out LDS contents.
// // The LDS corresponds TilePartitioner_::MPerBlock * TilePartitioner_::NPerBlock matrix.
// // Print LDS contents as matrix
// printf("LDS Contents (%d x %d):\n", TilePartitioner::MPerBlock, TilePartitioner::NPerBlock);
// const auto group_index = threadIdx.x;
// const index_t c_ptr_offset = group_index * MBlockWidth * NBlockWidth;
// OutDataType* lds_data = reinterpret_cast<OutDataType*>(smem_ptr_0);
// // Print LDS as a grid with each row being 16 elements wide
// const int total_rows = TilePartitioner::MPerBlock / MBlockWidth;
// const int cols_per_row = TilePartitioner::NPerBlock / NBlockWidth;
// for(int row = 0; row < total_rows; ++row) {
// printf("Block %d:\n", row);
// for(int col_block = 0; col_block < cols_per_row; ++col_block) {
// printf("Row %d: ", col_block);
// for(int elem = 0; elem < NBlockWidth; ++elem) {
// int idx = (col_block * NBlockWidth + elem) * TilePartitioner::MPerBlock + row * MBlockWidth;
// printf("%.7f ", static_cast<float>(lds_data[idx]));
// }
// printf(" \n");
// }
// printf("\n\n");
// }
// // Print out the c_block_window contents for debugging
// printf("C Ptr Contents (%d x %d):\n", TilePartitioner::MPerBlock, NBlockWidth);
// for(int m = 0; m < TilePartitioner::MPerBlock; ++m) {
// for(int n = 0; n < NBlockWidth; ++n) {
// int idx = m * NBlockWidth + n;
// printf("%.7f ", static_cast<float>(c_ptr[idx]));
// if((n + 1) % NBlockWidth == 0) printf("\n "); // Line break every NBlockWidth elements for readability
// }
// printf("\n");
// for (auto i_loc = 0; i_loc < NBlockWidth; ++i_loc)
// {
// const auto lds_index = (group_index * NBlockWidth + i_loc) * TilePartitioner::MPerBlock + group_index;
// c_ptr[c_ptr_offset + i_loc] = lds_data[lds_index];
// }
// }
// __syncthreads();
//transfer_lds_to_global(c_ptr, smem_ptr_0);
__syncthreads();
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0)
{
constexpr index_t Gs = GroupedConvTraitsType_::NumGroupsToMerge;
constexpr index_t NBlockWidth = TilePartitioner::NPerBlock / Gs;
// Print out LDS contents.
// The LDS corresponds TilePartitioner_::MPerBlock * TilePartitioner_::NPerBlock matrix.
// Print LDS contents as matrix
printf("LDS Contents (%d x %d):\n", TilePartitioner::MPerBlock, TilePartitioner::NPerBlock);
OutDataType* lds_data = reinterpret_cast<OutDataType*>(smem_ptr_0);
for(int c = 0; c < Gs; ++c) {
printf("Block %d:\n", c);
for(int r = 0; r < Gs; ++r) {
printf("Row %d: ", r);
for(int n = 0; n < NBlockWidth; ++n)
{
int idx = (r * NBlockWidth + n) * TilePartitioner::MPerBlock + c;
printf("%.7f ", static_cast<float>(lds_data[idx]));
}
printf(" \n");
}
printf("\n\n");
}
// Print out the c_block_window contents for debugging
printf("C Ptr Contents (%d x %d):\n", TilePartitioner::MPerBlock, NBlockWidth);
for(int m = 0; m < TilePartitioner::MPerBlock; ++m) {
for(int n = 0; n < NBlockWidth; ++n) {
int idx = m * NBlockWidth + n;
printf("%.7f ", static_cast<float>(c_ptr[idx]));
if((n + 1) % NBlockWidth == 0) printf("\n "); // Line break every NBlockWidth elements for readability
}
printf("\n");
}
}
__syncthreads();
}
/**