draft for 16*16*128 fp8

This commit is contained in:
solin
2025-05-14 09:56:13 +00:00
parent 41c17d0a95
commit c2d17ec24f
7 changed files with 69 additions and 39 deletions

View File

@@ -3,6 +3,7 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef)
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef)
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -DENABLE_FP8=1 -Wno-unused-local-typedef)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x128=1 -DENABLE_FP8=1 -Wno-unused-local-typedef)
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

View File

@@ -27,7 +27,19 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 2;
#if defined(USING_MFMA_16x16x128)
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 128;
#endif
// This part comes from the Codegen
#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16)
constexpr ck_tile::index_t M_Tile = 128;
@@ -151,11 +163,11 @@ int run_flatmm_example(int argc, char* argv[])
{
if(data_type == "fp16")
{
run_flatmm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
//run_flatmm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_flatmm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
//run_flatmm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
@@ -163,7 +175,7 @@ int run_flatmm_example(int argc, char* argv[])
}
else if(data_type == "bf8")
{
run_flatmm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
//run_flatmm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{

View File

@@ -31,6 +31,21 @@
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
template <typename DataType>
struct GemmConfig
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmBasicTypeConfig;
@@ -122,7 +137,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -33,37 +33,28 @@ static constexpr inline auto is_row_major(Layout layout_)
// mfma_type, 0:32x32, 1:16x16
template <typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type)
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
if constexpr(GemmConfig<T>::N_Warp_Tile == 32)
{
ck_tile::HostTensor<T> t_view({n_ / 32, 32, k_ / 16, 2, 8});
ck_tile::HostTensor<T> t_view(
{n_ / 32, 32, k_ / GemmConfig<T>::K_Warp_Tile, 2, GemmConfig<T>::K_Warp_Tile / 2});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
else
{
ck_tile::HostTensor<T> t_view({n_ / 16, 16, k_ / 32, 4, 8});
static_assert(GemmConfig<T>::N_Warp_Tile == 16);
ck_tile::HostTensor<T> t_view(
{n_ / 16, 16, k_ / GemmConfig<T>::K_Warp_Tile, 4, GemmConfig<T>::K_Warp_Tile / 4});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({n_ / 32, 32, k_ / 32, 2, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({n_ / 16, 16, k_ / 64, 4, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
return t;
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
@@ -189,12 +180,8 @@ int run_flatmm_example_with_layouts(int argc,
// do pre-shuffle
std::string mfma = arg_parser.get_str("prec");
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
ck_tile::index_t mfma_type = 1;
#else
ck_tile::index_t mfma_type = 0;
#endif
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b(b_origin_host, mfma, mfma_type);
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<BDataType>(b_origin_host);
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());

View File

@@ -38,10 +38,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
// r.x = __builtin_amdgcn_readfirstlane(r.x);
// r.y = __builtin_amdgcn_readfirstlane(r.y);
// r.z = __builtin_amdgcn_readfirstlane(r.z);
// r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}

View File

@@ -75,6 +75,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
#if 0
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) || defined(USING_MFMA_32x32x16)
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
@@ -148,6 +149,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
#endif
#endif
}

View File

@@ -19,11 +19,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
#if (defined(USING_MFMA_16x16x32) || defined(USING_MFMA_16x16x128)) && defined(ENABLE_FP8)
using ADataType = remove_cvref_t<typename Problem::ADataType>;
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr index_t KPack = Problem::VectorLoadSize / sizeof(ADataType);
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
@@ -138,6 +139,21 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetK1()
{
using TileShape = typename Problem::BlockGemmShape;
if constexpr(TileShape::WarpTile::at(TileShape::idxN) == 32)
{
return TileShape::WarpTile::at(TileShape::idxK) / 2;
}
else
{
static_assert(TileShape::WarpTile::at(TileShape::idxN) == 16);
return TileShape::WarpTile::at(TileShape::idxK) / 4;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
@@ -232,16 +248,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad =
Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt
constexpr index_t KBPerLoad = GetK1<Problem>(); // dwordx4 load B elem cnt
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;