mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
fix core dump issue, function is not correct.
This commit is contained in:
@@ -362,6 +362,55 @@ void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
|
||||
}
|
||||
}
|
||||
|
||||
#if 1
|
||||
template <class FlatmmConfig, bool KLast, class IterSrc, class IterDst>
|
||||
void preShuffleScale(const IterSrc src, IterDst dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = FlatmmConfig::N_Warp_Tile; // 16
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f,
|
||||
// 2-k)));
|
||||
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <class FlatmmConfig, class T>
|
||||
auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
|
||||
{
|
||||
@@ -390,6 +439,7 @@ auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4});
|
||||
}
|
||||
#endif
|
||||
|
||||
#include "run_mx_flatmm.inc"
|
||||
|
||||
|
||||
@@ -74,6 +74,11 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
ck_tile::HostTensor<ScaleDataType> scale_b(ck_tile::host_tensor_descriptor(
|
||||
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
|
||||
|
||||
ck_tile::HostTensor<ScaleDataType> scale_a_shuffled(ck_tile::host_tensor_descriptor(
|
||||
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<ScaleDataType> scale_b_shuffled(ck_tile::host_tensor_descriptor(
|
||||
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
@@ -187,8 +192,12 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K);
|
||||
|
||||
ck_tile::HostTensor<ScaleDataType> scale_a_shuffled = preShuffleScale<FlatmmConfig>(scale_a);
|
||||
ck_tile::HostTensor<ScaleDataType> scale_b_shuffled = preShuffleScale<FlatmmConfig>(scale_b);
|
||||
preShuffleScale<FlatmmConfig, is_row_major(a_layout)>(
|
||||
scale_a.begin(), scale_a_shuffled.begin(), M, K / ScaleGranularityK);
|
||||
preShuffleScale<FlatmmConfig, !is_row_major(b_layout)>(
|
||||
scale_b.begin(), scale_b_shuffled.begin(), N, K / ScaleGranularityK);
|
||||
// ck_tile::HostTensor<ScaleDataType> scale_a_shuffled = preShuffleScale<FlatmmConfig>(scale_a);
|
||||
// ck_tile::HostTensor<ScaleDataType> scale_b_shuffled = preShuffleScale<FlatmmConfig>(scale_b);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes());
|
||||
|
||||
Reference in New Issue
Block a user