Do not use warpSize as compile time constant as it is removed (#2320)

* Do not use warpSize as compile time constant as it is removed

* Update tile_image_to_column_shape.hpp

update warpSize usage.

* clean-up all use of warpSize, make sure code builds

* fix

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>

[ROCm/composable_kernel commit: 4c57157d50]
This commit is contained in:
Satyanvesh Dittakavi
2025-06-18 00:24:30 +05:30
committed by GitHub
parent af88547b60
commit a4517b0a9d
31 changed files with 213 additions and 206 deletions

View File

@@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
}
}();
@@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
}
}();

View File

@@ -35,7 +35,7 @@ struct Reduce2dShape
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
WarpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
};
template <typename XDataType_,

View File

@@ -74,22 +74,22 @@ struct rmsnorm2d_fwd_traits_
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
}
}();
@@ -97,13 +97,13 @@ struct rmsnorm2d_fwd_traits_
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
}
}();

View File

@@ -80,22 +80,22 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
}
}();
@@ -103,13 +103,13 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
}
}();

View File

@@ -49,22 +49,22 @@ struct smoothquant_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
}
}();
@@ -72,13 +72,13 @@ struct smoothquant_traits_
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
}
}();

View File

@@ -38,22 +38,22 @@ struct moe_smoothquant_traits_
using InputType = ck_tile::remove_cvref_t<InputType_>;
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
}
}();
@@ -61,13 +61,13 @@ struct moe_smoothquant_traits_
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
}
}();