mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)
* initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder * hacky poc * fix tile engine * kill the lambda * maybe fix test build * more fixes * clang-format * save temp * clang-format * mostly fix examples * clang-format * remove dead code * more cleanup * fix fmha bwd build (default epilogue set/add appears to be broken) * fix default epilogue tests but not correctness * clang-format * fix bquant * clang-format * cleanup dead code * rearrange make windows for readability * restore changes to IsSupportedArgument * fix smoke-builder * clang-format * fixup rename class * build fixes * clang-format * fix builder * fixup * remove set from builder tests * fix test * clang-format * re-refactor the kernels * clang-format * fix header license * remove memory operation from conv bwd test * clang-format * clang-format example,include * clang-format test * build fixes * clang-format * solve compilation error * fix the CI * solve compilation error * clang format * solve merge conflict * solve merge conflict * solve the gfx11 error * solve test error * moar build fixes * remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -361,6 +361,7 @@ struct GroupedGemmKernel
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
@@ -381,49 +382,54 @@ struct GroupedGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m)
|
||||
.at(Base::I0);
|
||||
const auto& b_block_window =
|
||||
Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n)
|
||||
.at(Base::I0);
|
||||
const auto& d_block_window =
|
||||
Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(Base::I2);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = Base::template MakeCBlockWindows<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window =
|
||||
Base::template MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note The GEMM pipeline is selected in-kernel based on the number of K-loops
|
||||
* and the tail-number. This is needed for the persistent tile-loop when
|
||||
* we didn't have access to the K dimension on the host.
|
||||
* @note RunGEMM2LDS with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param smem_ptr_1 The second start memory pointer of the shared memory block.
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
|
||||
* batch.
|
||||
* @param splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
@@ -440,54 +446,39 @@ struct GroupedGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m)
|
||||
.at(Base::I0);
|
||||
const auto& b_block_window =
|
||||
Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n)
|
||||
.at(Base::I0);
|
||||
const auto& d_block_window =
|
||||
Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(Base::I2);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline with compile-time branching
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
// Preshuffle version - without has_hot_loop parameter
|
||||
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Regular version - with has_hot_loop parameter
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
}
|
||||
}();
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = Base::template MakeCBlockWindows<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window =
|
||||
Base::template MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg<NumDTensor_>* gemm_desc_ptr,
|
||||
|
||||
@@ -222,19 +222,13 @@ struct StreamKKernel
|
||||
const index_t block_idx_n,
|
||||
const index_t k_size)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
UniversalGemmKernel::MakeABlockWindows(as_ptr, kargs, k_size, block_idx_m);
|
||||
const auto& bs_block_window =
|
||||
UniversalGemmKernel::MakeBBlockWindows(bs_ptr, kargs, k_size, block_idx_n);
|
||||
const auto& ds_block_window =
|
||||
UniversalGemmKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
|
||||
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
|
||||
@@ -243,6 +237,7 @@ struct StreamKKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop,
|
||||
@@ -253,7 +248,9 @@ struct StreamKKernel
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
auto c_block_window =
|
||||
UniversalGemmKernel::template MakeCBlockWindows<TilePartitioner::MemoryOperation>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
@@ -525,21 +522,13 @@ struct StreamKKernel
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<
|
||||
EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views =
|
||||
UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
UniversalGemmKernel::MakeABlockWindows({a_ptr}, kargs, k_size, i_m);
|
||||
const auto& bs_block_window =
|
||||
UniversalGemmKernel::MakeBBlockWindows({b_ptr}, kargs, k_size, i_n);
|
||||
const auto& ds_block_window =
|
||||
UniversalGemmKernel::MakeDBlockWindows({/*ds_ptr*/}, kargs, i_m, i_n);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop,
|
||||
// we compute has_hot_loop and tail_num here. This is a similar pattern used by
|
||||
@@ -548,6 +537,7 @@ struct StreamKKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop_sk,
|
||||
@@ -594,7 +584,8 @@ struct StreamKKernel
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
|
||||
TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
@@ -617,7 +608,8 @@ struct StreamKKernel
|
||||
// tensor.
|
||||
if(tile_started && !partner_in_tile)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
|
||||
TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
break;
|
||||
|
||||
@@ -27,6 +27,9 @@ struct StreamKTilePartitionerBase
|
||||
static constexpr index_t NPerBlock = BlockGemmShapeType::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShapeType::kK;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
|
||||
static constexpr auto MemoryOperation = (ReductionStrategy == StreamKReductionStrategy::Atomic)
|
||||
? memory_operation_enum::atomic_add
|
||||
: memory_operation_enum::set;
|
||||
|
||||
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
|
||||
|
||||
|
||||
@@ -254,6 +254,8 @@ struct UniversalGemmKernel
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
|
||||
using KernelArgs =
|
||||
UniversalGemmKernelArgs<AsLayout::size(), BsLayout::size(), DsLayout::size()>;
|
||||
|
||||
@@ -609,17 +611,13 @@ struct UniversalGemmKernel
|
||||
return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size)
|
||||
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_m)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
|
||||
// Step 1: Create tensor views for A tensors (from MakeGemmTensorViews)
|
||||
const auto& as_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
@@ -645,6 +643,58 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
// Step 2: Create padded views (from MakeGemmPadViews)
|
||||
const auto& as_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(as_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(as_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
// Step 3: Create tile windows (from MakeGemmTileWindows)
|
||||
const auto& as_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
return as_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor views for B tensors (from MakeGemmTensorViews)
|
||||
const auto& bs_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
@@ -733,96 +783,20 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm.
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_E),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& as_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
|
||||
// Step 2: Create padded views (from MakeGemmPadViews)
|
||||
const auto& bs_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view[i],
|
||||
return pad_tensor_view(bs_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view[i],
|
||||
return pad_tensor_view(bs_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
@@ -830,86 +804,7 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
// For flatmm, we need to use the flat B tensor view
|
||||
return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& as_pad_view = views.at(I0);
|
||||
const auto& bs_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& as_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
// Step 3: Create tile windows (from MakeGemmTileWindows)
|
||||
const auto& bs_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
@@ -942,7 +837,63 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
return bs_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor views for D tensors (from MakeGemmTensorViews)
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 2: Create padded views (from MakeGemmPadViews)
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 3: Create tile windows (from MakeGemmTileWindows)
|
||||
const auto& ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -962,12 +913,62 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return ds_block_window;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews)
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_E),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create padded view (from MakeGemmPadViews)
|
||||
const auto& e_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window (from MakeGemmTileWindows)
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
|
||||
return e_block_window;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -995,30 +996,32 @@ struct UniversalGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& bs_block_window =
|
||||
MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
// Run Epilogue Pipeline
|
||||
if(k_batch == 1)
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
@@ -1051,22 +1054,17 @@ struct UniversalGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& bs_block_window =
|
||||
MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
|
||||
AElementWise{},
|
||||
bs_block_window,
|
||||
@@ -1076,9 +1074,20 @@ struct UniversalGemmKernel
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
// Non-persistent kernel entry point
|
||||
@@ -1119,39 +1128,30 @@ struct UniversalGemmKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
|
||||
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1204,40 +1204,28 @@ struct UniversalGemmKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(as_ptr,
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
// Advance to the next work item
|
||||
block_id += grid_size;
|
||||
|
||||
Reference in New Issue
Block a user