[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

* hacky poc

* fix tile engine

* kill the lambda

* maybe fix test build

* more fixes

* clang-format

* save temp

* clang-format

* mostly fix examples

* clang-format

* remove dead code

* more cleanup

* fix fmha bwd build (default epilogue set/add appears to be broken)

* fix default epilogue tests but not correctness

* clang-format

* fix bquant

* clang-format

* cleanup dead code

* rearrange make windows for readability

* restore changes to IsSupportedArgument

* fix smoke-builder

* clang-format

* fixup rename class

* build fixes

* clang-format

* fix builder

* fixup

* remove set from builder tests

* fix test

* clang-format

* re-refactor the kernels

* clang-format

* fix header license

* remove memory operation from conv bwd test

* clang-format

* clang-format example,include

* clang-format test

* build fixes

* clang-format

* solve compilation error

* fix the CI

* solve compilation error

* clang format

* solve merge conflict

* solve merge conflict

* solve the gfx11 error

* solve test error

* moar build fixes

* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Max Podkorytov
2026-01-04 03:28:14 -08:00
committed by GitHub
parent ec23be0b9d
commit e339101e9c
68 changed files with 4198 additions and 4298 deletions

View File

@@ -558,21 +558,19 @@ struct FlatmmKernel
return DTesnorIsValid;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t block_idx_m)
{
// Step 1: Create tensor view
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.M, k_size),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
@@ -581,25 +579,81 @@ struct FlatmmKernel
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(k_size, kargs.M),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
// Step 2: Create padded view
const auto& a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
// Step 3: Create tile window
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, block_idx_m});
}
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
const KernelArgs& kargs,
const index_t block_idx_n)
{
// Step 1: Create tensor view
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
// Step 2: No padding needed for b_flat
// Step 3: Create tile window
return make_tile_window(
b_flat_tensor_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor views
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
@@ -625,7 +679,56 @@ struct FlatmmKernel
},
number<NumDTensor>{});
// TODO: enable vector write for C in ColMajor
// Step 2: Create padded views
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// Step 3: Create tile windows
return generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{block_idx_m, block_idx_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{block_idx_n, block_idx_m});
}
},
number<NumDTensor>{});
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor view
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
@@ -647,98 +750,8 @@ struct FlatmmKernel
}
}();
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
: 1; // per-token scale
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
: 1; // per-channel scale
static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
"only support per-tensor or per-row scaling");
static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
"only support per-tensor or per-column scaling");
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0
? 1
: splitk_batch_offset.splitted_k /
(ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)),
make_tuple(scale_stride_m, 0),
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
number<1>{});
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(ScaleGranularityKB == 0
? 1
: (splitk_batch_offset.splitted_k /
(ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)),
kargs.N / ScaleGranularityN),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<1>{});
return make_tuple(a_tensor_view,
b_flat_tensor_view,
ds_tensor_view,
e_tensor_view,
scale_m_view,
scale_n_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
const auto& b_flat_tensor_view = views.at(I1);
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
const auto& d_tensor_view = views.at(I2);
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// TODO vector write in for C in ColMajor
// Step 2: Create padded view
const auto& e_pad_view = [&]() {
const auto& e_tensor_view = views.at(I3);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
@@ -755,93 +768,72 @@ struct FlatmmKernel
}
}();
return make_tuple(a_pad_view,
b_flat_tensor_view,
ds_pad_view,
e_pad_view,
views.at(number<4>{}),
views.at(number<5>{}));
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& b_flat_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& e_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
const auto ds_block_window = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_n, i_m});
}
},
number<NumDTensor>{});
auto e_block_window = make_tile_window(
// Step 3: Create tile window
return make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
{block_idx_m, block_idx_n});
}
constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m)
{
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
auto scale_m_window = make_tile_window(views.at(number<4>{}),
make_tuple(number<TilePartitioner::MPerBlock>{},
number < ScaleGranularityKA == 0
? TilePartitioner::NPerBlock
: TilePartitioner::KPerBlock > {}),
{i_m, 0});
auto scale_n_window = make_tile_window(views.at(number<5>{}),
make_tuple(number < ScaleGranularityKB == 0
? TilePartitioner::MPerBlock
: TilePartitioner::KPerBlock > {},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
: 1; // per-token scale
return make_tuple(a_block_window,
b_flat_block_window,
ds_block_window,
e_block_window,
scale_m_window,
scale_n_window);
// Step 1: Create tensor view
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0
? 1
: (splitk_batch_offset.splitted_k / ScaleGranularityKA)),
make_tuple(scale_stride_m, 0),
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
number<1>{});
// Step 2: Create tile window
return make_tile_window(scale_m_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number < ScaleGranularityKA == 0
? TilePartitioner::NPerBlock
: TilePartitioner::KPerBlock > {}),
{block_idx_m, 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_n)
{
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
: 1; // per-channel scale
// Step 1: Create tensor view
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
kargs.N / ScaleGranularityN),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<1>{});
// Step 2: Create tile window
return make_tile_window(scale_n_view,
make_tuple(number < ScaleGranularityKB == 0
? TilePartitioner::MPerBlock
: TilePartitioner::KPerBlock > {},
number<TilePartitioner::NPerBlock>{}),
{0, block_idx_n});
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
@@ -857,45 +849,74 @@ struct FlatmmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
auto scale_m_window = gemm_tile_windows.at(number<4>{});
auto scale_n_window = gemm_tile_windows.at(number<5>{});
// Run Epilogue Pipeline
// Run Epilogue Pipeline with k_batch dispatching
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
}
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
}
}
@@ -924,8 +945,7 @@ struct FlatmmKernel
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);