[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

* hacky poc

* fix tile engine

* kill the lambda

* maybe fix test build

* more fixes

* clang-format

* save temp

* clang-format

* mostly fix examples

* clang-format

* remove dead code

* more cleanup

* fix fmha bwd build (default epilogue set/add appears to be broken)

* fix default epilogue tests but not correctness

* clang-format

* fix bquant

* clang-format

* cleanup dead code

* rearrange make windows for readability

* restore changes to IsSupportedArgument

* fix smoke-builder

* clang-format

* fixup rename class

* build fixes

* clang-format

* fix builder

* fixup

* remove set from builder tests

* fix test

* clang-format

* re-refactor the kernels

* clang-format

* fix header license

* remove memory operation from conv bwd test

* clang-format

* clang-format example,include

* clang-format test

* build fixes

* clang-format

* solve compilation error

* fix the CI

* solve compilation error

* clang format

* solve merge conflict

* solve merge conflict

* solve the gfx11 error

* solve test error

* moar build fixes

* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Max Podkorytov
2026-01-04 03:28:14 -08:00
committed by GitHub
parent ec23be0b9d
commit e339101e9c
68 changed files with 4198 additions and 4298 deletions

View File

@@ -361,6 +361,7 @@ struct GroupedGemmKernel
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param ds_ptr input Ds pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
@@ -381,49 +382,54 @@ struct GroupedGemmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
// Create block windows using specialized methods
const auto& a_block_window =
Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m)
.at(Base::I0);
const auto& b_block_window =
Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n)
.at(Base::I0);
const auto& d_block_window =
Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
const auto& b_block_window = gemm_tile_windows.at(Base::I1);
const auto& d_block_window = gemm_tile_windows.at(Base::I2);
// Get hot-loop and tail configuration
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
if(kargs.k_batch == 1)
{
auto c_block_window = Base::template MakeCBlockWindows<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else
{
auto c_block_window =
Base::template MakeCBlockWindows<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note The GEMM pipeline is selected in-kernel based on the number of K-loops
* and the tail-number. This is needed for the persistent tile-loop when
* we didn't have access to the K dimension on the host.
* @note RunGEMM2LDS with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param smem_ptr_1 The second start memory pointer of the shared memory block.
* @param ds_ptr input Ds pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
* batch.
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
@@ -440,54 +446,39 @@ struct GroupedGemmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
// Create block windows using specialized methods
const auto& a_block_window =
Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m)
.at(Base::I0);
const auto& b_block_window =
Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n)
.at(Base::I0);
const auto& d_block_window =
Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
const auto& b_block_window = gemm_tile_windows.at(Base::I1);
const auto& d_block_window = gemm_tile_windows.at(Base::I2);
// Get hot-loop and tail configuration
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline with compile-time branching
const auto& c_block_tile = [&]() {
if constexpr(GemmPipeline::Preshuffle)
{
// Preshuffle version - without has_hot_loop parameter
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
}
else
{
// Regular version - with has_hot_loop parameter
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
}
}();
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
if(kargs.k_batch == 1)
{
auto c_block_window = Base::template MakeCBlockWindows<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else
{
auto c_block_window =
Base::template MakeCBlockWindows<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg<NumDTensor_>* gemm_desc_ptr,

View File

@@ -222,19 +222,13 @@ struct StreamKKernel
const index_t block_idx_n,
const index_t k_size)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
// Create block windows using specialized methods
const auto& as_block_window =
UniversalGemmKernel::MakeABlockWindows(as_ptr, kargs, k_size, block_idx_m);
const auto& bs_block_window =
UniversalGemmKernel::MakeBBlockWindows(bs_ptr, kargs, k_size, block_idx_n);
const auto& ds_block_window =
UniversalGemmKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
@@ -243,6 +237,7 @@ struct StreamKKernel
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
bs_block_window[UniversalGemmKernel::I0],
num_loop,
@@ -253,7 +248,9 @@ struct StreamKKernel
if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
auto c_block_window =
UniversalGemmKernel::template MakeCBlockWindows<TilePartitioner::MemoryOperation>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
@@ -525,21 +522,13 @@ struct StreamKKernel
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
UniversalGemmKernel::template MakeGemmTensorViews<
EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
const auto& gemm_pad_views =
UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
// Create block windows using specialized methods
const auto& as_block_window =
UniversalGemmKernel::MakeABlockWindows({a_ptr}, kargs, k_size, i_m);
const auto& bs_block_window =
UniversalGemmKernel::MakeBBlockWindows({b_ptr}, kargs, k_size, i_n);
const auto& ds_block_window =
UniversalGemmKernel::MakeDBlockWindows({/*ds_ptr*/}, kargs, i_m, i_n);
// Since num_loop can vary per WG and per iteration of the Stream-K while loop,
// we compute has_hot_loop and tail_num here. This is a similar pattern used by
@@ -548,6 +537,7 @@ struct StreamKKernel
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
bs_block_window[UniversalGemmKernel::I0],
num_loop_sk,
@@ -594,7 +584,8 @@ struct StreamKKernel
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
@@ -617,7 +608,8 @@ struct StreamKKernel
// tensor.
if(tile_started && !partner_in_tile)
{
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
break;

View File

@@ -27,6 +27,9 @@ struct StreamKTilePartitionerBase
static constexpr index_t NPerBlock = BlockGemmShapeType::kN;
static constexpr index_t KPerBlock = BlockGemmShapeType::kK;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
static constexpr auto MemoryOperation = (ReductionStrategy == StreamKReductionStrategy::Atomic)
? memory_operation_enum::atomic_add
: memory_operation_enum::set;
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);

View File

@@ -254,6 +254,8 @@ struct UniversalGemmKernel
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
using KernelArgs =
UniversalGemmKernelArgs<AsLayout::size(), BsLayout::size(), DsLayout::size()>;
@@ -609,17 +611,13 @@ struct UniversalGemmKernel
return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
const KernelArgs& kargs,
const index_t k_size)
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_m)
{
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
// Step 1: Create tensor views for A tensors (from MakeGemmTensorViews)
const auto& as_tensor_view = generate_tuple(
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
@@ -645,6 +643,58 @@ struct UniversalGemmKernel
},
number<NumATensor>{});
// Step 2: Create padded views (from MakeGemmPadViews)
const auto& as_pad_view = generate_tuple(
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(as_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(as_tensor_view[i],
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
},
number<NumATensor>{});
// Step 3: Create tile windows (from MakeGemmTileWindows)
const auto& as_block_window = generate_tuple(
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(as_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(as_pad_view[i],
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
},
number<NumATensor>{});
return as_block_window;
}
CK_TILE_DEVICE static auto
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_n)
{
// Step 1: Create tensor views for B tensors (from MakeGemmTensorViews)
const auto& bs_tensor_view = generate_tuple(
[&](auto i) {
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
@@ -733,96 +783,20 @@ struct UniversalGemmKernel
},
number<NumBTensor>{});
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
},
number<NumDTensor>{});
// TODO: enable vector write for C in ColMajor
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm.
make_tuple(kargs.stride_E, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_E),
number<1>{},
number<1>{});
}
}();
return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& as_pad_view = generate_tuple(
[&](auto i) {
const auto& a_tensor_view = views.at(I0);
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view[i],
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
},
number<NumATensor>{});
const auto& b_flat_pad_view = views.at(I1);
// Step 2: Create padded views (from MakeGemmPadViews)
const auto& bs_pad_view = generate_tuple(
[&](auto i) {
const auto& b_tensor_view = views.at(I1);
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view[i],
return pad_tensor_view(bs_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view[i],
return pad_tensor_view(bs_tensor_view[i],
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
@@ -830,86 +804,7 @@ struct UniversalGemmKernel
},
number<NumBTensor>{});
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
const auto& d_tensor_view = views.at(I2);
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// TODO vector write in for C in ColMajor
const auto& e_pad_view = [&]() {
const auto& e_tensor_view = views.at(I3);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
if constexpr(GemmPipeline::Preshuffle)
{
// For flatmm, we need to use the flat B tensor view
return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
}
else
{
return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view);
}
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& as_pad_view = views.at(I0);
const auto& bs_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& e_pad_view = views.at(I3);
const auto& as_block_window = generate_tuple(
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(as_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(as_pad_view[i],
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
},
number<NumATensor>{});
// Step 3: Create tile windows (from MakeGemmTileWindows)
const auto& bs_block_window = generate_tuple(
[&](auto i) {
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
@@ -942,7 +837,63 @@ struct UniversalGemmKernel
},
number<NumBTensor>{});
const auto ds_block_window = generate_tuple(
return bs_block_window;
}
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const KernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor views for D tensors (from MakeGemmTensorViews)
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
},
number<NumDTensor>{});
// Step 2: Create padded views (from MakeGemmPadViews)
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// Step 3: Create tile windows (from MakeGemmTileWindows)
const auto& ds_block_window = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
@@ -962,12 +913,62 @@ struct UniversalGemmKernel
},
number<NumDTensor>{});
return ds_block_window;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
const KernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews)
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_E, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_E),
number<1>{},
number<1>{});
}
}();
// Step 2: Create padded view (from MakeGemmPadViews)
const auto& e_pad_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// Step 3: Create tile window (from MakeGemmTileWindows)
auto e_block_window = make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
return e_block_window;
}
/**
@@ -995,30 +996,32 @@ struct UniversalGemmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& as_block_window =
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& bs_block_window =
MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(I0);
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
if(UseDefaultScheduler || (get_warp_id() == 0))
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
// Run Epilogue Pipeline
if(k_batch == 1)
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}
@@ -1051,22 +1054,17 @@ struct UniversalGemmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& as_block_window =
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& bs_block_window =
MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(I0);
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
AElementWise{},
bs_block_window,
@@ -1076,9 +1074,20 @@ struct UniversalGemmKernel
smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
if(kargs.k_batch == 1)
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}
// Non-persistent kernel entry point
@@ -1119,39 +1128,30 @@ struct UniversalGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
RunGemm2LDS(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
RunGemm<scheduler_type>(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
RunGemm<scheduler_type>(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
@@ -1204,40 +1204,28 @@ struct UniversalGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
RunGemm(as_ptr,
RunGemm2LDS(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
else
{
RunGemm(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
// Advance to the next work item
block_id += grid_size;