[CK-TILE] Default2DEpilogue, example and adding nullptr_t type for D (#2752)

* Init commit

* Quick fix, CI fails

* Remove CDElementWise

* Add CDEELementWise

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>

[ROCm/composable_kernel commit: 0758883fa4]
This commit is contained in:
Mateusz Ozga
2025-08-28 21:45:50 +02:00
committed by GitHub
parent 3db57fc348
commit e4010d5ea1
20 changed files with 636 additions and 441 deletions

View File

@@ -25,13 +25,19 @@ struct Default2DEpilogueProblem
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr index_t NumDTensor = 0;
};
template <typename ADataType_,
typename BDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
typename DsLayout_,
typename CLayout_,
typename CDElementwise_,
index_t kM_,
index_t kN_,
bool kPadM_,
bool kPadN_,
index_t kMPerXdl_,
@@ -50,10 +56,20 @@ struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataTyp
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CLayout = remove_cvref_t<CLayout_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
using DsLayout = remove_cvref_t<DsLayout_>;
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
template <typename Problem_, typename Policy_ = void>
@@ -71,43 +87,70 @@ struct Default2DEpilogue
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const
{
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence();
}
else
{
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
}
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
const DsDramWindows& /* unused */,
void* = nullptr) const
const DsDramWindows& ds_dram_windows,
void* = nullptr)
{
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
const auto storeOrUpdateTile = [&](const auto& o_tile) {
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
}
buffer_store_fence();
}
else
{
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
}
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
}
}
};
if constexpr(!std::is_same_v<DsDramWindows, std::nullptr_t> && Problem::NumDTensor >= 1)
{
using elementwise_result_t = decltype(load_tile(
make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(),
make_tuple(Problem::kMPerBlock, Problem::kNPerBlock),
ds_dram_windows[number<0>{}].get_window_origin(),
o_acc_tile.get_tile_distribution())));
elementwise_result_t elementwise_result;
const auto d_tensor_tuple = generate_tuple(
[&](auto idx) {
const auto d_tile_window =
make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution());
return load_tile(d_tile_window);
},
number<Problem::NumDTensor>{});
const auto c_d_tuple = concat_tuple_of_reference(
tie(elementwise_result, o_acc_tile),
generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; },
number<Problem::NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple);
storeOrUpdateTile(elementwise_result);
}
else
{
storeOrUpdateTile(o_acc_tile);
}
}
};
@@ -122,8 +165,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using DsDataType = ck_tile::tuple<>;
using DsLayout = ck_tile::tuple<>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
@@ -192,7 +236,11 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number<I> index)
{
return GetVectorSizeC();
}
};
} // namespace ck_tile

View File

@@ -1134,8 +1134,8 @@ struct FmhaBwdDQDKDVKernel
scale_rp_undrop,
dropout);
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile, nullptr);
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile, nullptr);
}
else
{

View File

@@ -1509,7 +1509,7 @@ struct FmhaFwdKernel
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
else
{
@@ -2180,7 +2180,7 @@ struct FmhaFwdKernel
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
}
};

View File

@@ -1358,7 +1358,6 @@ struct FmhaFwdPagedKVKernel
make_tuple(kargs.stride_o, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(
o_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
@@ -1370,7 +1369,7 @@ struct FmhaFwdPagedKVKernel
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
};

View File

@@ -484,7 +484,7 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
};

View File

@@ -1134,7 +1134,7 @@ struct FmhaFwdSplitKVKernel
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_acc_dram_window, o_acc_tile);
EpiloguePipeline{}(o_acc_dram_window, o_acc_tile, nullptr);
}
};

View File

@@ -175,6 +175,12 @@ struct GemmKernelMultiD
CK_TILE_HOST static auto
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
{
// Currently MultiD kernel doesn't support k_batch > 1
if(kargs.k_batch > 1)
{
return false;
}
return UniversalGemmKernel::IsSupportedArgument(kargs);
}

View File

@@ -193,7 +193,7 @@ struct Layernorm2dFwdPipelineOnePass
Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem);
}
else
Epilogue{}(y_window_, ln);
Epilogue{}(y_window_, ln, nullptr);
}
};
} // namespace ck_tile

View File

@@ -255,7 +255,7 @@ struct Layernorm2dFwdPipelineTwoPass
});
static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT);
Epilogue{}(y_window, ln);
Epilogue{}(y_window, ln, nullptr);
move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N});

View File

@@ -221,7 +221,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
}
else
{
Epilogue{}(y_window_, rmsn);
Epilogue{}(y_window_, rmsn, nullptr);
}
}
};

View File

@@ -160,7 +160,7 @@ struct Rmsnorm2dFwdPipelineOnePass
}
else
{
Epilogue{}(y_window_, rmsn);
Epilogue{}(y_window_, rmsn, nullptr);
}
}
};

View File

@@ -195,7 +195,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
});
static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
Epilogue{}(y_window, rmsn);
Epilogue{}(y_window, rmsn, nullptr);
move_tile_window(gamma_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});