fix core dump issue, function is not correct.

This commit is contained in:
mtgu0705
2025-09-15 04:02:05 -05:00
parent 9ceb3fd508
commit 8052bea019
2 changed files with 61 additions and 2 deletions

View File

@@ -362,6 +362,55 @@ void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
}
}
#if 1
template <class FlatmmConfig, bool KLast, class IterSrc, class IterDst>
void preShuffleScale(const IterSrc src, IterDst dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = FlatmmConfig::N_Warp_Tile; // 16
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f,
// 2-k)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
#else
template <class FlatmmConfig, class T>
auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
{
@@ -390,6 +439,7 @@ auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4});
}
#endif
#include "run_mx_flatmm.inc"

View File

@@ -74,6 +74,11 @@ int run_mx_flatmm_with_layouts(int argc,
ck_tile::HostTensor<ScaleDataType> scale_b(ck_tile::host_tensor_descriptor(
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<ScaleDataType> scale_a_shuffled(ck_tile::host_tensor_descriptor(
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<ScaleDataType> scale_b_shuffled(ck_tile::host_tensor_descriptor(
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
@@ -187,8 +192,12 @@ int run_mx_flatmm_with_layouts(int argc,
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K);
ck_tile::HostTensor<ScaleDataType> scale_a_shuffled = preShuffleScale<FlatmmConfig>(scale_a);
ck_tile::HostTensor<ScaleDataType> scale_b_shuffled = preShuffleScale<FlatmmConfig>(scale_b);
preShuffleScale<FlatmmConfig, is_row_major(a_layout)>(
scale_a.begin(), scale_a_shuffled.begin(), M, K / ScaleGranularityK);
preShuffleScale<FlatmmConfig, !is_row_major(b_layout)>(
scale_b.begin(), scale_b_shuffled.begin(), N, K / ScaleGranularityK);
// ck_tile::HostTensor<ScaleDataType> scale_a_shuffled = preShuffleScale<FlatmmConfig>(scale_a);
// ck_tile::HostTensor<ScaleDataType> scale_b_shuffled = preShuffleScale<FlatmmConfig>(scale_b);
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes());