mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)
* initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder * hacky poc * fix tile engine * kill the lambda * maybe fix test build * more fixes * clang-format * save temp * clang-format * mostly fix examples * clang-format * remove dead code * more cleanup * fix fmha bwd build (default epilogue set/add appears to be broken) * fix default epilogue tests but not correctness * clang-format * fix bquant * clang-format * cleanup dead code * rearrange make windows for readability * restore changes to IsSupportedArgument * fix smoke-builder * clang-format * fixup rename class * build fixes * clang-format * fix builder * fixup * remove set from builder tests * fix test * clang-format * re-refactor the kernels * clang-format * fix header license * remove memory operation from conv bwd test * clang-format * clang-format example,include * clang-format test * build fixes * clang-format * solve compilation error * fix the CI * solve compilation error * clang format * solve merge conflict * solve merge conflict * solve the gfx11 error * solve test error * moar build fixes * remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -30,7 +30,6 @@ template <typename AsDataType_,
|
||||
index_t NPerXdl_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
@@ -39,31 +38,30 @@ template <typename AsDataType_,
|
||||
bool DoubleSmemBuffer_ = false>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
static constexpr index_t NWave = NWave_;
|
||||
static constexpr index_t MPerXdl = MPerXdl_;
|
||||
static constexpr index_t NPerXdl = NPerXdl_;
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
static constexpr index_t NWave = NWave_;
|
||||
static constexpr index_t MPerXdl = MPerXdl_;
|
||||
static constexpr index_t NPerXdl = NPerXdl_;
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
@@ -105,28 +103,27 @@ struct CShuffleEpilogue
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t MWave = Problem::MWave;
|
||||
static constexpr index_t NWave = Problem::NWave;
|
||||
static constexpr index_t MPerXdl = Problem::MPerXdl;
|
||||
static constexpr index_t NPerXdl = Problem::NPerXdl;
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t MWave = Problem::MWave;
|
||||
static constexpr index_t NWave = Problem::NWave;
|
||||
static constexpr index_t MPerXdl = Problem::MPerXdl;
|
||||
static constexpr index_t NPerXdl = Problem::NPerXdl;
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
|
||||
CDElementwise elfunc_;
|
||||
|
||||
@@ -142,8 +139,7 @@ struct CShuffleEpilogue
|
||||
concat('x', MWave, NWave),
|
||||
concat('x', MPerXdl, NPerXdl, KPerXdl),
|
||||
VectorSizeC,
|
||||
isCTransposed ? "CTransposed" : "CNotTransposed",
|
||||
mem_op_string<MemoryOperation>());
|
||||
isCTransposed ? "CTransposed" : "CNotTransposed");
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -445,7 +441,8 @@ struct CShuffleEpilogue
|
||||
CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
|
||||
const COutTensor& c_out_tensor)
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
|
||||
memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
@@ -617,7 +614,8 @@ struct CShuffleEpilogue
|
||||
});
|
||||
|
||||
// store/update
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
|
||||
memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
|
||||
@@ -15,17 +15,15 @@ template <typename AccDataType_,
|
||||
typename ODataType_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool UseRawStore_ = true,
|
||||
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
|
||||
bool UseRawStore_ = true>
|
||||
struct Default2DEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr index_t NumDTensor = 0;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
static constexpr index_t NumDTensor = 0;
|
||||
};
|
||||
|
||||
template <typename AsDataType_,
|
||||
@@ -44,14 +42,9 @@ template <typename AsDataType_,
|
||||
index_t kNPerXdl_,
|
||||
index_t kKPerXdl_,
|
||||
bool isCTransposed_,
|
||||
bool UseRawStore_ = true,
|
||||
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
|
||||
struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataType_,
|
||||
ODataType_,
|
||||
kPadM_,
|
||||
kPadN_,
|
||||
UseRawStore_,
|
||||
MemoryOperation_>
|
||||
bool UseRawStore_ = true>
|
||||
struct DefaultGemm2DEpilogueProblem
|
||||
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
@@ -81,7 +74,6 @@ struct Default2DEpilogue
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::UseRawStore;
|
||||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
@@ -102,7 +94,10 @@ struct Default2DEpilogue
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
// FIXME?
|
||||
// if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
|
||||
// memory_operation_enum::set)
|
||||
if constexpr(true)
|
||||
{
|
||||
if constexpr(is_partition_index)
|
||||
{
|
||||
@@ -123,7 +118,10 @@ struct Default2DEpilogue
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
// FIXME?
|
||||
// if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
|
||||
// memory_operation_enum::set)
|
||||
if constexpr(true)
|
||||
{
|
||||
if constexpr(is_partition_index)
|
||||
{
|
||||
|
||||
@@ -558,21 +558,19 @@ struct FlatmmKernel
|
||||
return DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t block_idx_m)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.M, k_size),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
@@ -581,25 +579,81 @@ struct FlatmmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(k_size, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK =
|
||||
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{block_idx_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, block_idx_m});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
index_t kFlatK =
|
||||
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
|
||||
// Step 2: No padding needed for b_flat
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
b_flat_tensor_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor views
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
@@ -625,7 +679,56 @@ struct FlatmmKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
// Step 2: Create padded views
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 3: Create tile windows
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{block_idx_n, block_idx_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -647,98 +750,8 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
|
||||
|
||||
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-token scale
|
||||
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-channel scale
|
||||
|
||||
static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
|
||||
"only support per-tensor or per-row scaling");
|
||||
static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
|
||||
"only support per-tensor or per-column scaling");
|
||||
|
||||
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m_ptr.ptr,
|
||||
make_tuple(kargs.M / ScaleGranularityM,
|
||||
ScaleGranularityKA == 0
|
||||
? 1
|
||||
: splitk_batch_offset.splitted_k /
|
||||
(ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
|
||||
number<1>{});
|
||||
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n_ptr.ptr,
|
||||
make_tuple(ScaleGranularityKB == 0
|
||||
? 1
|
||||
: (splitk_batch_offset.splitted_k /
|
||||
(ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)),
|
||||
kargs.N / ScaleGranularityN),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(a_tensor_view,
|
||||
b_flat_tensor_view,
|
||||
ds_tensor_view,
|
||||
e_tensor_view,
|
||||
scale_m_view,
|
||||
scale_n_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
// Step 2: Create padded view
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
@@ -755,93 +768,72 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view,
|
||||
b_flat_tensor_view,
|
||||
ds_pad_view,
|
||||
e_pad_view,
|
||||
views.at(number<4>{}),
|
||||
views.at(number<5>{}));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
|
||||
constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m)
|
||||
{
|
||||
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
|
||||
auto scale_m_window = make_tile_window(views.at(number<4>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number < ScaleGranularityKA == 0
|
||||
? TilePartitioner::NPerBlock
|
||||
: TilePartitioner::KPerBlock > {}),
|
||||
{i_m, 0});
|
||||
auto scale_n_window = make_tile_window(views.at(number<5>{}),
|
||||
make_tuple(number < ScaleGranularityKB == 0
|
||||
? TilePartitioner::MPerBlock
|
||||
: TilePartitioner::KPerBlock > {},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-token scale
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
// Step 1: Create tensor view
|
||||
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m_ptr.ptr,
|
||||
make_tuple(kargs.M / ScaleGranularityM,
|
||||
ScaleGranularityKA == 0
|
||||
? 1
|
||||
: (splitk_batch_offset.splitted_k / ScaleGranularityKA)),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
// Step 2: Create tile window
|
||||
return make_tile_window(scale_m_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number < ScaleGranularityKA == 0
|
||||
? TilePartitioner::NPerBlock
|
||||
: TilePartitioner::KPerBlock > {}),
|
||||
{block_idx_m, 0});
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-channel scale
|
||||
|
||||
// Step 1: Create tensor view
|
||||
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n_ptr.ptr,
|
||||
make_tuple(
|
||||
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
|
||||
kargs.N / ScaleGranularityN),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
// Step 2: Create tile window
|
||||
return make_tile_window(scale_n_view,
|
||||
make_tuple(number < ScaleGranularityKB == 0
|
||||
? TilePartitioner::MPerBlock
|
||||
: TilePartitioner::KPerBlock > {},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, block_idx_n});
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
@@ -857,45 +849,74 @@ struct FlatmmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
|
||||
const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
||||
|
||||
auto scale_m_window = gemm_tile_windows.at(number<4>{});
|
||||
auto scale_n_window = gemm_tile_windows.at(number<5>{});
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(e_block_window),
|
||||
decltype(c_block_tile),
|
||||
decltype(ds_block_window)>(e_block_window,
|
||||
c_block_tile,
|
||||
ds_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(e_block_window),
|
||||
decltype(c_block_tile),
|
||||
decltype(ds_block_window)>(e_block_window,
|
||||
c_block_tile,
|
||||
ds_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
}
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(e_block_window),
|
||||
decltype(c_block_tile),
|
||||
decltype(ds_block_window)>(
|
||||
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(e_block_window),
|
||||
decltype(c_block_tile),
|
||||
decltype(ds_block_window)>(
|
||||
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -924,8 +945,7 @@ struct FlatmmKernel
|
||||
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
|
||||
@@ -100,21 +100,19 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
|
||||
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t block_idx_m)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.M, k_size),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
@@ -123,25 +121,80 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(k_size, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{block_idx_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, block_idx_m});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
|
||||
// Step 2: No padding needed for b_flat
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
b_flat_tensor_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor views
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
@@ -167,7 +220,56 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
// Step 2: Create padded views
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 3: Create tile windows
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{block_idx_n, block_idx_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -189,70 +291,8 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_n = kargs.scale_n_ptr;
|
||||
|
||||
index_t FlatScaleK =
|
||||
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(
|
||||
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
// Step 2: Create padded view
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
@@ -269,77 +309,37 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
|
||||
auto scale_block_window =
|
||||
make_tile_window(views.at(I4),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
|
||||
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
auto scale_n = kargs.scale_n_ptr;
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_block_window);
|
||||
// Step 1: Create tensor view
|
||||
index_t FlatScaleK =
|
||||
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
// Step 2: Create tile window
|
||||
return make_tile_window(
|
||||
scale_b_flat_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
|
||||
{block_idx_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
@@ -355,21 +355,15 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& scale_block_window = MakeScaleBBlockWindow(kargs, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I4);
|
||||
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|
||||
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
@@ -378,6 +372,7 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
|
||||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
auto a_block_window_with_distr =
|
||||
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
|
||||
a_block_window.get_window_lengths(),
|
||||
@@ -390,22 +385,46 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if constexpr(DoEpiScale)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(e_block_window,
|
||||
c_block_tile,
|
||||
ds_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(e_block_window,
|
||||
c_block_tile,
|
||||
ds_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
}
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,8 +453,7 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
__shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
|
||||
@@ -1476,7 +1476,8 @@ struct MoeFlatmmKernel
|
||||
c_scatter_valids[mIter]);
|
||||
|
||||
if constexpr(!IsInputGemm ||
|
||||
EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)
|
||||
decltype(c_block_window.get_bottom_tensor_view())::DstInMemOp ==
|
||||
memory_operation_enum::atomic_add)
|
||||
c_scatter_tile_window.update(c_out_tensor);
|
||||
else
|
||||
c_scatter_tile_window.store(c_out_tensor);
|
||||
|
||||
@@ -113,32 +113,50 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
|
||||
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t block_idx_m)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& a_tensor_view = [&]() {
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
|
||||
"A tensor for mx must be RowMajor");
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.M, k_size),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<MXFlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, MXFlatmmPipeline::kPadK>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{block_idx_m, 0});
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view with special flat layout
|
||||
constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock;
|
||||
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
|
||||
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
|
||||
const index_t kFlatN = kargs.N / kNWarpTile;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0,
|
||||
"wrong! vector size for B tensor");
|
||||
auto&& naive_desc = make_naive_tensor_descriptor_packed(
|
||||
@@ -153,6 +171,22 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
|
||||
}();
|
||||
|
||||
// Step 2: No padding for flat B
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
b_flat_tensor_view,
|
||||
make_tuple(number<MXFlatmmPipeline::flatNPerWarp>{},
|
||||
number<MXFlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor views
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
@@ -178,7 +212,56 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
// Step 2: Create padded views
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, MXFlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, MXFlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 3: Create tile windows
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{block_idx_n, block_idx_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -200,92 +283,8 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
|
||||
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
|
||||
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
|
||||
|
||||
// A scale tensor view
|
||||
const auto& scale_a_tensor_view = [&]() {
|
||||
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
|
||||
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
|
||||
const auto scale_a_desc = transform_tensor_descriptor(
|
||||
scale_a_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
|
||||
}();
|
||||
|
||||
// B scale tensor view
|
||||
const auto& scale_b_tensor_view = [&]() {
|
||||
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
|
||||
const auto scale_b_desc = transform_tensor_descriptor(
|
||||
scale_b_navie_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view,
|
||||
b_flat_tensor_view,
|
||||
ds_tensor_view,
|
||||
e_tensor_view,
|
||||
scale_a_tensor_view,
|
||||
scale_b_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
|
||||
"A tensor for mx must be RowMajor");
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, MXFlatmmPipeline::kPadK>{});
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, MXFlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, MXFlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
// Step 2: Create padded view
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
@@ -302,79 +301,71 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(
|
||||
a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4), views.at(I5));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
|
||||
"A tensor for mx must be RowMajor");
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}();
|
||||
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<MXFlatmmPipeline::flatNPerWarp>{},
|
||||
number<MXFlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeScaleABlockWindow(const KernelArgs& kargs,
|
||||
const index_t block_idx_m)
|
||||
{
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
views.at(I4),
|
||||
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
|
||||
|
||||
// Step 1: Create tensor view
|
||||
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
|
||||
const auto scale_a_desc = transform_tensor_descriptor(
|
||||
scale_a_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto& scale_a_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(kargs.scale_m_ptr.ptr), scale_a_desc);
|
||||
|
||||
// Step 2: Create tile window
|
||||
return make_tile_window(
|
||||
scale_a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock / MXdlPack>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
|
||||
{i_m / MXdlPack, 0});
|
||||
{block_idx_m / MXdlPack, 0});
|
||||
}
|
||||
|
||||
auto scale_b_block_window = make_tile_window(
|
||||
views.at(I5),
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
|
||||
|
||||
// Step 1: Create tensor view
|
||||
const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
|
||||
const auto scale_b_desc = transform_tensor_descriptor(
|
||||
scale_b_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto& scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(kargs.scale_n_ptr.ptr), scale_b_desc);
|
||||
|
||||
// Step 2: Create tile window
|
||||
return make_tile_window(
|
||||
scale_b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / NXdlPack>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
|
||||
{i_n / NXdlPack, 0});
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window);
|
||||
{block_idx_n / NXdlPack, 0});
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
@@ -390,22 +381,16 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& scale_a_block_window = MakeScaleABlockWindow(kargs, block_idx_m);
|
||||
const auto& scale_b_block_window = MakeScaleBBlockWindow(kargs, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_a_block_window = gemm_tile_windows.at(I4);
|
||||
const auto& scale_b_block_window = gemm_tile_windows.at(I5);
|
||||
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|
||||
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
@@ -422,22 +407,46 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
// Run Epilogue Pipeline with split_k dispatch
|
||||
if constexpr(DoEpiScale)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(e_block_window,
|
||||
c_block_tile,
|
||||
ds_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(e_block_window,
|
||||
c_block_tile,
|
||||
ds_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
}
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -466,27 +475,17 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
__shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false,
|
||||
"Unimplemented: atomic_add with odd vector size for fp16/bf16");
|
||||
}
|
||||
constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
partition_idx += gridDim.x;
|
||||
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
||||
}
|
||||
|
||||
@@ -361,6 +361,7 @@ struct GroupedGemmKernel
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
@@ -381,49 +382,54 @@ struct GroupedGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m)
|
||||
.at(Base::I0);
|
||||
const auto& b_block_window =
|
||||
Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n)
|
||||
.at(Base::I0);
|
||||
const auto& d_block_window =
|
||||
Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(Base::I2);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = Base::template MakeCBlockWindows<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window =
|
||||
Base::template MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note The GEMM pipeline is selected in-kernel based on the number of K-loops
|
||||
* and the tail-number. This is needed for the persistent tile-loop when
|
||||
* we didn't have access to the K dimension on the host.
|
||||
* @note RunGEMM2LDS with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param smem_ptr_1 The second start memory pointer of the shared memory block.
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
|
||||
* batch.
|
||||
* @param splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
@@ -440,54 +446,39 @@ struct GroupedGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m)
|
||||
.at(Base::I0);
|
||||
const auto& b_block_window =
|
||||
Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n)
|
||||
.at(Base::I0);
|
||||
const auto& d_block_window =
|
||||
Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(Base::I2);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline with compile-time branching
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
// Preshuffle version - without has_hot_loop parameter
|
||||
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Regular version - with has_hot_loop parameter
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
}
|
||||
}();
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = Base::template MakeCBlockWindows<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window =
|
||||
Base::template MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg<NumDTensor_>* gemm_desc_ptr,
|
||||
|
||||
@@ -222,19 +222,13 @@ struct StreamKKernel
|
||||
const index_t block_idx_n,
|
||||
const index_t k_size)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
UniversalGemmKernel::MakeABlockWindows(as_ptr, kargs, k_size, block_idx_m);
|
||||
const auto& bs_block_window =
|
||||
UniversalGemmKernel::MakeBBlockWindows(bs_ptr, kargs, k_size, block_idx_n);
|
||||
const auto& ds_block_window =
|
||||
UniversalGemmKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
|
||||
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
|
||||
@@ -243,6 +237,7 @@ struct StreamKKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop,
|
||||
@@ -253,7 +248,9 @@ struct StreamKKernel
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
auto c_block_window =
|
||||
UniversalGemmKernel::template MakeCBlockWindows<TilePartitioner::MemoryOperation>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
@@ -525,21 +522,13 @@ struct StreamKKernel
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<
|
||||
EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views =
|
||||
UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
UniversalGemmKernel::MakeABlockWindows({a_ptr}, kargs, k_size, i_m);
|
||||
const auto& bs_block_window =
|
||||
UniversalGemmKernel::MakeBBlockWindows({b_ptr}, kargs, k_size, i_n);
|
||||
const auto& ds_block_window =
|
||||
UniversalGemmKernel::MakeDBlockWindows({/*ds_ptr*/}, kargs, i_m, i_n);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop,
|
||||
// we compute has_hot_loop and tail_num here. This is a similar pattern used by
|
||||
@@ -548,6 +537,7 @@ struct StreamKKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop_sk,
|
||||
@@ -594,7 +584,8 @@ struct StreamKKernel
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
|
||||
TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
@@ -617,7 +608,8 @@ struct StreamKKernel
|
||||
// tensor.
|
||||
if(tile_started && !partner_in_tile)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
|
||||
TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
break;
|
||||
|
||||
@@ -27,6 +27,9 @@ struct StreamKTilePartitionerBase
|
||||
static constexpr index_t NPerBlock = BlockGemmShapeType::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShapeType::kK;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
|
||||
static constexpr auto MemoryOperation = (ReductionStrategy == StreamKReductionStrategy::Atomic)
|
||||
? memory_operation_enum::atomic_add
|
||||
: memory_operation_enum::set;
|
||||
|
||||
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
|
||||
|
||||
|
||||
@@ -254,6 +254,8 @@ struct UniversalGemmKernel
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
|
||||
using KernelArgs =
|
||||
UniversalGemmKernelArgs<AsLayout::size(), BsLayout::size(), DsLayout::size()>;
|
||||
|
||||
@@ -609,17 +611,13 @@ struct UniversalGemmKernel
|
||||
return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size)
|
||||
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_m)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
|
||||
// Step 1: Create tensor views for A tensors (from MakeGemmTensorViews)
|
||||
const auto& as_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
@@ -645,6 +643,58 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
// Step 2: Create padded views (from MakeGemmPadViews)
|
||||
const auto& as_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(as_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(as_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
// Step 3: Create tile windows (from MakeGemmTileWindows)
|
||||
const auto& as_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
return as_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor views for B tensors (from MakeGemmTensorViews)
|
||||
const auto& bs_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
@@ -733,96 +783,20 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm.
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_E),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& as_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
|
||||
// Step 2: Create padded views (from MakeGemmPadViews)
|
||||
const auto& bs_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view[i],
|
||||
return pad_tensor_view(bs_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view[i],
|
||||
return pad_tensor_view(bs_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
@@ -830,86 +804,7 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
// For flatmm, we need to use the flat B tensor view
|
||||
return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& as_pad_view = views.at(I0);
|
||||
const auto& bs_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& as_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
// Step 3: Create tile windows (from MakeGemmTileWindows)
|
||||
const auto& bs_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
@@ -942,7 +837,63 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
return bs_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor views for D tensors (from MakeGemmTensorViews)
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 2: Create padded views (from MakeGemmPadViews)
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 3: Create tile windows (from MakeGemmTileWindows)
|
||||
const auto& ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -962,12 +913,62 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return ds_block_window;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews)
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_E),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create padded view (from MakeGemmPadViews)
|
||||
const auto& e_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window (from MakeGemmTileWindows)
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
|
||||
return e_block_window;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -995,30 +996,32 @@ struct UniversalGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& bs_block_window =
|
||||
MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
// Run Epilogue Pipeline
|
||||
if(k_batch == 1)
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
@@ -1051,22 +1054,17 @@ struct UniversalGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& bs_block_window =
|
||||
MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
|
||||
AElementWise{},
|
||||
bs_block_window,
|
||||
@@ -1076,9 +1074,20 @@ struct UniversalGemmKernel
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
// Non-persistent kernel entry point
|
||||
@@ -1119,39 +1128,30 @@ struct UniversalGemmKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
|
||||
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1204,40 +1204,28 @@ struct UniversalGemmKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(as_ptr,
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
// Advance to the next work item
|
||||
block_id += grid_size;
|
||||
|
||||
@@ -401,6 +401,592 @@ struct QuantGemmKernel
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_m)
|
||||
{
|
||||
// Step 1: Create tensor view for A
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, k_size),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(k_size, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for AQ
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
|
||||
make_tuple(aq_x, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
|
||||
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_pad0_desc = transform_tensor_descriptor(
|
||||
aq_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
|
||||
const auto wave_tile_size =
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
|
||||
const auto wave_tile_count_x =
|
||||
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
|
||||
|
||||
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
aq_pad0_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto aq_pad1_desc = transform_tensor_descriptor(
|
||||
aq_unmerge_pad0_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_pass_through_transform(wave_tile_count_x),
|
||||
make_right_pad_transform(
|
||||
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
const auto pad_wave_size =
|
||||
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
aq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
|
||||
make_pass_through_transform(pad_wave_size)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
||||
}
|
||||
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped) &&
|
||||
!PreshuffleQuant)
|
||||
{
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK_A),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else // Column major AQ
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.QK_A, kargs.M),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, 0), // broadcasting over n
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create tile window (no padding for AQ)
|
||||
const auto& aq_block_window = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_m / warp_m;
|
||||
auto block_m_idx = i_m / block_m;
|
||||
return make_tile_window(
|
||||
aq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_m_idx * tile_window_height, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(aq_tensor_view,
|
||||
make_tuple(number<block_m>{}, number<aqk_per_block>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else // Column major AQ
|
||||
{
|
||||
return make_tile_window(aq_tensor_view,
|
||||
make_tuple(number<aqk_per_block>{}, number<block_m>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
return make_tile_window(
|
||||
aq_tensor_view,
|
||||
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_tile_window(aq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
return aq_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for B
|
||||
const auto& b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(k_size, kargs.N),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::flatKPerWarp *
|
||||
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, k_size / 2),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
else
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, k_size),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create padded view (or flat view for PreshuffleB)
|
||||
const auto& b_pad_view = [&]() {
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
return b_tensor_view; // no padding for preshuffle
|
||||
}
|
||||
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / 2>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
else
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window
|
||||
const auto& b_block_window = [&]() {
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<GemmPipeline::flatNPerWarp>{},
|
||||
number<GemmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / 2>{}),
|
||||
{i_n, 0});
|
||||
else
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for BQ
|
||||
const auto& bq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(0, 1), // broadcasting over m
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
|
||||
"PreshuffleQuant with BQuantGrouped currently only supports "
|
||||
"ColumnMajor BQ layout");
|
||||
|
||||
return MakePreshuffledQuantTensorView<
|
||||
GemmPipeline::KPerBlockBQ,
|
||||
GemmPipeline::NPerBlock,
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
|
||||
GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
|
||||
}
|
||||
else
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
|
||||
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
|
||||
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
|
||||
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create tile window (no padding for BQ)
|
||||
const auto& bq_block_window = [&]() {
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_tile_window(bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_n / warp_n;
|
||||
auto block_n_idx = i_n / block_n;
|
||||
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
return bq_block_window;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindow(CDataType* c_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for C
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& c_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return c_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
@@ -1143,9 +1729,7 @@ struct QuantGemmKernel
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
@@ -1157,25 +1741,22 @@ struct QuantGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
index_t m = 0;
|
||||
index_t m = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
m = kargs.M;
|
||||
@@ -1185,8 +1766,7 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
index_t n = 0;
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
n = kargs.N;
|
||||
@@ -1196,10 +1776,8 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
index_t m = 0;
|
||||
index_t n = 0;
|
||||
index_t m = 0;
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
m = kargs.M;
|
||||
@@ -1222,86 +1800,111 @@ struct QuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I4);
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
if constexpr(kQuantType == QuantType::ABQuantGrouped ||
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
// Run Epilogue Pipeline with k_batch dispatch
|
||||
if(k_batch == 1)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
if constexpr(kQuantType == QuantType::ABQuantGrouped ||
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
else
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
// TODO: why doesn't readfirstlane work here?
|
||||
// const AccDataType aq_scale =
|
||||
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
|
||||
// const AccDataType bq_scale =
|
||||
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
if constexpr(kQuantType == QuantType::ABQuantGrouped ||
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param aq_ptr input AQ pointer
|
||||
* @param bq_ptr input BQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
[[maybe_unused]] const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr_1,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
index_t n = 0;
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
n = kargs.N;
|
||||
@@ -1320,19 +1923,23 @@ struct QuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I4);
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
// Run Epilogue Pipeline with k_batch dispatch
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return;
|
||||
// throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or
|
||||
// RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped,
|
||||
// "DoubleSmemBuffer Not implemented");
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1343,16 +1950,19 @@ struct QuantGemmKernel
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
|
||||
// Apply splitk offset to input pointers
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
|
||||
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
assert(kargs.k_batch == 1);
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
|
||||
@@ -374,7 +374,7 @@ struct QuantGroupedGemmKernel
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
[[maybe_unused]] const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
@@ -385,25 +385,21 @@ struct QuantGroupedGemmKernel
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& bq_block_window =
|
||||
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
|
||||
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
// Run GEMM cooperatively by whole workgroup
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
@@ -411,10 +407,20 @@ struct QuantGroupedGemmKernel
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
// Run Epilogue Pipeline with split_k dispatch
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window =
|
||||
Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -449,16 +455,15 @@ struct QuantGroupedGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& aq_block_window =
|
||||
Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window =
|
||||
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
@@ -466,51 +471,77 @@ struct QuantGroupedGemmKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
// Run GEMM cooperatively by whole workgroup
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant ||
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
}
|
||||
}();
|
||||
|
||||
// Run Epilogue Pipeline with split_k dispatch
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
auto c_block_window =
|
||||
Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
|
||||
@@ -617,6 +617,117 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeABlockWindow(const OutDataType* a_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id,
|
||||
const index_t i_m,
|
||||
const index_t i_k)
|
||||
{
|
||||
// Step 1: Create tensor view for A (Out tensor)
|
||||
const auto& a_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_descs_m_k[group_id]);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto a_block_window = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, i_k});
|
||||
|
||||
return a_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindow(const InDataType* b_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id,
|
||||
const index_t i_n,
|
||||
const index_t i_k)
|
||||
{
|
||||
// Step 1: Create tensor view for B (Weight tensor)
|
||||
const auto& b_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_descs_n_k[group_id]);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto b_block_window = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
|
||||
return b_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Create D tensor block windows
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
// Step 1: Create tensor view for D
|
||||
const auto& d_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& d_pad_view =
|
||||
pad_tensor_view(d_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(d_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return ds_block_window;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeCBlockWindow(WeiDataType* c_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for C (Input tensor)
|
||||
const auto& c_tensor_view = make_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr, kargs.c_grid_descs_m_n[group_id]);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return c_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
|
||||
{
|
||||
@@ -895,38 +1006,49 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const index_t block_idx_k,
|
||||
const index_t group_id)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k);
|
||||
const auto& d_block_window =
|
||||
MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
// Run Epilogue Pipeline with k_batch dispatch
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
* @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
@@ -951,23 +1073,19 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const index_t block_idx_k,
|
||||
const index_t group_id)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k);
|
||||
const auto& d_block_window =
|
||||
MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
num_loop,
|
||||
@@ -976,11 +1094,27 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
// Run Epilogue Pipeline with k_batch dispatch
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
@@ -1066,8 +1200,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
@@ -1086,8 +1219,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
|
||||
@@ -518,25 +518,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(__gfx11__)
|
||||
if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add)
|
||||
{
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
|
||||
!std::is_same_v<typename EpiloguePipeline::ODataType, double>)
|
||||
{
|
||||
@@ -704,29 +685,31 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const OutDataType* a_ptr,
|
||||
const InDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
WeiDataType* c_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
MakeCBlockWindow(WeiDataType* c_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr,
|
||||
kargs.a_grid_desc_k_m); // A: out
|
||||
}();
|
||||
const auto& c_tensor_view =
|
||||
make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, kargs.c_grid_desc_m_n);
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr,
|
||||
kargs.b_grid_desc_k_n); // B: in
|
||||
}();
|
||||
const auto& c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr,
|
||||
kargs.c_grid_desc_m_n);
|
||||
}();
|
||||
return make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
|
||||
@@ -741,30 +724,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = views.at(I2);
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
@@ -773,67 +733,58 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create views to the data that each workgroup will process.
|
||||
*
|
||||
* @param views padded views of A, B, D and C tensors
|
||||
* @param i_m block m-index
|
||||
* @param i_n block n-index
|
||||
* @param i_k block k-index
|
||||
*
|
||||
* @return tuple of tile windows for A, B, D and C tensors
|
||||
*/
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
const index_t i_m,
|
||||
const index_t i_n,
|
||||
const index_t i_k)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_k, i_m});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
}();
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
{block_idx_m, block_idx_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
}
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindow(const InDataType* b_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& b_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_k_n);
|
||||
|
||||
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
|
||||
const auto& b_pad_view =
|
||||
pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_k, block_idx_n});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeABlockWindow(const OutDataType* a_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_k_m);
|
||||
|
||||
const auto& a_pad_view =
|
||||
pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
return make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::MPerBlock>{}),
|
||||
{block_idx_k, block_idx_m});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -859,28 +810,30 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
// Create block windows using helper methods
|
||||
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k);
|
||||
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k);
|
||||
const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -910,27 +863,33 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
// Create block windows using helper methods
|
||||
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k);
|
||||
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k);
|
||||
const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
return;
|
||||
#endif
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
@@ -960,12 +919,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set)
|
||||
{
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
|
||||
{
|
||||
CallExplicitGemm(kargs);
|
||||
@@ -1001,9 +954,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
@@ -1021,9 +972,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
|
||||
@@ -794,34 +794,53 @@ struct GroupedConvolutionForwardKernel
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
typename ADescType,
|
||||
typename BDescType,
|
||||
typename CDescType>
|
||||
template <typename ADescType>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
const ADescType& a_desc,
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc)
|
||||
MakeABlockWindow(const InDataType* a_ptr, const ADescType& a_desc, const index_t block_idx_m)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
|
||||
}();
|
||||
// Step 1: Create tensor view
|
||||
const auto& a_tensor_view = make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
|
||||
}();
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(c_ptr, c_desc);
|
||||
}();
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{block_idx_m, 0});
|
||||
}
|
||||
|
||||
template <typename BDescType>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindow(const WeiDataType* b_ptr, const BDescType& b_desc, const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& b_tensor_view = make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{block_idx_n, 0});
|
||||
}
|
||||
|
||||
template <typename CDescType>
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const CDescType& c_desc,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor views
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
|
||||
@@ -836,30 +855,8 @@ struct GroupedConvolutionForwardKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = views.at(I2);
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
// Step 2: Create padded views
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
@@ -868,55 +865,38 @@ struct GroupedConvolutionForwardKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}();
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
// Step 3: Create tile windows
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
{block_idx_m, block_idx_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
}
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename CDescType>
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindow(OutDataType* c_ptr,
|
||||
const CDescType& c_desc,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& c_tensor_view =
|
||||
make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, c_desc);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -931,6 +911,7 @@ struct GroupedConvolutionForwardKernel
|
||||
* @param b_desc Weight tensor B descriptor
|
||||
* @param c_desc Output tensor C descriptor
|
||||
* @param gemm_k The GEMM K dimension
|
||||
* @param k_batch The K batch parameter for split-K
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
@@ -945,34 +926,41 @@ struct GroupedConvolutionForwardKernel
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t k_batch,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const CDElementwise& elfunc)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m);
|
||||
const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -990,6 +978,7 @@ struct GroupedConvolutionForwardKernel
|
||||
* @param b_desc Weight tensor B descriptor
|
||||
* @param c_desc Output tensor C descriptor
|
||||
* @param gemm_k The GEMM K dimension
|
||||
* @param k_batch The K batch parameter for split-K
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
@@ -1005,33 +994,41 @@ struct GroupedConvolutionForwardKernel
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t k_batch,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const CDElementwise& elfunc)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m);
|
||||
const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const
|
||||
@@ -1185,9 +1182,7 @@ struct GroupedConvolutionForwardKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
@@ -1200,6 +1195,7 @@ struct GroupedConvolutionForwardKernel
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
kargs.k_batch,
|
||||
i_m,
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
@@ -1207,9 +1203,7 @@ struct GroupedConvolutionForwardKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
@@ -1221,6 +1215,7 @@ struct GroupedConvolutionForwardKernel
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
kargs.k_batch,
|
||||
i_m,
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
|
||||
Reference in New Issue
Block a user