diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 141f598737..0d23e24e14 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -62,6 +62,85 @@ auto shuffle_b(const ck_tile::HostTensor& t) return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, + // 2-k))); + + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K_pk; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, @@ -371,9 +450,9 @@ auto create_args(int argc, char* argv[]) .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8") - .insert("warp_tile", - "0", - "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)"); + .insert( + "warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)") + .insert("verbosity", "0", "0: no log print, 1: print log"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 422eb93cc5..07e8ab4546 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -136,6 +136,7 @@ struct GemmBasicTypeConfig using BDataType = ck_tile::pk_fp4_t; using AccDataType = float; using XDataType = ck_tile::e8m0_bexp_t; // scale data type + using XPackedData = int32_t; // packed scale data type using CDataType = ck_tile::half_t; } diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index ef94590a06..37aceca248 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -71,6 +71,10 @@ int run_flatmm_example_with_layouts(int argc, using AscaleDataType = typename GemmBasicTypeConfig::XDataType; using BscaleDataType = typename GemmBasicTypeConfig::XDataType; + // B shuffled tensor + ck_tile::HostTensor b_shuffle_host( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + // A, B scale tensors ck_tile::HostTensor a_m_k_scale(ck_tile::host_tensor_descriptor( Scale_Padded_M, K / ScaleGranularityK, Scale_stride_AM, is_row_major(a_layout))); @@ -90,8 +94,62 @@ int run_flatmm_example_with_layouts(int argc, ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_scale); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_scale); } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_host); + ck_tile::FillMonotonicSeq{}(b_origin_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_scale); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_scale); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_scale); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_scale); + } + else + { + a_host.SetZero(); + b_origin_host.SetZero(); + a_m_k_scale.SetZero(); + b_k_n_scale.SetZero(); + } + + if(arg_parser.get_int("v") == 1) + std::cout << "do shuffle for B, and A/B scale..." << std::endl; + // do B pre-shuffle + preShuffleBuffer(b_origin_host.mData.data(), + b_shuffle_host.mData.data(), + N, + K, + FlatmmConfig::N_Warp_Tile); + + // do A, B scale pre-shuffle + preShuffleScaleBuffer>( + a_m_k_scale.data(), a_m_k_scale_shuffle.data(), Scale_Padded_M, K / ScaleGranularityK); + preShuffleScaleBuffer< + ck_tile::is_same_v>( + b_k_n_scale.data(), b_k_n_scale_shuffle.data(), N, K / ScaleGranularityK); + + if(arg_parser.get_int("v") == 1) + std::cout << "Device memory allocation..." << std::endl; + ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a_scale_dev_buf(a_m_k_scale.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_scale_dev_buf(b_k_n_scale.get_element_space_size_in_bytes()); + + if(arg_parser.get_int("v") == 1) + std::cout << "Uploading tensors to device..." << std::endl; + a_dev_buf.ToDevice(a_host.data()); + a_scale_dev_buf.ToDevice(a_m_k_scale_shuffle.data()); + b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); + b_scale_dev_buf.ToDevice(b_k_n_scale_shuffle.data()); + if(arg_parser.get_int("v") == 1) + std::cout << "Upload tensors done." << std::endl; } - else // PTPC unversal flat gemm + else // PTPC universal flat gemm { ck_tile::HostTensor per_token_scale(ck_tile::HostTensorDescriptor({M}, {1})); ck_tile::HostTensor per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1})); diff --git a/include/ck_tile/core/numeric/e8m0.hpp b/include/ck_tile/core/numeric/e8m0.hpp index 64a69dd761..c14c4867ac 100644 --- a/include/ck_tile/core/numeric/e8m0.hpp +++ b/include/ck_tile/core/numeric/e8m0.hpp @@ -64,6 +64,39 @@ struct e8m0_bexp_t __host__ __device__ constexpr bool is_nan() const { return data == nan_mask; } }; +// limits +template +struct numeric; + +template <> +struct numeric +{ + static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 + static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 + static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 + static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 + static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 + static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 + static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 + static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 + + CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } + CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } + CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } + CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } + CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } + CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } + CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_135() + { + return e8m0_bexp_t(binary_135); + } + CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_142() + { + return e8m0_bexp_t(binary_142); + } +}; +} + template <> struct numeric_traits { diff --git a/include/ck_tile/ops/common/tensor_layout.hpp b/include/ck_tile/ops/common/tensor_layout.hpp index bb905e6ab9..f828119b1a 100644 --- a/include/ck_tile/ops/common/tensor_layout.hpp +++ b/include/ck_tile/ops/common/tensor_layout.hpp @@ -22,6 +22,12 @@ struct ColumnMajor : public BaseTensorLayout { static constexpr const char* name = "ColumnMajor"; }; + +struct MFMA : public BaseTensorLayout +{ + static constexpr const char* name = "MFMA"; +}; + } // namespace gemm namespace convolution {