Fix conflicts

This commit is contained in:
Rostyslav Geyyer
2025-04-30 20:03:27 +00:00
parent 045a71bc14
commit 0fc2f528e0
2 changed files with 3 additions and 2 deletions

View File

@@ -653,6 +653,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
const bf8x32_t& reg_b,

View File

@@ -981,10 +981,10 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
auto fragXb = BScaleFragT{};
// Load the inputs.
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, ScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
a, xa, fragXa);
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, ScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
b, xb, fragXb);
// Scaled Matrix multiply-accumulate using MFMA units