mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Fix case k > 1 and c=1.
This commit is contained in:
@@ -274,27 +274,51 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
template <typename DataType, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE void print_tensor_matrix_format(
|
||||
const static_distributed_tensor<DataType, StaticTileDistribution>& tensor,
|
||||
const char* /*name = "tensor_matrix"*/)
|
||||
{
|
||||
const auto spans = tensor.get_distributed_spans();
|
||||
//static_assert(spans.size() == 2, "This function is for 2D tensors only");
|
||||
|
||||
const auto dim0_span = spans[number<0>{}];
|
||||
const auto dim1_span = spans[number<1>{}];
|
||||
|
||||
//printf("%s matrix format (tid %u):\n", name, threadIdx.x);
|
||||
|
||||
sweep_tile_span(dim0_span, [&](auto row) {
|
||||
printf(" ");
|
||||
sweep_tile_span(dim1_span, [&](auto col) {
|
||||
constexpr auto distributed_indices = make_tuple(row, col);
|
||||
const auto value = tensor[distributed_indices];
|
||||
printf("tid %u: %.7f\n", threadIdx.x, static_cast<float>(value));
|
||||
});
|
||||
//printf("\n");
|
||||
});
|
||||
//printf("\n");
|
||||
}
|
||||
|
||||
template <index_t MPerGroup, index_t NPerGroup, typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
//template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
|
||||
{
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
if constexpr (NumGroupsToMerge == 1)
|
||||
{
|
||||
// When NumGroupsToMerge > 1, we want to write out only the diagonal blocks.
|
||||
// Hence, we configure the shuffle such that it iterates one merge group block at a time.
|
||||
return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
// When NumGroupsToMerge == 1, we want to write out all the blocks.
|
||||
return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
return merged_op<MPerGroup, NPerGroup>(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
|
||||
//return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
template <index_t MPerGroup, index_t NPerGroup, typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto merged_op(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
@@ -359,8 +383,8 @@ struct CShuffleEpilogue
|
||||
block_sync_lds();
|
||||
|
||||
constexpr index_t Gs = NumGroupsToMerge;
|
||||
constexpr index_t MPerGroup = kMPerBlock / Gs;
|
||||
constexpr index_t NPerGroup = kNPerBlock / Gs;
|
||||
constexpr index_t MBlockWidth = kMPerBlock / Gs;
|
||||
constexpr index_t NBlockWidth = kNPerBlock / Gs;
|
||||
|
||||
// Tile enconding for a single group (diagonal block in LDS)
|
||||
constexpr auto dram_tile_encoding = tile_distribution_encoding<
|
||||
@@ -374,20 +398,31 @@ struct CShuffleEpilogue
|
||||
constexpr auto dram_tile_distribution = make_static_tile_distribution(dram_tile_encoding);
|
||||
|
||||
// The LDS data has the following 4D layout in the row-major case.
|
||||
// linear_index = c + Gs * m + Gs * MPerGroup * n + Gs * MPerGroup * NPerGroup * r
|
||||
// linear_index = c + Gs * n + Gs * NBlockWidth * m + Gs * MBlockWidth * NBlockWidth * r
|
||||
// for 4D coordinates (r,c,m,n) where (r,c) is the group index and (m,n) is the index within the group.
|
||||
// Within the sub-block, we have column-major layout (n is the faster index).
|
||||
// We pick-up only the diagonal blocks where r == c.
|
||||
// For each block, the tile distribution and the tensor descriptors are the same.
|
||||
// The only thing that changes is the p_smem offset.
|
||||
constexpr auto lds_block_desc_2d = make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerGroup>{}, number<NPerGroup>{}),
|
||||
make_tuple(number<Gs>{}, number<Gs * MPerGroup>{}));
|
||||
make_tuple(number<MBlockWidth>{}, number<NBlockWidth>{}),
|
||||
make_tuple(number<Gs * NBlockWidth>{}, number<Gs>{}));
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
printf("""CShuffleEpilogue::merged_op(): MPerGroup=%d, NPerGroup=%d, Gs=%d, MBlockWidth=%d, NBlockWidth=%d\n",
|
||||
static_cast<int>(MPerGroup), static_cast<int>(NPerGroup), static_cast<int>(Gs),
|
||||
static_cast<int>(MBlockWidth), static_cast<int>(NBlockWidth));
|
||||
}
|
||||
|
||||
// Loop over the groups (diagonal blocks in LDS)
|
||||
static_for<0, Gs, 1>{}([&](auto g) {
|
||||
block_sync_lds();
|
||||
|
||||
constexpr index_t group_offset = g * (1 + Gs* MPerGroup * NPerGroup);
|
||||
// With to the single diagonal block of LDS.
|
||||
// This block may have more elements that the actual output groups contains
|
||||
// because we have MPerGroup <= MBlockWidth and NPerGroup <= NBlockWidth.
|
||||
constexpr index_t group_offset = g * (1 + Gs* MBlockWidth * NBlockWidth * MPerGroup);
|
||||
auto lds_view = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem) + group_offset, lds_block_desc_2d);
|
||||
|
||||
@@ -405,6 +440,10 @@ struct CShuffleEpilogue
|
||||
|
||||
auto c_out_tensor = load_tile(lds_window);
|
||||
|
||||
// DEBUG: Print out the c_out_tensor contents for debugging
|
||||
//print_tensor_matrix_format(c_out_tensor, "c_out_tensor");
|
||||
//__syncthreads();
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
|
||||
@@ -99,12 +99,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
ZYX = conv_to_gemm_transformer.ZYX_;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch
|
||||
<< ", ZYX: " << ZYX << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,12 +187,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
ZYX = conv_to_gemm_transformer.ZYX_;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch
|
||||
<< ", ZYX: " << ZYX << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -278,12 +282,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
ZYX = conv_to_gemm_transformer.ZYX_;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch
|
||||
<< ", ZYX: " << ZYX << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,6 +316,7 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
index_t GemmK;
|
||||
index_t GemmBatch;
|
||||
index_t NumGroupsPerBatch;
|
||||
index_t ZYX;
|
||||
|
||||
const void* out_ptr;
|
||||
const void* in_ptr;
|
||||
@@ -797,8 +804,114 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
constexpr index_t Gs = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
const index_t ZYX = kargs.ZYX;
|
||||
const index_t MPerGroup = ConvK;
|
||||
const index_t NPerGroup = ZYX * ConvC;
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
printf("MPerGroup: %d, NPerGroup: %d \n", MPerGroup, NPerGroup);
|
||||
}
|
||||
|
||||
// Check that MPerGroup and NPerGroup map to the existing options
|
||||
if (MPerGroup == 1 && NPerGroup == 16)
|
||||
{
|
||||
EpiloguePipeline{}.template operator()<1, 16, decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else if (MPerGroup == 2 && NPerGroup == 16)
|
||||
{
|
||||
EpiloguePipeline{}.template operator()<2, 16, decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else if (MPerGroup == 1 && NPerGroup == 32)
|
||||
{
|
||||
EpiloguePipeline{}.template operator()<1, 32, decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else if (MPerGroup == 2 && NPerGroup == 32)
|
||||
{
|
||||
EpiloguePipeline{}.template operator()<2, 32, decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else if (MPerGroup == 1 && NPerGroup == 4)
|
||||
{
|
||||
EpiloguePipeline{}.template operator()<1, 4, decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else if (MPerGroup == 1 && NPerGroup == 8)
|
||||
{
|
||||
EpiloguePipeline{}.template operator()<1, 8, decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else if (MPerGroup == 2 && NPerGroup == 4)
|
||||
{
|
||||
EpiloguePipeline{}.template operator()<2, 4, decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else if (MPerGroup == 2 && NPerGroup == 8)
|
||||
{
|
||||
EpiloguePipeline{}.template operator()<2, 8, decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
auto print_lds_per_wg = [&]()
|
||||
{
|
||||
// Print out LDS contents.
|
||||
// The LDS corresponds TilePartitioner_::MPerBlock * TilePartitioner_::NPerBlock matrix.
|
||||
// Print LDS contents as matrix
|
||||
printf("LDS Contents (%d x %d) for thread block %u:\n", TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, blockIdx.x);
|
||||
OutDataType* lds_data = reinterpret_cast<OutDataType*>(smem_ptr_0);
|
||||
|
||||
const int MBlockWidth = TilePartitioner::MPerBlock / Gs;
|
||||
const int NBlockWidth = TilePartitioner::NPerBlock / Gs;
|
||||
const int Ncols = Gs;
|
||||
const int Nrows = Gs;
|
||||
|
||||
for(int c = 0; c < Ncols; ++c) {
|
||||
printf("Block %d:\n", c);
|
||||
for(int r = 0; r < Nrows; ++r) {
|
||||
printf("Row %d:\n", r);
|
||||
for (int m = 0; m < MBlockWidth; ++m)
|
||||
{
|
||||
for(int n = 0; n < NBlockWidth; ++n)
|
||||
{
|
||||
int idx = c + Ncols * n + Ncols * NBlockWidth * m + Ncols * MBlockWidth * NBlockWidth * r;
|
||||
printf("%.7f ", static_cast<float>(lds_data[idx]));
|
||||
}
|
||||
printf(" \n");
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
}
|
||||
|
||||
|
||||
// Print out the c_block_window contents for debugging
|
||||
printf("C Ptr Contents (%d x %d):\n", TilePartitioner::MPerBlock, NPerGroup);
|
||||
for(int m = 0; m < TilePartitioner::MPerBlock; ++m) {
|
||||
for(int n = 0; n < NPerGroup; ++n) {
|
||||
int idx = m * NPerGroup + n;
|
||||
printf("%.7f ", static_cast<float>(c_ptr[idx]));
|
||||
if((n + 1) % NPerGroup == 0) printf("\n "); // Line break every NBlockWidth elements for readability
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
};
|
||||
|
||||
__syncthreads();
|
||||
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0)
|
||||
{
|
||||
print_lds_per_wg();
|
||||
}
|
||||
__syncthreads();
|
||||
// if (blockIdx.x == 1 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0)
|
||||
// {
|
||||
// print_lds_per_wg();
|
||||
// }
|
||||
// __syncthreads();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user