mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
fix settings for example, fix some things in pipeline
This commit is contained in:
@@ -162,7 +162,7 @@ execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMI
|
||||
configure_file(include/ck/version.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/version.h)
|
||||
|
||||
set(ROCM_SYMLINK_LIBS OFF)
|
||||
find_package(ROCM REQUIRED PATHS /opt/rocm)
|
||||
find_package(ROCM REQUIRED PATHS /opt/venv/lib/python3.12/site-packages/_rocm_sdk_devel)
|
||||
|
||||
include(ROCMInstallTargets)
|
||||
include(ROCMPackageConfigHelpers)
|
||||
|
||||
@@ -31,7 +31,7 @@ template <typename GemmConfig,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool UsePersistentKernel = false>
|
||||
bool UsePersistentKernel = true>
|
||||
float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
@@ -83,7 +83,7 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
UsePersistentKernel,
|
||||
GemmConfig::NumWaveGroups,
|
||||
true>;
|
||||
false>;
|
||||
|
||||
using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -152,9 +152,9 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "32", "m dimension")
|
||||
.insert("n", "512", "n dimension")
|
||||
.insert("k", "256", "k dimension")
|
||||
arg_parser.insert("m", "4096", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "4096", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
@@ -169,7 +169,6 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
|
||||
@@ -39,7 +39,8 @@ struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0>
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp4_GemmConfig16
|
||||
|
||||
struct MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
@@ -70,3 +71,17 @@ struct MXfp4_GemmConfig16
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
struct MXfp4_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp8_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
@@ -49,25 +49,25 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
|
||||
// Scale tensors
|
||||
// Assuming block scale 32
|
||||
ck_tile::index_t scale_n_size = N / 32;
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
ck_tile::index_t scale_k_size = K / 32;
|
||||
ck_tile::HostTensor<ck_tile::e8m0_t> scale_a_host(
|
||||
ck_tile::HostTensor<ScaleType> scale_a_host(
|
||||
ck_tile::HostTensorDescriptor({M, scale_k_size}, {scale_k_size, 1}));
|
||||
ck_tile::HostTensor<ck_tile::e8m0_t> scale_b_host(
|
||||
ck_tile::HostTensorDescriptor({scale_k_size, scale_n_size}, {scale_n_size, 1}));
|
||||
ck_tile::HostTensor<ScaleType> scale_b_host(
|
||||
ck_tile::HostTensorDescriptor({scale_k_size, N}, {1, scale_k_size}));
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_host);
|
||||
ck_tile::FillUniformDistribution<ck_tile::e8m0_t>{-1.f, 1.f}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ck_tile::e8m0_t>{-1.f, 1.f}(scale_b_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-1.f, 1.f}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-1.f, 1.f}(scale_b_host);
|
||||
break;
|
||||
case 1:
|
||||
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
|
||||
ck_tile::FillConstant<BDataType>{BDataType(1.f)}(b_host);
|
||||
ck_tile::FillConstant<ck_tile::e8m0_t>{ck_tile::e8m0_t(1.f)}(scale_a_host);
|
||||
ck_tile::FillConstant<ck_tile::e8m0_t>{ck_tile::e8m0_t(1.f)}(scale_b_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -83,8 +83,8 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
scale_b_dev_buf.ToDevice(scale_b_host.data());
|
||||
|
||||
// Scale pointers
|
||||
using ScaleM = ck_tile::MXScalePointer<1, 32>; // per-token
|
||||
using ScaleN = ck_tile::MXScalePointer<32, 32>; // per-block
|
||||
using ScaleM = ck_tile::MXScalePointer<1, 32>; // in blocks of 32 in K
|
||||
using ScaleN = ck_tile::MXScalePointer<1, 32>;
|
||||
|
||||
ScaleM scale_m(reinterpret_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()));
|
||||
ScaleN scale_n(reinterpret_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()));
|
||||
@@ -104,14 +104,31 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
|
||||
(void)ave_time;
|
||||
|
||||
bool pass = true;
|
||||
if(validation > 0)
|
||||
{
|
||||
// get output data from device
|
||||
c_dev_buf.FromDevice(c_host.data());
|
||||
// TODO: Implement validation logic (reference GEMM with scales)
|
||||
// For now just print success if it runs
|
||||
std::cout << "Validation not implemented yet." << std::endl;
|
||||
|
||||
// compute reference
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_host);
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
return 0;
|
||||
return pass ? 0 : -1;
|
||||
}
|
||||
|
||||
int run_mx_gemm_example(int argc, char* argv[])
|
||||
@@ -126,24 +143,28 @@ int run_mx_gemm_example(int argc, char* argv[])
|
||||
std::string mx_prec = arg_parser.get_str("mx_prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_gemm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
float,
|
||||
MXfp4_GemmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only non-persistent kernels are supported currently!");
|
||||
return run_mx_gemm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
float,
|
||||
MXfp4_GemmConfig16,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
|
||||
{
|
||||
return run_mx_gemm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
MXfp8_GemmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Only fp4xfp4 is supported currently!");
|
||||
throw std::runtime_error("Only fp4 and fp8 is supported currently!");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -39,12 +39,8 @@
|
||||
#define CK_TILE_DEVICE inline __device__
|
||||
#define CK_TILE_HOST_DEVICE inline __host__ __device__
|
||||
#define CK_TILE_DEVICE_EXTERN __device__
|
||||
#if __clang_major__ < 22
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
|
||||
#else
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN
|
||||
#endif
|
||||
#else
|
||||
#define CK_TILE_HOST inline
|
||||
#define CK_TILE_DEVICE inline
|
||||
#define CK_TILE_HOST_DEVICE inline
|
||||
|
||||
@@ -119,6 +119,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution()
|
||||
{
|
||||
// TODO: these could be replaced by the standard UniversalGEMM tile distributions??
|
||||
constexpr index_t K2 = AK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
@@ -95,6 +95,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
static constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
/// @brief The e8m0 scales are packed into int32/float32 such that
|
||||
/// in one element contains a 2x2 block of scales (two rows, two lements in K dim)
|
||||
static constexpr auto MXdlPack = MXGemmPipeline::MXdlPack;
|
||||
static constexpr auto NXdlPack = MXGemmPipeline::NXdlPack;
|
||||
static constexpr auto KXdlPack = MXGemmPipeline::KXdlPack;
|
||||
@@ -195,7 +197,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK, "M and N scales must have same K granularity!");
|
||||
static constexpr int BlockScaleSize = ScaleM::GranularityK;
|
||||
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
|
||||
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
|
||||
@@ -218,10 +221,10 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
// B scale tensor view
|
||||
const auto& scale_b_tensor_view = [&]() {
|
||||
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
|
||||
const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
|
||||
const auto scale_b_desc = transform_tensor_descriptor(
|
||||
scale_b_navie_desc,
|
||||
scale_b_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
@@ -251,12 +254,14 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
// We are packing 2x2 (MXdlPack x KXdlPack) scales (e8m0) into one int32 element
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
views.at(I4),
|
||||
make_tuple(number<TilePartitioner::MPerBlock / MXdlPack>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
|
||||
{i_m / MXdlPack, 0});
|
||||
|
||||
// We are packing 2x2 (NXdlPack x KXdlPack) scales (e8m0) into one int32 element
|
||||
auto scale_b_block_window = make_tile_window(
|
||||
views.at(I5),
|
||||
make_tuple(number<TilePartitioner::NPerBlock / NXdlPack>{},
|
||||
@@ -295,7 +300,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_a_block_window = gemm_tile_windows.at(I4);
|
||||
const auto& scale_b_block_window = gemm_tile_windows.at(I5);
|
||||
@@ -304,12 +309,9 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
"ScaleM and ScaleN should have the same GranularityK");
|
||||
constexpr bool DoEpiScale =
|
||||
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
|
||||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
|
||||
|
||||
const auto& c_block_tile = MXGemmPipeline{}(a_block_window,
|
||||
b_flat_block_window,
|
||||
b_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
@@ -317,54 +319,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(DoEpiScale)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
auto scale_m_ptr_offset = kargs.scale_m_ptr + block_idx_m;
|
||||
auto scale_n_ptr_offset = kargs.scale_n_ptr + block_idx_n;
|
||||
|
||||
auto scale_m_view = [&]() {
|
||||
if constexpr (ScaleM::GranularityMN != -1) {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
scale_m_ptr_offset.ptr,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(number<1>{}, number<0>{}),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
} else {
|
||||
return typename EpiloguePipeline::EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_n_view = [&]() {
|
||||
if constexpr (ScaleN::GranularityMN != -1) {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
scale_n_ptr_offset.ptr,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(number<0>{}, number<1>{}),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
} else {
|
||||
return typename EpiloguePipeline::EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_view,
|
||||
scale_n_view);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
|
||||
|
||||
@@ -242,7 +242,7 @@ struct MXGemmPipelineAgBgCrV1
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
};
|
||||
|
||||
// Helper for Math Loop
|
||||
// Helper for Main Loop
|
||||
auto warp_gemm_loop = [&](auto& a_warp_window, auto& b_warp_window, auto& scale_a, auto& scale_b) {
|
||||
// Define register tiles types for double buffering
|
||||
using AValType = decltype(load_tile_with_offset(a_warp_window, tuple<number<0>, number<0>>{}));
|
||||
|
||||
@@ -227,31 +227,10 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<K_Thread / AK1, K_Lane, AK1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<2, 2>, // K_Thread/AK1, AK1
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution()
|
||||
{
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
constexpr index_t N2 = WaveSize / K1; // 8
|
||||
constexpr index_t N1 = BlockSize / WaveSize; // 4
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>, // N0,K0,K2
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
@@ -294,6 +273,29 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
TensorView::DstInMemOp>{naive_view.buf_, desc};
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution()
|
||||
{
|
||||
// TODO: these could be replaced by the standard UniversalGEMM tile distributions??
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
constexpr index_t N2 = WaveSize / K1; // 8
|
||||
constexpr index_t N1 = BlockSize / WaveSize; // 4
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>, // N0,K0,K2
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
|
||||
Reference in New Issue
Block a user