mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)
* initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder * hacky poc * fix tile engine * kill the lambda * maybe fix test build * more fixes * clang-format * save temp * clang-format * mostly fix examples * clang-format * remove dead code * more cleanup * fix fmha bwd build (default epilogue set/add appears to be broken) * fix default epilogue tests but not correctness * clang-format * fix bquant * clang-format * cleanup dead code * rearrange make windows for readability * restore changes to IsSupportedArgument * fix smoke-builder * clang-format * fixup rename class * build fixes * clang-format * fix builder * fixup * remove set from builder tests * fix test * clang-format * re-refactor the kernels * clang-format * fix header license * remove memory operation from conv bwd test * clang-format * clang-format example,include * clang-format test * build fixes * clang-format * solve compilation error * fix the CI * solve compilation error * clang format * solve merge conflict * solve merge conflict * solve the gfx11 error * solve test error * moar build fixes * remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -558,21 +558,19 @@ struct FlatmmKernel
|
||||
return DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t block_idx_m)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.M, k_size),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
@@ -581,25 +579,81 @@ struct FlatmmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(k_size, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK =
|
||||
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{block_idx_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, block_idx_m});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
index_t kFlatK =
|
||||
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
|
||||
// Step 2: No padding needed for b_flat
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
b_flat_tensor_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor views
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
@@ -625,7 +679,56 @@ struct FlatmmKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
// Step 2: Create padded views
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 3: Create tile windows
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{block_idx_n, block_idx_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -647,98 +750,8 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
|
||||
|
||||
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-token scale
|
||||
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-channel scale
|
||||
|
||||
static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
|
||||
"only support per-tensor or per-row scaling");
|
||||
static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
|
||||
"only support per-tensor or per-column scaling");
|
||||
|
||||
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m_ptr.ptr,
|
||||
make_tuple(kargs.M / ScaleGranularityM,
|
||||
ScaleGranularityKA == 0
|
||||
? 1
|
||||
: splitk_batch_offset.splitted_k /
|
||||
(ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
|
||||
number<1>{});
|
||||
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n_ptr.ptr,
|
||||
make_tuple(ScaleGranularityKB == 0
|
||||
? 1
|
||||
: (splitk_batch_offset.splitted_k /
|
||||
(ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)),
|
||||
kargs.N / ScaleGranularityN),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(a_tensor_view,
|
||||
b_flat_tensor_view,
|
||||
ds_tensor_view,
|
||||
e_tensor_view,
|
||||
scale_m_view,
|
||||
scale_n_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
// Step 2: Create padded view
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
@@ -755,93 +768,72 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view,
|
||||
b_flat_tensor_view,
|
||||
ds_pad_view,
|
||||
e_pad_view,
|
||||
views.at(number<4>{}),
|
||||
views.at(number<5>{}));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
|
||||
constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m)
|
||||
{
|
||||
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
|
||||
auto scale_m_window = make_tile_window(views.at(number<4>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number < ScaleGranularityKA == 0
|
||||
? TilePartitioner::NPerBlock
|
||||
: TilePartitioner::KPerBlock > {}),
|
||||
{i_m, 0});
|
||||
auto scale_n_window = make_tile_window(views.at(number<5>{}),
|
||||
make_tuple(number < ScaleGranularityKB == 0
|
||||
? TilePartitioner::MPerBlock
|
||||
: TilePartitioner::KPerBlock > {},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-token scale
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
// Step 1: Create tensor view
|
||||
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m_ptr.ptr,
|
||||
make_tuple(kargs.M / ScaleGranularityM,
|
||||
ScaleGranularityKA == 0
|
||||
? 1
|
||||
: (splitk_batch_offset.splitted_k / ScaleGranularityKA)),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
// Step 2: Create tile window
|
||||
return make_tile_window(scale_m_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number < ScaleGranularityKA == 0
|
||||
? TilePartitioner::NPerBlock
|
||||
: TilePartitioner::KPerBlock > {}),
|
||||
{block_idx_m, 0});
|
||||
}
|
||||
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-channel scale
|
||||
|
||||
// Step 1: Create tensor view
|
||||
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n_ptr.ptr,
|
||||
make_tuple(
|
||||
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
|
||||
kargs.N / ScaleGranularityN),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
// Step 2: Create tile window
|
||||
return make_tile_window(scale_n_view,
|
||||
make_tuple(number < ScaleGranularityKB == 0
|
||||
? TilePartitioner::MPerBlock
|
||||
: TilePartitioner::KPerBlock > {},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, block_idx_n});
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
@@ -857,45 +849,74 @@ struct FlatmmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
|
||||
const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
||||
|
||||
auto scale_m_window = gemm_tile_windows.at(number<4>{});
|
||||
auto scale_n_window = gemm_tile_windows.at(number<5>{});
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(e_block_window),
|
||||
decltype(c_block_tile),
|
||||
decltype(ds_block_window)>(e_block_window,
|
||||
c_block_tile,
|
||||
ds_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(e_block_window),
|
||||
decltype(c_block_tile),
|
||||
decltype(ds_block_window)>(e_block_window,
|
||||
c_block_tile,
|
||||
ds_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
}
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(e_block_window),
|
||||
decltype(c_block_tile),
|
||||
decltype(ds_block_window)>(
|
||||
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(e_block_window),
|
||||
decltype(c_block_tile),
|
||||
decltype(ds_block_window)>(
|
||||
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -924,8 +945,7 @@ struct FlatmmKernel
|
||||
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
|
||||
Reference in New Issue
Block a user