mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-22 08:07:38 +00:00
update f16xMXF4
This commit is contained in:
@@ -97,6 +97,8 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -129,9 +131,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
@@ -211,10 +214,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
}
|
||||
else
|
||||
{
|
||||
// Run(has_hot_loop_,
|
||||
// tail_number_,
|
||||
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// ck_tile::memory_operation_enum::atomic_add>{});
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
@@ -412,17 +415,17 @@ int run_mixed_prec_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig,
|
||||
// false>(argc, argv, Row{}, Col{}, Row{});
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig,
|
||||
// true>(argc, argv, Row{}, Col{}, Row{});
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else if(mixed_prec == "fp16xfp4")
|
||||
@@ -434,13 +437,13 @@ int run_mixed_prec_flatmm_example(int argc, char* argv[])
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else
|
||||
// {
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig,
|
||||
// true>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
else
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -466,10 +469,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported warp_tile!");
|
||||
|
||||
@@ -58,8 +58,8 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
@@ -165,8 +165,8 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
|
||||
const float rtol = 5e-3;
|
||||
const float atol = 1e-3;
|
||||
const float rtol = 1e-2;
|
||||
const float atol = 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
@@ -31,10 +31,9 @@ struct e8m0_bexp_t
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
|
||||
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale) : data(0)
|
||||
{
|
||||
data = numeric_utils<float>::get_exponent(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
|
||||
@@ -256,7 +256,7 @@ struct buffer_view<address_space_enum::global,
|
||||
T* p_data_ = nullptr;
|
||||
BufferSizeType buffer_size_;
|
||||
int32x4_t cached_buf_res_;
|
||||
remove_cvref_t<T> invalid_element_value_ = T{0.f};
|
||||
remove_cvref_t<T> invalid_element_value_ = T{0};
|
||||
|
||||
static constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
|
||||
|
||||
@@ -269,7 +269,7 @@ struct buffer_view<address_space_enum::global,
|
||||
: p_data_{p_data},
|
||||
buffer_size_{buffer_size / PackedSize},
|
||||
cached_buf_res_{0},
|
||||
invalid_element_value_{0.f}
|
||||
invalid_element_value_{0}
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@@ -27,10 +27,11 @@ template <typename ADataType_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false>
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -53,6 +54,7 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
@@ -89,6 +91,7 @@ struct CShuffleEpilogue
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
@@ -193,7 +196,10 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}();
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle =
|
||||
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
|
||||
|
||||
static_assert(NumNXdlPerWavePerShuffle % BlockedXDLN_PerWarp == 0);
|
||||
|
||||
static constexpr auto MNPerIterationShuffle = [] {
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
@@ -242,14 +248,31 @@ struct CShuffleEpilogue
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto block_outer_dstr_encoding = [] {
|
||||
if constexpr(BlockedXDLN_PerWarp == 1)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
// BlockedLayout
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
}();
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
|
||||
|
||||
|
||||
@@ -58,6 +58,46 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
|
||||
{
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size,
|
||||
F16xMXF4FlatmmKernel,
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size
|
||||
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
|
||||
Reference in New Issue
Block a user