mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
scale bf16
This commit is contained in:
@@ -638,7 +638,8 @@ struct MoeFlatmmKernel
|
||||
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
|
||||
// using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
|
||||
using ScaleType = std::conditional_t<MXFP4_Pipeline, bfloat16_t, float>;
|
||||
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
|
||||
|
||||
@@ -189,12 +189,12 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
struct DequantizeMxFP4 {
|
||||
|
||||
CK_TILE_DEVICE auto operator()(statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp>& dequant_B_n,
|
||||
const auto& quant_weight_tensor,
|
||||
const auto& scale_tensor,
|
||||
auto xdl_nIter,
|
||||
auto xdl_kIter) {
|
||||
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp>& dequant_B_n,
|
||||
[[maybe_unused]] const auto& quant_weight_tensor,
|
||||
[[maybe_unused]] const auto& scale_tensor,
|
||||
[[maybe_unused]] auto xdl_nIter,
|
||||
[[maybe_unused]] auto xdl_kIter) {
|
||||
#if 0
|
||||
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
|
||||
|
||||
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
|
||||
@@ -258,6 +258,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
.at(i),
|
||||
bit_cast<float>(uscale)));
|
||||
});
|
||||
#endif
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
@@ -281,9 +282,13 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
|
||||
/*
|
||||
constexpr int float_mantissa = 23;
|
||||
|
||||
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
|
||||
*/
|
||||
|
||||
float scale_f32 = type_cast<float>(scale.data);
|
||||
|
||||
using ComputeV2Type =
|
||||
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
|
||||
@@ -308,7 +313,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
pk_int4_to_compute_v2(
|
||||
bit_cast<thread_buffer<pk_int4_t, 4>>(quant_weight_tensor[quant_idx_k])
|
||||
.at(i),
|
||||
bit_cast<float>(uscale)));
|
||||
scale_f32));
|
||||
});
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user