mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
update codes
This commit is contained in:
@@ -62,6 +62,85 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f,
|
||||
// 2-k)));
|
||||
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
@@ -371,9 +450,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
.insert(
|
||||
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
|
||||
.insert("verbosity", "0", "0: no log print, 1: print log");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
@@ -136,6 +136,7 @@ struct GemmBasicTypeConfig<ck_tile::pk_fp4_t>
|
||||
using BDataType = ck_tile::pk_fp4_t;
|
||||
using AccDataType = float;
|
||||
using XDataType = ck_tile::e8m0_bexp_t; // scale data type
|
||||
using XPackedData = int32_t; // packed scale data type
|
||||
using CDataType = ck_tile::half_t;
|
||||
}
|
||||
|
||||
|
||||
@@ -71,6 +71,10 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
using AscaleDataType = typename GemmBasicTypeConfig<PrecType>::XDataType;
|
||||
using BscaleDataType = typename GemmBasicTypeConfig<PrecType>::XDataType;
|
||||
|
||||
// B shuffled tensor
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
|
||||
// A, B scale tensors
|
||||
ck_tile::HostTensor<AscaleDataType> a_m_k_scale(ck_tile::host_tensor_descriptor(
|
||||
Scale_Padded_M, K / ScaleGranularityK, Scale_stride_AM, is_row_major(a_layout)));
|
||||
@@ -90,8 +94,62 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::FillUniformDistribution<AscaleDataType>{-1.f, 1.f}(a_m_k_scale);
|
||||
ck_tile::FillUniformDistribution<BscaleDataType>{-1.f, 1.f}(b_k_n_scale);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AscaleDataType>{1.f, 1.f}(a_m_k_scale);
|
||||
ck_tile::FillUniformDistribution<BscaleDataType>{1.f, 1.f}(b_k_n_scale);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AscaleDataType>{1.f, 1.f}(a_m_k_scale);
|
||||
ck_tile::FillUniformDistribution<BscaleDataType>{1.f, 1.f}(b_k_n_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_host.SetZero();
|
||||
b_origin_host.SetZero();
|
||||
a_m_k_scale.SetZero();
|
||||
b_k_n_scale.SetZero();
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
std::cout << "do shuffle for B, and A/B scale..." << std::endl;
|
||||
// do B pre-shuffle
|
||||
preShuffleBuffer(b_origin_host.mData.data(),
|
||||
b_shuffle_host.mData.data(),
|
||||
N,
|
||||
K,
|
||||
FlatmmConfig::N_Warp_Tile);
|
||||
|
||||
// do A, B scale pre-shuffle
|
||||
preShuffleScaleBuffer<ck_tile::is_same_v<a_layout, ck_tile::tensor_layout::gemm::RowMajor>>(
|
||||
a_m_k_scale.data(), a_m_k_scale_shuffle.data(), Scale_Padded_M, K / ScaleGranularityK);
|
||||
preShuffleScaleBuffer<
|
||||
ck_tile::is_same_v<b_layout, ck_tile::tensor_layout::gemm::ColumnMajor>>(
|
||||
b_k_n_scale.data(), b_k_n_scale_shuffle.data(), N, K / ScaleGranularityK);
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
std::cout << "Device memory allocation..." << std::endl;
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem a_scale_dev_buf(a_m_k_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_scale_dev_buf(b_k_n_scale.get_element_space_size_in_bytes());
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
std::cout << "Uploading tensors to device..." << std::endl;
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
a_scale_dev_buf.ToDevice(a_m_k_scale_shuffle.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
b_scale_dev_buf.ToDevice(b_k_n_scale_shuffle.data());
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
std::cout << "Upload tensors done." << std::endl;
|
||||
}
|
||||
else // PTPC unversal flat gemm
|
||||
else // PTPC universal flat gemm
|
||||
{
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
|
||||
|
||||
@@ -64,6 +64,39 @@ struct e8m0_bexp_t
|
||||
__host__ __device__ constexpr bool is_nan() const { return data == nan_mask; }
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<e8m0_bexp_t>
|
||||
{
|
||||
static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000
|
||||
static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110
|
||||
static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111
|
||||
static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111
|
||||
static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000
|
||||
static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010
|
||||
static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111
|
||||
static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_135()
|
||||
{
|
||||
return e8m0_bexp_t(binary_135);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_bexp_t Binary_142()
|
||||
{
|
||||
return e8m0_bexp_t(binary_142);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
template <>
|
||||
struct numeric_traits<e8m0_bexp_t>
|
||||
{
|
||||
|
||||
@@ -22,6 +22,12 @@ struct ColumnMajor : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "ColumnMajor";
|
||||
};
|
||||
|
||||
struct MFMA : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "MFMA";
|
||||
};
|
||||
|
||||
} // namespace gemm
|
||||
|
||||
namespace convolution {
|
||||
|
||||
Reference in New Issue
Block a user