mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
More clean up
This commit is contained in:
@@ -103,105 +103,6 @@ static constexpr int32_t vectorSize(const VecT&)
|
||||
return scalar_type<VecT>::vector_size;
|
||||
}
|
||||
|
||||
// Define a load function for input A blocks:
|
||||
// Size: (BLOCK_M x BLOCK_K)
|
||||
// - Data is in column major format
|
||||
// - Rows are loaded in contiguous chunks that map to corresponding microscales
|
||||
// - Each row is loaded in chunks of size 16 and each thread loads 32 elements
|
||||
template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
|
||||
__device__ AFragT load_A_col_major(AType const* input_ptr)
|
||||
{
|
||||
// clang-format off
|
||||
// Register Mapping for 16x128 for FP8: || Register Mapping for 32x64 for FP8:
|
||||
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
|
||||
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
||||
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
|
||||
// Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] |
|
||||
// Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] |
|
||||
// Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] |
|
||||
// Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] |
|
||||
// Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] |
|
||||
// Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] |
|
||||
// Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] |
|
||||
// Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] |
|
||||
// Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] |
|
||||
// Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] |
|
||||
// Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] |
|
||||
// Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] |
|
||||
// Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] |
|
||||
// Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] |
|
||||
// Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] |
|
||||
// Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] |
|
||||
// Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] |
|
||||
// Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] |
|
||||
// Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] |
|
||||
// Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] |
|
||||
// Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] |
|
||||
// Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] |
|
||||
// Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] |
|
||||
// Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] |
|
||||
// Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] |
|
||||
// Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] |
|
||||
// Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] |
|
||||
// Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] |
|
||||
// Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] |
|
||||
// Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
|
||||
// Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
|
||||
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
|
||||
// clang-format on
|
||||
|
||||
static constexpr int32_t WAVE_SIZE = 64;
|
||||
|
||||
// Here we want to load from rows of A in chunks of 16 elements each.
|
||||
static constexpr uint32_t chunk_size = 16;
|
||||
|
||||
// each chunk is separated by offset
|
||||
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M;
|
||||
|
||||
// To start the loading process, let's visualize in 2D coords.
|
||||
// Each thread will load 32 elements.
|
||||
// We need to know where they start, and where the next elements are.
|
||||
auto startCoord2D =
|
||||
std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15}
|
||||
(threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48}
|
||||
|
||||
auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows
|
||||
auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row
|
||||
|
||||
// Flatten to 1D col_major offsets.
|
||||
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
||||
|
||||
// BLOCK_M is a stride in A matrix
|
||||
auto startOffset = col_major(startCoord2D, BLOCK_M);
|
||||
auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M);
|
||||
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M);
|
||||
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarFragT =
|
||||
vector_type<ARawT,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
|
||||
AScalarFragT fragA{};
|
||||
|
||||
constexpr index_t num_chunks =
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 1 : 2);
|
||||
|
||||
#pragma unroll
|
||||
for(int chunk = 0; chunk < num_chunks; chunk++)
|
||||
{
|
||||
#pragma unroll
|
||||
for(uint32_t i = 0; i < chunk_size; i++)
|
||||
{
|
||||
fragA[chunk * chunk_size + i] =
|
||||
bit_cast<ARawT>(input_ptr[startOffset + chunk * kMajorOffset + i * kMinorOffset]);
|
||||
}
|
||||
}
|
||||
|
||||
return fragA;
|
||||
}
|
||||
|
||||
// Define a load function for input A blocks:
|
||||
// Size: (BLOCK_M x BLOCK_K)
|
||||
// - Data is in row major format
|
||||
@@ -1144,16 +1045,17 @@ struct TestMXMFMA
|
||||
case 1:
|
||||
// results in C = {K}
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
||||
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
|
||||
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{512.0f}});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f / 512}});
|
||||
break;
|
||||
case 2:
|
||||
// expect small round off errors
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{512.0f}});
|
||||
a_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f / 512}});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 129});
|
||||
break;
|
||||
case 3:
|
||||
// expect small round off errors
|
||||
|
||||
Reference in New Issue
Block a user