mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[Regression] Fix CK_TILE build error in grouped_convolution, copy_basic and fused_moegemm_kernel (#2728)
* fix copy basic build error * fix other ck tile test build error
This commit is contained in:
@@ -27,8 +27,9 @@ struct TileCopyShape
|
||||
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
|
||||
|
||||
// Wave tile dimensions
|
||||
static constexpr index_t Wave_Tile_M = WaveTile::at(number<0>{});
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t Wave_Tile_N = WaveTile::at(number<1>{});
|
||||
static constexpr index_t Wave_Tile_M = ThreadTile_M * ThreadTile_N * WaveSize / Wave_Tile_N;
|
||||
|
||||
// Block tile dimensions
|
||||
static constexpr index_t Block_Tile_M = BlockTile::at(number<0>{});
|
||||
@@ -45,7 +46,6 @@ struct TileCopyShape
|
||||
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
|
||||
|
||||
// Hardware configuration
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t BlockSize = Waves_Per_Block_M * Waves_Per_Block_N * WaveSize;
|
||||
|
||||
// Configuration validation
|
||||
@@ -60,8 +60,10 @@ struct TileCopyShape
|
||||
"Invalid wave configuration for N dimension");
|
||||
|
||||
// Ensure wave tile dimensions align with wave size
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize,
|
||||
"(Wave_Tile_M/ThreadTile_M) * (Wave_Tile_N/ThreadTile_N) != WaveSize");
|
||||
#endif
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -200,6 +202,19 @@ struct ElementWiseTileCopyKernel
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
CK_TILE_HOST static auto BlockSize()
|
||||
{
|
||||
if(ck_tile::is_wave32())
|
||||
{
|
||||
return kBlockSize / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kBlockSize;
|
||||
}
|
||||
}
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
Reference in New Issue
Block a user