mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
updates, build pass
This commit is contained in:
@@ -284,12 +284,14 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
|
||||
const std::size_t ScaleBlockSize = K / scale_a.get_length(1);
|
||||
|
||||
HostTensor<AccDataType> a_m_k_scaled({M, K}, {K, 1});
|
||||
HostTensor<AccDataType> b_k_n_scaled({K, N}, {1, N});
|
||||
HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
|
||||
{std::size_t(K), std::size_t(1)});
|
||||
HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
|
||||
{std::size_t(1), std::size_t(K)});
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
for(std::size_t m = 0; m < M; ++m)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
@@ -297,7 +299,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
continue; // skip odd k
|
||||
|
||||
auto a_f4x2 = a_m_k(m, k);
|
||||
auto a_scale = scale_a(m, k / ScaleBlockSize);
|
||||
auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
|
||||
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
|
||||
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
|
||||
auto a_f4_lo =
|
||||
@@ -311,9 +313,9 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
}
|
||||
}
|
||||
|
||||
for(int n = 0; n < N; n++)
|
||||
for(std::size_t n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
for(std::size_t k = 0; k < K; k++)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
@@ -321,7 +323,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
continue; // skip odd k
|
||||
|
||||
auto b_f4x2 = b_k_n(k, n);
|
||||
auto b_scale = scale_b(k / ScaleBlockSize, n);
|
||||
auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
|
||||
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
|
||||
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
|
||||
auto b_f4_lo =
|
||||
|
||||
@@ -76,7 +76,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = FlatmmPipeline::BlockSize().x;
|
||||
constexpr int block_size = MXFlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
@@ -86,7 +86,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size,
|
||||
FlatmmPipeline,
|
||||
MXFlatmmKernel,
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
@@ -118,8 +118,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::packed_size;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::packed_size;
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr index_t MXdlPack = Problem::MXdlPack;
|
||||
static constexpr index_t NXdlPack = Problem::NXdlPack;
|
||||
@@ -629,25 +629,27 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_Buffer = thread_buffer<uint32_t, 4>;
|
||||
union UnionB
|
||||
{
|
||||
V4UInt_Buffer u = 0;
|
||||
MXFP4_Buffer mxfp4;
|
||||
} ub;
|
||||
// using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// // use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
// using V4UInt_Buffer = thread_buffer<uint32_t, 4>;
|
||||
// union UnionB
|
||||
// {
|
||||
// V4UInt_Buffer u = 0;
|
||||
// MXFP4_Buffer mxfp4;
|
||||
// } ub;
|
||||
|
||||
// pingpong buffer for B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
// pingpong buffer for Scale A and Scale B
|
||||
@@ -708,8 +710,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
@@ -785,8 +786,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -850,8 +850,10 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack),
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack));
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0],
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -914,8 +916,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -973,15 +974,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1047,8 +1049,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1101,15 +1102,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1176,15 +1178,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1245,15 +1248,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
|
||||
Reference in New Issue
Block a user