[CK TILE] Support fp8/fp16 with pk_int4_t as data types for tensors A and B (#2805)

- Add support for tensor A/B in both fp16+pk_int4_t and fp8+pk_int4_t formats
- Implement A(bf8) B(i4) support in universal GEMM
- Use new implementation for i4 to fp8 conversion in Block Scale
This commit is contained in:
Cong Ma
2025-09-09 17:40:52 -06:00
committed by GitHub
parent 91178b4011
commit 82890192dd
15 changed files with 320 additions and 135 deletions

View File

@@ -181,9 +181,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr auto Scheduler = Traits::Scheduler;
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
static constexpr auto Scheduler = Traits::Scheduler;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
@@ -451,7 +449,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
scale_reg_f);
});
}
}
@@ -471,7 +469,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
scale_reg_f);
});
}
else
@@ -556,7 +554,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
reg_offset_for_row_data] +=
(c_warp_tensor
.get_thread_buffer()[reg_offset_for_row_data] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
scale_reg_f);
});
}
}

View File

@@ -179,9 +179,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr auto Scheduler = Traits::Scheduler;
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
static constexpr auto Scheduler = Traits::Scheduler;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
@@ -384,8 +382,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f *
kA_cvt_scale * kB_cvt_scale);
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
});
});