mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK-Tile] Remove usage of tile partitioner's full gemm shape (#3204)
gemm shape should be used from the pipeline instead (where it gets from a problem description struct)
This commit is contained in:
@@ -363,8 +363,8 @@ struct FlatmmKernel
|
||||
template <class KernelArgs>
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<1>{});
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
constexpr auto N1 = BlockGemmShape::WarpTile::at(number<1>{});
|
||||
constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
|
||||
|
||||
@@ -369,7 +369,7 @@ struct MoeFlatmmKernel
|
||||
template <class KernelArgs>
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
|
||||
|
||||
@@ -386,11 +386,9 @@ template <typename BlockGemmShapeType,
|
||||
uint32_t TileSwizzleSubM = 8>
|
||||
struct StreamKTilePartitioner
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShapeType;
|
||||
|
||||
static constexpr uint32_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr uint32_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr uint32_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr uint32_t MPerBlock = BlockGemmShapeType::kM;
|
||||
static constexpr uint32_t NPerBlock = BlockGemmShapeType::kN;
|
||||
static constexpr uint32_t KPerBlock = BlockGemmShapeType::kK;
|
||||
|
||||
CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept = delete;
|
||||
|
||||
|
||||
@@ -22,11 +22,10 @@ template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType = StreamKReductionStrategy::Atomic>
|
||||
struct StreamKTilePartitionerBase
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShapeType;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t MPerBlock = BlockGemmShapeType::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShapeType::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShapeType::kK;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
|
||||
|
||||
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
|
||||
|
||||
@@ -325,7 +325,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
|
||||
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
|
||||
|
||||
@@ -584,7 +584,7 @@ struct UniversalGemmKernel
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
|
||||
const auto& as_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -617,7 +617,7 @@ struct UniversalGemmKernel
|
||||
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
@@ -649,7 +649,7 @@ struct UniversalGemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
@@ -675,8 +675,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::BlockGemmShape::flatKPerWarp *
|
||||
(k_size /
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
|
||||
@@ -276,7 +276,7 @@ struct QuantGemmKernel
|
||||
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
|
||||
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
|
||||
|
||||
@@ -487,7 +487,7 @@ struct QuantGemmKernel
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -537,7 +537,7 @@ struct QuantGemmKernel
|
||||
|
||||
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
|
||||
const auto wave_tile_size =
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
|
||||
const auto wave_tile_count_x =
|
||||
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
|
||||
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
@@ -597,7 +597,7 @@ struct QuantGemmKernel
|
||||
const auto& b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
@@ -627,7 +627,7 @@ struct QuantGemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
@@ -649,10 +649,9 @@ struct QuantGemmKernel
|
||||
{
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::flatKPerWarp *
|
||||
(splitk_batch_offset.splitted_k /
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatK = GemmPipeline::flatKPerWarp *
|
||||
(splitk_batch_offset.splitted_k /
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -837,7 +836,7 @@ struct QuantGemmKernel
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||
@@ -880,7 +879,7 @@ struct QuantGemmKernel
|
||||
b_pad_view,
|
||||
make_tuple(number<GemmPipeline::flatNPerWarp>{},
|
||||
number<GemmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / TilePartitioner::BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -724,8 +724,8 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
|
||||
@@ -464,7 +464,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
|
||||
const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
|
||||
|
||||
@@ -646,8 +646,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
WeiDataType* c_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr,
|
||||
kargs.a_grid_desc_k_m); // A: out
|
||||
|
||||
@@ -745,8 +745,8 @@ struct GroupedConvolutionForwardKernel
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
|
||||
}();
|
||||
|
||||
Reference in New Issue
Block a user