update codes

This commit is contained in:
mtgu0705
2025-07-29 03:43:52 -05:00
parent 4ce42011b2
commit 35ac1894ba
5 changed files with 181 additions and 4 deletions

View File

@@ -62,6 +62,85 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f,
// 2-k)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
@@ -371,9 +450,9 @@ auto create_args(int argc, char* argv[])
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
.insert(
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
.insert("verbosity", "0", "0: no log print, 1: print log");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

View File

@@ -136,6 +136,7 @@ struct GemmBasicTypeConfig<ck_tile::pk_fp4_t>
using BDataType = ck_tile::pk_fp4_t;
using AccDataType = float;
using XDataType = ck_tile::e8m0_bexp_t; // scale data type
using XPackedData = int32_t; // packed scale data type
using CDataType = ck_tile::half_t;
}

View File

@@ -71,6 +71,10 @@ int run_flatmm_example_with_layouts(int argc,
using AscaleDataType = typename GemmBasicTypeConfig<PrecType>::XDataType;
using BscaleDataType = typename GemmBasicTypeConfig<PrecType>::XDataType;
// B shuffled tensor
ck_tile::HostTensor<BDataType> b_shuffle_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
// A, B scale tensors
ck_tile::HostTensor<AscaleDataType> a_m_k_scale(ck_tile::host_tensor_descriptor(
Scale_Padded_M, K / ScaleGranularityK, Scale_stride_AM, is_row_major(a_layout)));
@@ -90,8 +94,62 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<AscaleDataType>{-1.f, 1.f}(a_m_k_scale);
ck_tile::FillUniformDistribution<BscaleDataType>{-1.f, 1.f}(b_k_n_scale);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
ck_tile::FillUniformDistribution<AscaleDataType>{1.f, 1.f}(a_m_k_scale);
ck_tile::FillUniformDistribution<BscaleDataType>{1.f, 1.f}(b_k_n_scale);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<AscaleDataType>{1.f, 1.f}(a_m_k_scale);
ck_tile::FillUniformDistribution<BscaleDataType>{1.f, 1.f}(b_k_n_scale);
}
else
{
a_host.SetZero();
b_origin_host.SetZero();
a_m_k_scale.SetZero();
b_k_n_scale.SetZero();
}
if(arg_parser.get_int("v") == 1)
std::cout << "do shuffle for B, and A/B scale..." << std::endl;
// do B pre-shuffle
preShuffleBuffer(b_origin_host.mData.data(),
b_shuffle_host.mData.data(),
N,
K,
FlatmmConfig::N_Warp_Tile);
// do A, B scale pre-shuffle
preShuffleScaleBuffer<ck_tile::is_same_v<a_layout, ck_tile::tensor_layout::gemm::RowMajor>>(
a_m_k_scale.data(), a_m_k_scale_shuffle.data(), Scale_Padded_M, K / ScaleGranularityK);
preShuffleScaleBuffer<
ck_tile::is_same_v<b_layout, ck_tile::tensor_layout::gemm::ColumnMajor>>(
b_k_n_scale.data(), b_k_n_scale_shuffle.data(), N, K / ScaleGranularityK);
if(arg_parser.get_int("v") == 1)
std::cout << "Device memory allocation..." << std::endl;
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem a_scale_dev_buf(a_m_k_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_scale_dev_buf(b_k_n_scale.get_element_space_size_in_bytes());
if(arg_parser.get_int("v") == 1)
std::cout << "Uploading tensors to device..." << std::endl;
a_dev_buf.ToDevice(a_host.data());
a_scale_dev_buf.ToDevice(a_m_k_scale_shuffle.data());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
b_scale_dev_buf.ToDevice(b_k_n_scale_shuffle.data());
if(arg_parser.get_int("v") == 1)
std::cout << "Upload tensors done." << std::endl;
}
else // PTPC unversal flat gemm
else // PTPC universal flat gemm
{
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));

View File

@@ -64,6 +64,39 @@ struct e8m0_bexp_t
__host__ __device__ constexpr bool is_nan() const { return data == nan_mask; }
};
// limits
template <class T>
struct numeric;
template <>
struct numeric<e8m0_bexp_t>
{
static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000
static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110
static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111
static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111
static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000
static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010
static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111
static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); }
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); }
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); }
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); }
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); }
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); }
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_135()
{
return e8m0_bexp_t(binary_135);
}
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_142()
{
return e8m0_bexp_t(binary_142);
}
};
}
template <>
struct numeric_traits<e8m0_bexp_t>
{

View File

@@ -22,6 +22,12 @@ struct ColumnMajor : public BaseTensorLayout
{
static constexpr const char* name = "ColumnMajor";
};
struct MFMA : public BaseTensorLayout
{
static constexpr const char* name = "MFMA";
};
} // namespace gemm
namespace convolution {