result right

This commit is contained in:
ltqin
2025-11-12 05:42:59 +00:00
parent 100dcc9ea2
commit df55e264ad
4 changed files with 61 additions and 54 deletions

View File

@@ -167,9 +167,9 @@ class BlockQuantizer
size_t seq_len = in.get_length(i_perm ? 2 : 1);
size_t hdim = in.get_length(3);
size_t num_blocks_ = (seq_len + block_size_ - 1) / block_size_;
std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len
<< " hdim: " << hdim << " dtype_max: " << dtype_max
<< " num_blocks_: " << num_blocks_ << std::endl;
// std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len
// << " hdim: " << hdim << " dtype_max: " << dtype_max
// << " num_blocks_: " << num_blocks_ << std::endl;
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(0.5f, 2.0f);
@@ -215,9 +215,9 @@ class BlockQuantizer
}
// save scale to tensor
block_scale(b, h, block) = 1.0f / scale;
std::cout << "block: " << block << " scale: " << scale
<< " max_value: " << max_value << " block_scale: " << block_scale
<< std::endl;
// std::cout << "block: " << block << " scale: " << scale
// << " max_value: " << max_value << " block_scale: " << block_scale
// << std::endl;
}
}
}
@@ -806,32 +806,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
float scale_o = 1.f;
if(quant == 2)
{
ck_tile::FillUniformDistributionIntegerValue<float>{1.f, 10.f, next_seed()}(q_scale);
ck_tile::FillUniformDistributionIntegerValue<float>{1.f, 10.f, next_seed()}(k_scale);
ck_tile::FillUniformDistributionIntegerValue<float>{1.f, 10.f, next_seed()}(v_scale);
{ //debug info
std::cout << "q_scale: " << q_scale << " k_scale: " << k_scale
<< " v_scale: " << v_scale << std::endl;
ck_tile::HostTensor<float> q_host_deq(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<float> k_host_deq(
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<float> v_host_deq(
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
BlockQuantizer quantizer(i_perm);
quantizer.dequantize(q_host, q_host_deq, q_scale, block_scale_m_);
quantizer.dequantize(k_host, k_host_deq, k_scale, block_scale_n_);
quantizer.dequantize(v_host, v_host_deq, v_scale, block_scale_n_);
q_host_deq.savetxt("./q_deq.txt");
k_host_deq.savetxt("./k_deq.txt");
v_host_deq.savetxt("./v_deq.txt");
}
BlockQuantizer quantizer(i_perm);
quantizer.quantize(q_host, q_host, q_scale, block_scale_m_);
quantizer.quantize(k_host, k_host, k_scale, block_scale_n_);
quantizer.quantize(v_host, v_host, v_scale, block_scale_n_);
q_host.savetxt("./q_quant.txt");
k_host.savetxt("./k_quant.txt");
v_host.savetxt("./v_quant.txt");
}
else if(quant == 1)
{
@@ -1737,10 +1718,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
s_host_ref,
ck_tile::idx_identity{},
ck_tile::idx_identity{},
[&q_scale, &k_scale, scale_s, wb](auto idx, auto value) {
[&](auto idx, auto value) {
return value * scale_s *
q_scale(wb, std::get<0>(idx), std::get<1>(idx) / 128) *
k_scale(wb, std::get<0>(idx), std::get<2>(idx) / 128);
k_scale(wb, std::get<0>(idx) / nr, std::get<2>(idx) / 128);
});
}
else
@@ -1919,10 +1900,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
v_host_ref,
o_host_ref,
ck_tile::idx_identity{},
[&v_scale, wb](auto idx, auto value) {
// idx: b, m, n, k --> h, sq, d, sk
[&v_scale, wb, nr](auto idx, auto value) {
return ck_tile::type_convert<float>(value) *
v_scale(wb, std::get<0>(idx), std::get<2>(idx) / 128);
v_scale(wb, std::get<0>(idx) / nr, std::get<2>(idx) / 128);
},
ck_tile::idx_identity{});
}

View File

@@ -1728,12 +1728,13 @@ struct FmhaFwdKernel
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
kargs.scale_s * q_scale,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
q_scale,
k_scale_ptr,
v_scale_ptr,
kargs.block_scale_m,

View File

@@ -176,6 +176,7 @@ struct BlockFmhaPipelineQRKSVS
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float q_scale,
const float* k_scale_ptr,
const float* v_scale_ptr,
index_t,
@@ -407,7 +408,7 @@ struct BlockFmhaPipelineQRKSVS
{
if(k_scale_ptr)
{
tile_elementwise_inout([k_scale](auto& x) { x = x * k_scale; }, s_acc);
tile_elementwise_inout([q_scale, k_scale](auto& x) { x = x * q_scale * k_scale; }, s_acc);
}
}
@@ -640,18 +641,39 @@ struct BlockFmhaPipelineQRKSVS
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
auto wrapper_gemm1 = [&](auto& acc, auto a, auto b) {
if constexpr(kDoFp8StaticQuant)
{
auto acc0 = gemm_1(a, b);
tile_elementwise_inout(
[&v_scale](auto& o, auto o0) {
// asm volatile(";wrapper_gemm1\n\tv_mul_f32_e32 %0, %1, %2"
// : "=v"(o)
// : "s"(v_scale), "v"(o0)
// : "memory");
o += o0 * v_scale;
},
acc,
acc0);
}
else
{
gemm_1(acc, a, b);
};
};
// STAGE 3, KV gemm
auto o_acc_tmp = decltype(o_acc){};
clear_tile(o_acc_tmp);
// auto o_acc_tmp = decltype(o_acc){};
// clear_tile(o_acc_tmp);
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
gemm_1(o_acc_tmp,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
wrapper_gemm1(o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -675,20 +697,21 @@ struct BlockFmhaPipelineQRKSVS
// tail
{
block_sync_lds();
gemm_1(o_acc_tmp,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
v_lds_window);
wrapper_gemm1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
v_lds_window);
block_sync_lds();
}
// o_acc += o_acc_tmp;
// o_acc += tile_elementwise_in(scale(1.0f / v_scale), o_acc_tmp);
// ck_tile::ignore = v_scale;
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) += o_acc_tmp(i_j_idx) * v_scale;
});
});
// sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
// sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// o_acc(i_j_idx) += o_acc_tmp(i_j_idx) * v_scale;
// });
// });
} while(++i_total_loops < num_total_loop);
@@ -796,6 +819,7 @@ struct BlockFmhaPipelineQRKSVS
block_indices,
smem_ptr,
dropout,
1.0f,
nullptr,
nullptr,
128,

View File

@@ -186,6 +186,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float,
const float*,
const float*,
index_t,
@@ -850,6 +851,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_indices,
smem_ptr,
dropout,
1.0f,
nullptr,
nullptr,
128,