Fix case k > 1 and c=1.

This commit is contained in:
Ville Pietilä
2025-09-29 16:02:00 +00:00
parent 558054eadb
commit 193907fd85
2 changed files with 173 additions and 21 deletions

View File

@@ -274,27 +274,51 @@ struct CShuffleEpilogue
}
}
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
template <typename DataType, typename StaticTileDistribution>
CK_TILE_DEVICE void print_tensor_matrix_format(
const static_distributed_tensor<DataType, StaticTileDistribution>& tensor,
const char* /*name = "tensor_matrix"*/)
{
const auto spans = tensor.get_distributed_spans();
//static_assert(spans.size() == 2, "This function is for 2D tensors only");
const auto dim0_span = spans[number<0>{}];
const auto dim1_span = spans[number<1>{}];
//printf("%s matrix format (tid %u):\n", name, threadIdx.x);
sweep_tile_span(dim0_span, [&](auto row) {
printf(" ");
sweep_tile_span(dim1_span, [&](auto col) {
constexpr auto distributed_indices = make_tuple(row, col);
const auto value = tensor[distributed_indices];
printf("tid %u: %.7f\n", threadIdx.x, static_cast<float>(value));
});
//printf("\n");
});
//printf("\n");
}
template <index_t MPerGroup, index_t NPerGroup, typename ODramWindow, typename OAccTile, typename DsDramWindows>
//template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
if constexpr (NumGroupsToMerge > 1)
if constexpr (NumGroupsToMerge == 1)
{
// When NumGroupsToMerge > 1, we want to write out only the diagonal blocks.
// Hence, we configure the shuffle such that it iterates one merge group block at a time.
return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
}
else
{
// When NumGroupsToMerge == 1, we want to write out all the blocks.
return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
}
else
{
return merged_op<MPerGroup, NPerGroup>(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
//return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
}
}
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
template <index_t MPerGroup, index_t NPerGroup, typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto merged_op(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
@@ -359,8 +383,8 @@ struct CShuffleEpilogue
block_sync_lds();
constexpr index_t Gs = NumGroupsToMerge;
constexpr index_t MPerGroup = kMPerBlock / Gs;
constexpr index_t NPerGroup = kNPerBlock / Gs;
constexpr index_t MBlockWidth = kMPerBlock / Gs;
constexpr index_t NBlockWidth = kNPerBlock / Gs;
// Tile enconding for a single group (diagonal block in LDS)
constexpr auto dram_tile_encoding = tile_distribution_encoding<
@@ -374,20 +398,31 @@ struct CShuffleEpilogue
constexpr auto dram_tile_distribution = make_static_tile_distribution(dram_tile_encoding);
// The LDS data has the following 4D layout in the row-major case.
// linear_index = c + Gs * m + Gs * MPerGroup * n + Gs * MPerGroup * NPerGroup * r
// linear_index = c + Gs * n + Gs * NBlockWidth * m + Gs * MBlockWidth * NBlockWidth * r
// for 4D coordinates (r,c,m,n) where (r,c) is the group index and (m,n) is the index within the group.
// Within the sub-block, we have column-major layout (n is the faster index).
// We pick-up only the diagonal blocks where r == c.
// For each block, the tile distribution and the tensor descriptors are the same.
// The only thing that changes is the p_smem offset.
constexpr auto lds_block_desc_2d = make_naive_tensor_descriptor(
make_tuple(number<MPerGroup>{}, number<NPerGroup>{}),
make_tuple(number<Gs>{}, number<Gs * MPerGroup>{}));
make_tuple(number<MBlockWidth>{}, number<NBlockWidth>{}),
make_tuple(number<Gs * NBlockWidth>{}, number<Gs>{}));
if (threadIdx.x == 0)
{
printf("""CShuffleEpilogue::merged_op(): MPerGroup=%d, NPerGroup=%d, Gs=%d, MBlockWidth=%d, NBlockWidth=%d\n",
static_cast<int>(MPerGroup), static_cast<int>(NPerGroup), static_cast<int>(Gs),
static_cast<int>(MBlockWidth), static_cast<int>(NBlockWidth));
}
// Loop over the groups (diagonal blocks in LDS)
static_for<0, Gs, 1>{}([&](auto g) {
block_sync_lds();
constexpr index_t group_offset = g * (1 + Gs* MPerGroup * NPerGroup);
// With to the single diagonal block of LDS.
// This block may have more elements that the actual output groups contains
// because we have MPerGroup <= MBlockWidth and NPerGroup <= NBlockWidth.
constexpr index_t group_offset = g * (1 + Gs* MBlockWidth * NBlockWidth * MPerGroup);
auto lds_view = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem) + group_offset, lds_block_desc_2d);
@@ -405,6 +440,10 @@ struct CShuffleEpilogue
auto c_out_tensor = load_tile(lds_window);
// DEBUG: Print out the c_out_tensor contents for debugging
//print_tensor_matrix_format(c_out_tensor, "c_out_tensor");
//__syncthreads();
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});

View File

@@ -99,12 +99,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
ZYX = conv_to_gemm_transformer.ZYX_;
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch
<< ", ZYX: " << ZYX << std::endl;
}
}
@@ -185,12 +187,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
ZYX = conv_to_gemm_transformer.ZYX_;
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch
<< ", ZYX: " << ZYX << std::endl;
}
}
@@ -278,12 +282,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
ZYX = conv_to_gemm_transformer.ZYX_;
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch
<< ", ZYX: " << ZYX << std::endl;
}
}
@@ -310,6 +316,7 @@ struct GroupedConvBwdWeightKernelArgs
index_t GemmK;
index_t GemmBatch;
index_t NumGroupsPerBatch;
index_t ZYX;
const void* out_ptr;
const void* in_ptr;
@@ -797,8 +804,114 @@ struct GroupedConvolutionBackwardWeightKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
constexpr index_t Gs = GroupedConvTraitsType_::NumGroupsToMerge;
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
const index_t ZYX = kargs.ZYX;
const index_t MPerGroup = ConvK;
const index_t NPerGroup = ZYX * ConvC;
if (threadIdx.x == 0)
{
printf("MPerGroup: %d, NPerGroup: %d \n", MPerGroup, NPerGroup);
}
// Check that MPerGroup and NPerGroup map to the existing options
if (MPerGroup == 1 && NPerGroup == 16)
{
EpiloguePipeline{}.template operator()<1, 16, decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else if (MPerGroup == 2 && NPerGroup == 16)
{
EpiloguePipeline{}.template operator()<2, 16, decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else if (MPerGroup == 1 && NPerGroup == 32)
{
EpiloguePipeline{}.template operator()<1, 32, decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else if (MPerGroup == 2 && NPerGroup == 32)
{
EpiloguePipeline{}.template operator()<2, 32, decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else if (MPerGroup == 1 && NPerGroup == 4)
{
EpiloguePipeline{}.template operator()<1, 4, decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else if (MPerGroup == 1 && NPerGroup == 8)
{
EpiloguePipeline{}.template operator()<1, 8, decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else if (MPerGroup == 2 && NPerGroup == 4)
{
EpiloguePipeline{}.template operator()<2, 4, decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else if (MPerGroup == 2 && NPerGroup == 8)
{
EpiloguePipeline{}.template operator()<2, 8, decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
auto print_lds_per_wg = [&]()
{
// Print out LDS contents.
// The LDS corresponds TilePartitioner_::MPerBlock * TilePartitioner_::NPerBlock matrix.
// Print LDS contents as matrix
printf("LDS Contents (%d x %d) for thread block %u:\n", TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, blockIdx.x);
OutDataType* lds_data = reinterpret_cast<OutDataType*>(smem_ptr_0);
const int MBlockWidth = TilePartitioner::MPerBlock / Gs;
const int NBlockWidth = TilePartitioner::NPerBlock / Gs;
const int Ncols = Gs;
const int Nrows = Gs;
for(int c = 0; c < Ncols; ++c) {
printf("Block %d:\n", c);
for(int r = 0; r < Nrows; ++r) {
printf("Row %d:\n", r);
for (int m = 0; m < MBlockWidth; ++m)
{
for(int n = 0; n < NBlockWidth; ++n)
{
int idx = c + Ncols * n + Ncols * NBlockWidth * m + Ncols * MBlockWidth * NBlockWidth * r;
printf("%.7f ", static_cast<float>(lds_data[idx]));
}
printf(" \n");
}
}
printf("\n\n");
}
// Print out the c_block_window contents for debugging
printf("C Ptr Contents (%d x %d):\n", TilePartitioner::MPerBlock, NPerGroup);
for(int m = 0; m < TilePartitioner::MPerBlock; ++m) {
for(int n = 0; n < NPerGroup; ++n) {
int idx = m * NPerGroup + n;
printf("%.7f ", static_cast<float>(c_ptr[idx]));
if((n + 1) % NPerGroup == 0) printf("\n "); // Line break every NBlockWidth elements for readability
}
printf("\n");
}
};
__syncthreads();
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0)
{
print_lds_per_wg();
}
__syncthreads();
// if (blockIdx.x == 1 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0)
// {
// print_lds_per_wg();
// }
// __syncthreads();
}
/**