mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
update scale for mxfp4
This commit is contained in:
@@ -6,40 +6,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
struct A16W4_FlatmmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
|
||||
};
|
||||
|
||||
struct A16W4_FlatmmConfig32_950 : A16W4_FlatmmConfig32
|
||||
{
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
@@ -76,7 +42,6 @@ struct A16W4_FlatmmConfig16
|
||||
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
|
||||
@@ -211,10 +211,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
// Run(has_hot_loop_,
|
||||
// tail_number_,
|
||||
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
@@ -327,11 +327,11 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <class IterSrc, class IterDst>
|
||||
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K, int NXdl)
|
||||
template <class FlatmmConfig, class IterSrc, class IterDst>
|
||||
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
@@ -359,12 +359,33 @@ void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K, int NXdl)
|
||||
}
|
||||
}
|
||||
|
||||
template <class IterSrc, class IterDst>
|
||||
void preShuffleScale(const IterSrc src, IterDst dst, int N, int K, int NXdl);
|
||||
template <class FlatmmConfig, class T>
|
||||
auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
|
||||
{
|
||||
assert(scale.get_lengths().size() == 2);
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
constexpr int K_Pack = FlatmmConfig::K_Tile / FlatmmConfig::K_Warp_Tile / K_Lane;
|
||||
|
||||
static_assert(sizeof(T) * K_Pack * FlatmmConfig::N_Repeat <= 16, "inefficient pack policy");
|
||||
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
n_ / FlatmmConfig::N_Repeat / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Repeat,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 3, 5, 2, 4, 1});
|
||||
}
|
||||
|
||||
#include "run_mixed_prec_flatmm.inc"
|
||||
|
||||
template <template <typename PrecType> typename FlatmmConfig>
|
||||
template <typename FlatmmConfig>
|
||||
int run_mixed_prec_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -385,33 +406,33 @@ int run_mixed_prec_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig<ck_tile::bf16_t>,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig,
|
||||
// false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig<ck_tile::bf16_t>,
|
||||
// FlatmmConfig,
|
||||
// true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
// if(persistent_opt == 0)
|
||||
// {
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig<ck_tile::fp16_t>,
|
||||
// false>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else
|
||||
// {
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig<ck_tile::fp16_t>,
|
||||
// FlatmmConfig,
|
||||
// true>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
}
|
||||
@@ -437,11 +458,11 @@ int main(int argc, char* argv[])
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig32_950>(argc, argv);
|
||||
// return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
else
|
||||
{
|
||||
|
||||
@@ -23,6 +23,8 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
using CDataType = PrecActType;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int DequantGranularityN = 1;
|
||||
constexpr int DequantGranularityK = 32;
|
||||
|
||||
@@ -50,42 +52,42 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> weight_dequant_scale(ck_tile::HostTensorDescriptor(
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
|
||||
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(weight_dequant_scale);
|
||||
// ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(weight_dequant_scale);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
preShuffleWeight(
|
||||
b_origin_host.begin(), b_shuffle_host.begin(), N, K, FlatmmConfig::N_Warp_Tile);
|
||||
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffle_host.begin(), N, K);
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b_shuffle = preShuffleScale<FlatmmConfig>(scale_b);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weight_dequant_scale_dev_buf(
|
||||
weight_dequant_scale.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
weight_dequant_scale_dev_buf.ToDevice(weight_dequant_scale.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
auto weight_dequant_scale_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
|
||||
static_cast<float*>(weight_dequant_scale_dev_buf.GetDeviceBuffer()),
|
||||
N / DequantGranularityN};
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
|
||||
invoke_mixed_prec_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
@@ -97,7 +99,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(weight_dequant_scale_dev_ptr),
|
||||
decltype(scale_b_dev_ptr),
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
@@ -108,7 +110,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
weight_dequant_scale_dev_ptr,
|
||||
scale_b_dev_ptr,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
@@ -126,12 +128,19 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
|
||||
ck_tile::HostTensor<AccDataType> scale_A(
|
||||
ck_tile::HostTensorDescriptor({1, K / DequantGranularityK}, {1, 1}));
|
||||
|
||||
// scaleA = 1 has no effect on the result
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
|
||||
|
||||
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
|
||||
scale_A_dev_buf.ToDevice(scale_A.data());
|
||||
|
||||
// convert scale_b from e8m0 to float
|
||||
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
|
||||
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
|
||||
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
|
||||
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
|
||||
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
|
||||
|
||||
c_gpu_ref_dev_buf.SetZero();
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
@@ -153,7 +162,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
DequantGranularityN,
|
||||
DequantGranularityK,
|
||||
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(weight_dequant_scale_dev_buf.GetDeviceBuffer()));
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
|
||||
|
||||
@@ -227,11 +227,11 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale)
|
||||
return convert_to_type<pk_fp4_t>(x, scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale = 1.0f)
|
||||
{
|
||||
return float_to_e2m1(x, scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -239,7 +239,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float sca
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -247,7 +247,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float sca
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -256,7 +256,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -265,7 +265,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -274,27 +274,27 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_fp32x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_fp16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_bf16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_float(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_fp16(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_bf16(scale);
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr int QuantPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
static constexpr int N_Pack = 2;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -47,6 +48,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I4 = number<4>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
@@ -149,7 +151,21 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
|
||||
auto scale_n = kargs.scale_n_ptr;
|
||||
|
||||
index_t FlatScaleK =
|
||||
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
const auto scale_b_flat_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(scale_n.ptr,
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(
|
||||
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -215,7 +231,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
@@ -275,6 +291,12 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
auto scale_block_window =
|
||||
make_tile_window(views.at(I4),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
|
||||
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
|
||||
}
|
||||
|
||||
@@ -304,8 +326,13 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I3);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
|
||||
b_flat_block_window,
|
||||
scale_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
|
||||
|
||||
@@ -371,37 +371,6 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ADramDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<4>,
|
||||
tuple<sequence<16>, sequence<4, 4, 8>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
@@ -438,42 +407,6 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = 32;
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
|
||||
{
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -37,7 +38,7 @@ struct MixedPrecFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
static constexpr index_t flatKPerWarp = 128;
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
|
||||
template <typename Problem, typename PipelinePolicy = MixedPrecFlatmmPipelineAgBgCrPolicy>
|
||||
struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
: FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
|
||||
{
|
||||
@@ -456,10 +457,14 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename DequantBFlatWindow>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const DequantBFlatWindow& scale_b_flat_window,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
@@ -565,35 +570,61 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_flatmm.MakeCBlockTile();
|
||||
|
||||
constexpr int XDLPerLoadK = 4;
|
||||
constexpr int XDLPerLoadK = 4;
|
||||
constexpr int NRepeatPerScaleLoad = 2;
|
||||
|
||||
constexpr int QuantKPerWarp = KIterPerWarp / XDLPerLoadK;
|
||||
constexpr int QuantNPerWarp = NIterPerWarp / NRepeatPerScaleLoad;
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeFp4BFlatDramTileDistribution<Problem>();
|
||||
auto b_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
auto scale_b_flat_distribution =
|
||||
PipelinePolicy::template MakeFp4ScaleBFlatDramTileDistribution<Problem>();
|
||||
|
||||
auto b_flat_dram_window = make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
constexpr int ScaleB_BlockK =
|
||||
flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
|
||||
|
||||
auto scale_b_flat_dram_window = make_tile_window(
|
||||
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<ScaleB_BlockK>{}),
|
||||
scale_b_flat_window.get_window_origin(),
|
||||
scale_b_flat_distribution);
|
||||
|
||||
// pingpong buffer for B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), QuantKPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(scale_b_flat_dram_window), QuantKPerWarp>,
|
||||
QuantNPerWarp>
|
||||
scale_b_flat_dram_windows;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), QuantKPerWarp>,
|
||||
QuantNPerWarp>
|
||||
scale_b_warp_tensor_ping;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), QuantKPerWarp>,
|
||||
QuantNPerWarp>
|
||||
scale_b_warp_tensor_pong;
|
||||
|
||||
// HEAD
|
||||
// Prefetch A0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
@@ -603,6 +634,19 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
@@ -613,6 +657,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
|
||||
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
@@ -643,12 +688,56 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
auto dequant_B = typename WG::BWarpTensor{};
|
||||
|
||||
auto deq_fn = [&](auto& quant_weight_tensor, auto sub_idx) {
|
||||
auto perm_scale = [&](auto lane_scale, auto xdl_k_idx) {
|
||||
#if defined(__gfx942__)
|
||||
return lane_scale;
|
||||
#endif
|
||||
auto v2scale = __builtin_amdgcn_permlane32_swap(lane_scale, lane_scale, 0, 0);
|
||||
if constexpr(xdl_k_idx < 2)
|
||||
{
|
||||
lane_scale = v2scale[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
lane_scale = v2scale[1];
|
||||
}
|
||||
|
||||
v2scale = __builtin_amdgcn_permlane16_swap(lane_scale, lane_scale, 0, 0);
|
||||
if constexpr(xdl_k_idx % 2 == 0)
|
||||
{
|
||||
return v2scale[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return v2scale[1];
|
||||
}
|
||||
};
|
||||
|
||||
auto deq_fn = [&](const auto& quant_weight_tensor,
|
||||
const auto& scale_tensor,
|
||||
auto xdl_nIter,
|
||||
auto xdl_kIter) {
|
||||
auto b_idx_k = xdl_kIter % number<XDLPerLoadK>{};
|
||||
|
||||
auto scale_idx_n = xdl_nIter % number<NRepeatPerScaleLoad>{};
|
||||
|
||||
uint32_t packed_scale = scale_tensor.get_thread_buffer().template get_as<uint32_t>(I0);
|
||||
packed_scale = perm_scale(packed_scale, b_idx_k);
|
||||
|
||||
e8m0_t* scale_ptr = reinterpret_cast<e8m0_t*>(&packed_scale);
|
||||
|
||||
if constexpr(xdl_nIter % 2 != 0)
|
||||
{
|
||||
scale_ptr++;
|
||||
}
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
|
||||
dequant_B.get_thread_buffer().template set_as<fp16x2_t>(
|
||||
number<i>{},
|
||||
fp16x2_t(quant_weight_tensor.get_thread_buffer()[sub_idx * ScalarCnt / 2 + i]));
|
||||
pk_fp4_to_fp16x2(
|
||||
quant_weight_tensor.get_thread_buffer()[b_idx_k * ScalarCnt / 2 + i],
|
||||
*scale_ptr));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -690,7 +779,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
kIter % number<XDLPerLoadK>{});
|
||||
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
|
||||
kIter / number<XDLPerLoadK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
@@ -721,6 +813,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
@@ -765,7 +858,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
kIter % number<XDLPerLoadK>{});
|
||||
scale_b_warp_tensor_pong(nIter / number<NRepeatPerScaleLoad>{})(
|
||||
kIter / number<XDLPerLoadK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
@@ -795,6 +891,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
@@ -839,7 +936,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
kIter % number<XDLPerLoadK>{});
|
||||
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
|
||||
kIter / number<XDLPerLoadK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
@@ -889,7 +989,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
kIter % number<XDLPerLoadK>{});
|
||||
scale_b_warp_tensor_pong(nIter / number<NRepeatPerScaleLoad>{})(
|
||||
kIter / number<XDLPerLoadK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
@@ -931,7 +1034,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
kIter % number<XDLPerLoadK>{});
|
||||
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
|
||||
kIter / number<XDLPerLoadK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
@@ -964,9 +1070,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename DequantBFlatWindow>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const DequantBFlatWindow& scale_b_flat_window,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
@@ -975,6 +1084,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
scale_b_flat_window,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct MixedPrecFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
|
||||
{
|
||||
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16);
|
||||
return TileShape::WarpTile::at(I2) / 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ADramDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<4>,
|
||||
tuple<sequence<16>, sequence<4, 4, 8>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = 32;
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t N_Repeat = TileShape::kN / TileShape::WarpTile::at(I1) / N_Warp;
|
||||
constexpr index_t N_Pack = N_Repeat;
|
||||
|
||||
constexpr index_t XDLPerBlock = TileShape::kK / TileShape::WarpTile::at(I2);
|
||||
constexpr index_t KBPerLoad = XDLPerBlock * N_Pack;
|
||||
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t K_Pack = XDLPerBlock / K_Lane;
|
||||
|
||||
// constexpr index_t RepeatScale = TileShape::WarpTile::at(I2) / ;
|
||||
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk>, // second direction
|
||||
sequence<K_Lane, 16, N_Pack * K_Pack>>, // first
|
||||
// direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<1>, sequence<2, 2>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = kMPerBlock / M1;
|
||||
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
if constexpr(warp_size >= (K2 * M0))
|
||||
{
|
||||
constexpr index_t K1 = warp_size / (K2 * M0);
|
||||
constexpr index_t K0 = kBlockSize / warp_size;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
|
||||
{
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
|
||||
typename Problem::ADataType,
|
||||
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename
|
||||
// Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user