[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

* hacky poc

* fix tile engine

* kill the lambda

* maybe fix test build

* more fixes

* clang-format

* save temp

* clang-format

* mostly fix examples

* clang-format

* remove dead code

* more cleanup

* fix fmha bwd build (default epilogue set/add appears to be broken)

* fix default epilogue tests but not correctness

* clang-format

* fix bquant

* clang-format

* cleanup dead code

* rearrange make windows for readability

* restore changes to IsSupportedArgument

* fix smoke-builder

* clang-format

* fixup rename class

* build fixes

* clang-format

* fix builder

* fixup

* remove set from builder tests

* fix test

* clang-format

* re-refactor the kernels

* clang-format

* fix header license

* remove memory operation from conv bwd test

* clang-format

* clang-format example,include

* clang-format test

* build fixes

* clang-format

* solve compilation error

* fix the CI

* solve compilation error

* clang format

* solve merge conflict

* solve merge conflict

* solve the gfx11 error

* solve test error

* moar build fixes

* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Max Podkorytov
2026-01-04 03:28:14 -08:00
committed by GitHub
parent ec23be0b9d
commit e339101e9c
68 changed files with 4198 additions and 4298 deletions

View File

@@ -617,6 +617,117 @@ struct GroupedConvolutionBackwardDataKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE static auto
MakeABlockWindow(const OutDataType* a_ptr,
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
const index_t group_id,
const index_t i_m,
const index_t i_k)
{
// Step 1: Create tensor view for A (Out tensor)
const auto& a_tensor_view =
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_descs_m_k[group_id]);
// Step 2: Create padded view
const auto& a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<true, true>{});
// Step 3: Create tile window
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, i_k});
return a_block_window;
}
CK_TILE_DEVICE static auto
MakeBBlockWindow(const InDataType* b_ptr,
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
const index_t group_id,
const index_t i_n,
const index_t i_k)
{
// Step 1: Create tensor view for B (Weight tensor)
const auto& b_tensor_view =
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_descs_n_k[group_id]);
// Step 2: Create padded view
const auto& b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
// Step 3: Create tile window
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_k, i_n});
return b_block_window;
}
CK_TILE_DEVICE static auto
MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
const index_t group_id,
const index_t i_m,
const index_t i_n)
{
// Create D tensor block windows
const auto ds_block_window = generate_tuple(
[&](auto i) {
// Step 1: Create tensor view for D
const auto& d_tensor_view = make_tensor_view<address_space_enum::global>(
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
// Step 2: Create padded view
const auto& d_pad_view =
pad_tensor_view(d_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
// Step 3: Create tile window
return make_tile_window(d_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
},
number<NumDTensor>{});
return ds_block_window;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto
MakeCBlockWindow(WeiDataType* c_ptr,
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
const index_t group_id,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor view for C (Input tensor)
const auto& c_tensor_view = make_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr, kargs.c_grid_descs_m_n[group_id]);
// Step 2: Create padded view
const auto& c_pad_view = pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
// Step 3: Create tile window
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return c_block_window;
}
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
{
@@ -895,38 +1006,49 @@ struct GroupedConvolutionBackwardDataKernel
const index_t block_idx_k,
const index_t group_id)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k);
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k);
const auto& d_block_window =
MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
// Run Epilogue Pipeline with k_batch dispatch
if(k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
* @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
@@ -951,23 +1073,19 @@ struct GroupedConvolutionBackwardDataKernel
const index_t block_idx_k,
const index_t group_id)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k);
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k);
const auto& d_block_window =
MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
num_loop,
@@ -976,11 +1094,27 @@ struct GroupedConvolutionBackwardDataKernel
smem_ptr_0,
smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
// Run Epilogue Pipeline with k_batch dispatch
if(k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized& kargs,
@@ -1066,8 +1200,7 @@ struct GroupedConvolutionBackwardDataKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
@@ -1086,8 +1219,7 @@ struct GroupedConvolutionBackwardDataKernel
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,

View File

@@ -518,25 +518,6 @@ struct GroupedConvolutionBackwardWeightKernel
return false;
}
#if defined(__gfx11__)
if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set)
{
return false;
}
#endif
if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add)
{
if(kargs.k_batch == 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1.");
}
return false;
}
}
if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
!std::is_same_v<typename EpiloguePipeline::ODataType, double>)
{
@@ -704,29 +685,31 @@ struct GroupedConvolutionBackwardWeightKernel
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const OutDataType* a_ptr,
const InDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
WeiDataType* c_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
MakeCBlockWindow(WeiDataType* c_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
const auto& a_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(a_ptr,
kargs.a_grid_desc_k_m); // A: out
}();
const auto& c_tensor_view =
make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, kargs.c_grid_desc_m_n);
const auto& b_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(b_ptr,
kargs.b_grid_desc_k_n); // B: in
}();
const auto& c_pad_view = pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
const auto& c_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr,
kargs.c_grid_desc_m_n);
}();
return make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{block_idx_m, block_idx_n});
}
CK_TILE_DEVICE static auto
MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
@@ -741,30 +724,7 @@ struct GroupedConvolutionBackwardWeightKernel
},
number<NumDTensor>{});
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<true, true>{});
}();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I1);
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
}();
const auto& ds_tensor_view = views.at(I2);
const auto& ds_pad_view = generate_tuple(
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
@@ -773,67 +733,58 @@ struct GroupedConvolutionBackwardWeightKernel
},
number<NumDTensor>{});
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I3);
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
}();
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
}
/**
* @brief Create views to the data that each workgroup will process.
*
* @param views padded views of A, B, D and C tensors
* @param i_m block m-index
* @param i_n block n-index
* @param i_k block k-index
*
* @return tuple of tile windows for A, B, D and C tensors
*/
template <typename PadView>
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
const index_t i_m,
const index_t i_n,
const index_t i_k)
{
const auto& a_pad_view = views.at(I0);
const auto& b_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& c_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_k, i_m});
}();
const auto& b_block_window = [&]() {
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_k, i_n});
}();
const auto ds_block_window = generate_tuple(
return generate_tuple(
[&](auto i) {
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
{block_idx_m, block_idx_n});
},
number<NumDTensor>{});
}
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
CK_TILE_DEVICE static auto
MakeBBlockWindow(const InDataType* b_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const index_t block_idx_n,
const index_t block_idx_k)
{
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
const auto& b_tensor_view =
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_k_n);
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
const auto& b_pad_view =
pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
return make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{block_idx_k, block_idx_n});
}
CK_TILE_DEVICE static auto
MakeABlockWindow(const OutDataType* a_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const index_t block_idx_m,
const index_t block_idx_k)
{
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view =
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_k_m);
const auto& a_pad_view =
pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
number<TilePartitioner::MPerBlock>{}),
sequence<true, true>{});
return make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::MPerBlock>{}),
{block_idx_k, block_idx_m});
}
/**
@@ -859,28 +810,30 @@ struct GroupedConvolutionBackwardWeightKernel
const index_t block_idx_n,
const index_t block_idx_k)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
// Create block windows using helper methods
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k);
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k);
const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
// Run Epilogue Pipeline with k_batch dispatching
if(kargs.k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
/**
@@ -910,27 +863,33 @@ struct GroupedConvolutionBackwardWeightKernel
const index_t block_idx_n,
const index_t block_idx_k)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
// Create block windows using helper methods
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k);
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k);
const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
// Run Epilogue Pipeline with k_batch dispatching
if(kargs.k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else
{
#if defined(__gfx11__)
return;
#endif
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
@@ -960,12 +919,6 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
#if defined(__gfx11__)
if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set)
{
return;
}
#endif
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
{
CallExplicitGemm(kargs);
@@ -1001,9 +954,7 @@ struct GroupedConvolutionBackwardWeightKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
@@ -1021,9 +972,7 @@ struct GroupedConvolutionBackwardWeightKernel
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,

View File

@@ -794,34 +794,53 @@ struct GroupedConvolutionForwardKernel
return true;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
typename ADescType,
typename BDescType,
typename CDescType>
template <typename ADescType>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
const ADescType& a_desc,
const BDescType& b_desc,
const CDescType& c_desc)
MakeABlockWindow(const InDataType* a_ptr, const ADescType& a_desc, const index_t block_idx_m)
{
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
const auto& a_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
}();
// Step 1: Create tensor view
const auto& a_tensor_view = make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
const auto& b_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
}();
// Step 2: Create padded view
const auto& a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<true, true>{});
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(c_ptr, c_desc);
}();
// Step 3: Create tile window
return make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
}
template <typename BDescType>
CK_TILE_DEVICE static auto
MakeBBlockWindow(const WeiDataType* b_ptr, const BDescType& b_desc, const index_t block_idx_n)
{
// Step 1: Create tensor view
const auto& b_tensor_view = make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
// Step 2: Create padded view
const auto& b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<true, true>{});
// Step 3: Create tile window
return make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{block_idx_n, 0});
}
template <typename CDescType>
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const CDescType& c_desc,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor views
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
@@ -836,30 +855,8 @@ struct GroupedConvolutionForwardKernel
},
number<NumDTensor>{});
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<true, true>{});
}();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I1);
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<true, true>{});
}();
const auto& ds_tensor_view = views.at(I2);
const auto& ds_pad_view = generate_tuple(
// Step 2: Create padded views
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
@@ -868,55 +865,38 @@ struct GroupedConvolutionForwardKernel
},
number<NumDTensor>{});
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I3);
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
}();
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& b_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& c_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}();
const auto& b_block_window = [&]() {
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}();
const auto ds_block_window = generate_tuple(
// Step 3: Create tile windows
return generate_tuple(
[&](auto i) {
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
{block_idx_m, block_idx_n});
},
number<NumDTensor>{});
}
auto c_block_window = make_tile_window(
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename CDescType>
CK_TILE_DEVICE static auto MakeCBlockWindow(OutDataType* c_ptr,
const CDescType& c_desc,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor view
const auto& c_tensor_view =
make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, c_desc);
// Step 2: Create padded view
const auto& c_pad_view = pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
// Step 3: Create tile window
return make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
{block_idx_m, block_idx_n});
}
/**
@@ -931,6 +911,7 @@ struct GroupedConvolutionForwardKernel
* @param b_desc Weight tensor B descriptor
* @param c_desc Output tensor C descriptor
* @param gemm_k The GEMM K dimension
* @param k_batch The K batch parameter for split-K
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
@@ -945,34 +926,41 @@ struct GroupedConvolutionForwardKernel
const BDescType& b_desc,
const CDescType& c_desc,
const index_t gemm_k,
const index_t k_batch,
const index_t block_idx_m,
const index_t block_idx_n,
const CDElementwise& elfunc)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m);
const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
// Run Epilogue Pipeline with k_batch dispatching
if(k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, c_desc, block_idx_m, block_idx_n);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, c_desc, block_idx_m, block_idx_n);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}
/**
@@ -990,6 +978,7 @@ struct GroupedConvolutionForwardKernel
* @param b_desc Weight tensor B descriptor
* @param c_desc Output tensor C descriptor
* @param gemm_k The GEMM K dimension
* @param k_batch The K batch parameter for split-K
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
@@ -1005,33 +994,41 @@ struct GroupedConvolutionForwardKernel
const BDescType& b_desc,
const CDescType& c_desc,
const index_t gemm_k,
const index_t k_batch,
const index_t block_idx_m,
const index_t block_idx_n,
const CDElementwise& elfunc)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m);
const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
// Run Epilogue Pipeline with k_batch dispatching
if(k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, c_desc, block_idx_m, block_idx_n);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, c_desc, block_idx_m, block_idx_n);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const
@@ -1185,9 +1182,7 @@ struct GroupedConvolutionForwardKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
@@ -1200,6 +1195,7 @@ struct GroupedConvolutionForwardKernel
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);
@@ -1207,9 +1203,7 @@ struct GroupedConvolutionForwardKernel
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,
@@ -1221,6 +1215,7 @@ struct GroupedConvolutionForwardKernel
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);