mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
result right
This commit is contained in:
@@ -167,9 +167,9 @@ class BlockQuantizer
|
||||
size_t seq_len = in.get_length(i_perm ? 2 : 1);
|
||||
size_t hdim = in.get_length(3);
|
||||
size_t num_blocks_ = (seq_len + block_size_ - 1) / block_size_;
|
||||
std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len
|
||||
<< " hdim: " << hdim << " dtype_max: " << dtype_max
|
||||
<< " num_blocks_: " << num_blocks_ << std::endl;
|
||||
// std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len
|
||||
// << " hdim: " << hdim << " dtype_max: " << dtype_max
|
||||
// << " num_blocks_: " << num_blocks_ << std::endl;
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<float> dis(0.5f, 2.0f);
|
||||
@@ -215,9 +215,9 @@ class BlockQuantizer
|
||||
}
|
||||
// save scale to tensor
|
||||
block_scale(b, h, block) = 1.0f / scale;
|
||||
std::cout << "block: " << block << " scale: " << scale
|
||||
<< " max_value: " << max_value << " block_scale: " << block_scale
|
||||
<< std::endl;
|
||||
// std::cout << "block: " << block << " scale: " << scale
|
||||
// << " max_value: " << max_value << " block_scale: " << block_scale
|
||||
// << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -806,32 +806,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
float scale_o = 1.f;
|
||||
if(quant == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<float>{1.f, 10.f, next_seed()}(q_scale);
|
||||
ck_tile::FillUniformDistributionIntegerValue<float>{1.f, 10.f, next_seed()}(k_scale);
|
||||
ck_tile::FillUniformDistributionIntegerValue<float>{1.f, 10.f, next_seed()}(v_scale);
|
||||
|
||||
{ //debug info
|
||||
std::cout << "q_scale: " << q_scale << " k_scale: " << k_scale
|
||||
<< " v_scale: " << v_scale << std::endl;
|
||||
|
||||
ck_tile::HostTensor<float> q_host_deq(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<float> k_host_deq(
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
ck_tile::HostTensor<float> v_host_deq(
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
BlockQuantizer quantizer(i_perm);
|
||||
quantizer.dequantize(q_host, q_host_deq, q_scale, block_scale_m_);
|
||||
quantizer.dequantize(k_host, k_host_deq, k_scale, block_scale_n_);
|
||||
quantizer.dequantize(v_host, v_host_deq, v_scale, block_scale_n_);
|
||||
q_host_deq.savetxt("./q_deq.txt");
|
||||
k_host_deq.savetxt("./k_deq.txt");
|
||||
v_host_deq.savetxt("./v_deq.txt");
|
||||
}
|
||||
BlockQuantizer quantizer(i_perm);
|
||||
quantizer.quantize(q_host, q_host, q_scale, block_scale_m_);
|
||||
quantizer.quantize(k_host, k_host, k_scale, block_scale_n_);
|
||||
quantizer.quantize(v_host, v_host, v_scale, block_scale_n_);
|
||||
q_host.savetxt("./q_quant.txt");
|
||||
k_host.savetxt("./k_quant.txt");
|
||||
v_host.savetxt("./v_quant.txt");
|
||||
}
|
||||
else if(quant == 1)
|
||||
{
|
||||
@@ -1737,10 +1718,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
s_host_ref,
|
||||
ck_tile::idx_identity{},
|
||||
ck_tile::idx_identity{},
|
||||
[&q_scale, &k_scale, scale_s, wb](auto idx, auto value) {
|
||||
[&](auto idx, auto value) {
|
||||
return value * scale_s *
|
||||
q_scale(wb, std::get<0>(idx), std::get<1>(idx) / 128) *
|
||||
k_scale(wb, std::get<0>(idx), std::get<2>(idx) / 128);
|
||||
k_scale(wb, std::get<0>(idx) / nr, std::get<2>(idx) / 128);
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -1919,10 +1900,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::idx_identity{},
|
||||
[&v_scale, wb](auto idx, auto value) {
|
||||
// idx: b, m, n, k --> h, sq, d, sk
|
||||
[&v_scale, wb, nr](auto idx, auto value) {
|
||||
return ck_tile::type_convert<float>(value) *
|
||||
v_scale(wb, std::get<0>(idx), std::get<2>(idx) / 128);
|
||||
v_scale(wb, std::get<0>(idx) / nr, std::get<2>(idx) / 128);
|
||||
},
|
||||
ck_tile::idx_identity{});
|
||||
}
|
||||
|
||||
@@ -1728,12 +1728,13 @@ struct FmhaFwdKernel
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s * q_scale,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
q_scale,
|
||||
k_scale_ptr,
|
||||
v_scale_ptr,
|
||||
kargs.block_scale_m,
|
||||
|
||||
@@ -176,6 +176,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float q_scale,
|
||||
const float* k_scale_ptr,
|
||||
const float* v_scale_ptr,
|
||||
index_t,
|
||||
@@ -407,7 +408,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
if(k_scale_ptr)
|
||||
{
|
||||
tile_elementwise_inout([k_scale](auto& x) { x = x * k_scale; }, s_acc);
|
||||
tile_elementwise_inout([q_scale, k_scale](auto& x) { x = x * q_scale * k_scale; }, s_acc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -640,18 +641,39 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
|
||||
auto wrapper_gemm1 = [&](auto& acc, auto a, auto b) {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
auto acc0 = gemm_1(a, b);
|
||||
tile_elementwise_inout(
|
||||
[&v_scale](auto& o, auto o0) {
|
||||
// asm volatile(";wrapper_gemm1\n\tv_mul_f32_e32 %0, %1, %2"
|
||||
// : "=v"(o)
|
||||
// : "s"(v_scale), "v"(o0)
|
||||
// : "memory");
|
||||
o += o0 * v_scale;
|
||||
},
|
||||
acc,
|
||||
acc0);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_1(acc, a, b);
|
||||
};
|
||||
};
|
||||
// STAGE 3, KV gemm
|
||||
auto o_acc_tmp = decltype(o_acc){};
|
||||
clear_tile(o_acc_tmp);
|
||||
// auto o_acc_tmp = decltype(o_acc){};
|
||||
// clear_tile(o_acc_tmp);
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc_tmp,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
wrapper_gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -675,20 +697,21 @@ struct BlockFmhaPipelineQRKSVS
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc_tmp,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_window);
|
||||
wrapper_gemm1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
// o_acc += o_acc_tmp;
|
||||
// o_acc += tile_elementwise_in(scale(1.0f / v_scale), o_acc_tmp);
|
||||
// ck_tile::ignore = v_scale;
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) += o_acc_tmp(i_j_idx) * v_scale;
|
||||
});
|
||||
});
|
||||
// sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// o_acc(i_j_idx) += o_acc_tmp(i_j_idx) * v_scale;
|
||||
// });
|
||||
// });
|
||||
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
@@ -796,6 +819,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
1.0f,
|
||||
nullptr,
|
||||
nullptr,
|
||||
128,
|
||||
|
||||
@@ -186,6 +186,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float,
|
||||
const float*,
|
||||
const float*,
|
||||
index_t,
|
||||
@@ -850,6 +851,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
1.0f,
|
||||
nullptr,
|
||||
nullptr,
|
||||
128,
|
||||
|
||||
Reference in New Issue
Block a user