mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] Update flatmm related kernels (#3022)
--------- Co-authored-by: Ding, Yi <yi.ding@amd.com> Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -238,22 +239,47 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t scale = 4;
|
||||
#else
|
||||
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
|
||||
#endif
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) * scale / 2;
|
||||
return TileShape::WarpTile::at(I2) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16);
|
||||
return TileShape::WarpTile::at(I2) * scale / 4;
|
||||
return TileShape::WarpTile::at(I2) / 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALDS_WarpTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
|
||||
|
||||
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr int KLane = get_warp_size() / MPerXdl;
|
||||
constexpr int KPerThread = KPerXdl / KLane;
|
||||
|
||||
constexpr int MaxVecSize = 16 / sizeof(ADataType);
|
||||
constexpr int KItemsPerLoad = min(MaxVecSize, KPerThread);
|
||||
constexpr int KFragment = KPerThread / KItemsPerLoad;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<Repeat>,
|
||||
tuple<sequence<MPerXdl>, sequence<KFragment, KLane, KItemsPerLoad>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
@@ -307,10 +333,10 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (M2 * K0) == 0)
|
||||
if constexpr(get_warp_size() % K0 == 0)
|
||||
{
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
@@ -329,24 +355,54 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M0 = BlockSize / get_warp_size();
|
||||
constexpr index_t M1 = MPerBlock / (M2 * M0);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M1, M2 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
constexpr index_t KWave = K0 / get_warp_size();
|
||||
constexpr index_t M0 = BlockSize / get_warp_size() / KWave;
|
||||
constexpr index_t M1 = MPerBlock / M0;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<KWave, get_warp_size(), K1>>,
|
||||
tuple<sequence<1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0>, sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
@@ -355,15 +411,16 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
constexpr index_t KRepeatInWave = 1;
|
||||
#endif
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
|
||||
constexpr index_t MaxVecSize = 16 / sizeof(typename Problem::BDataType);
|
||||
constexpr index_t KItemsPerLoad = min(KBPerLoad, MaxVecSize);
|
||||
constexpr index_t KFragment = KBPerLoad / KItemsPerLoad;
|
||||
static_assert(KFragment * KItemsPerLoad == KBPerLoad);
|
||||
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim./
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
static_assert(TileShape::BlockWarps::at(number<2>{}) == 1, "Requires K_Warp == 1");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
@@ -371,15 +428,17 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KFragment, KWavePerBlk, KThdPerWave, KItemsPerLoad>>, // first
|
||||
// direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,239 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 0
|
||||
|
||||
#if defined(__gfx950__)
|
||||
#define CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4 1
|
||||
#else
|
||||
#define CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4 0
|
||||
#endif
|
||||
|
||||
#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS \
|
||||
(CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && \
|
||||
CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4)
|
||||
|
||||
struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t KBPerLoad = 32;
|
||||
static constexpr index_t N_Pack = 2; // it's fixed for fp4
|
||||
static constexpr index_t K_Pack = 2; // it's fixed for fp4
|
||||
|
||||
template <typename Problem, typename NativeADramTensorView>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
TransformF16xF4_ATensorView(const NativeADramTensorView& a_dram_view)
|
||||
{
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
|
||||
constexpr int DynamicTileOffsetFlag = 0;
|
||||
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
|
||||
|
||||
// implement swizzle pattern on global side
|
||||
// because we can't adjust the ds_write pattern of BUFFER_LOAD_LDS.
|
||||
auto swizzle_a_dram_view_1 = transform_tensor_view(
|
||||
a_dram_view,
|
||||
make_tuple(
|
||||
// M-dim is not affected by swizzle pattern
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<DynamicTileOffsetFlag>{}, number<MPerBlock>{})),
|
||||
// K-dim is the swizzle dimension
|
||||
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<KPerBlock / KPack>{},
|
||||
number<KPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}));
|
||||
|
||||
auto swizzle_a_dram_view_2 = transform_tensor_view(
|
||||
swizzle_a_dram_view_1,
|
||||
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
number<ContiguousThreadsCntInDS_READ_16B>{})),
|
||||
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
swizzle_a_dram_view_2,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<DynamicTileOffsetFlag>{}, number<MPerBlock>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<KPerBlock / KPack>{},
|
||||
number<KPack>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
#else
|
||||
return a_dram_view;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ReadALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
number<ContiguousThreadsCntInDS_READ_16B>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_WriteALdsBlockDescriptor()
|
||||
{
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
return make_naive_tensor_descriptor(make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
make_tuple(number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
#else
|
||||
return MakeF16xF4_ReadALdsBlockDescriptor<Problem>();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALDS_TileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
|
||||
|
||||
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8
|
||||
constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4
|
||||
constexpr int K0 = K_Lane; // 4
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Repeat>,
|
||||
tuple<sequence<M0>, sequence<K0, XDL_PerThreadK, K2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Pack>, // second
|
||||
// direction
|
||||
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
[[maybe_unused]] constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
|
||||
|
||||
[[maybe_unused]] constexpr index_t XDLPerBlock =
|
||||
TileShape::kK / TileShape::WarpTile::at(I2);
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
|
||||
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk>, // second direction
|
||||
sequence<K_Lane, N_Lane, N_Pack * K_Pack>>, // first
|
||||
// direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<1>, sequence<2, 2>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user