mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
still debugging: speculating soemthing with cshuffle epilogue
This commit is contained in:
@@ -18,9 +18,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# gemm_bquant_quantgrouped_fp8i4.cpp
|
||||
# gemm_bquant_quantgrouped_bf16mxfp4.cpp
|
||||
# gemm_bquant_quantgrouped_bf8.cpp
|
||||
gemm_bquant_quantgrouped_fp8.cpp
|
||||
# gemm_bquant_quantgrouped_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant.cpp
|
||||
# gemm_bquant_quantgrouped_preshufflequant.cpp
|
||||
# gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp
|
||||
# gemm_quant_rowcol.cpp
|
||||
# gemm_quant_tensor.cpp
|
||||
|
||||
@@ -89,8 +89,8 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_fp8_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_bf8_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
@@ -101,8 +101,8 @@ void bquant_quantgrouped_fp8_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void quant_rowcol_instance_factory(
|
||||
@@ -126,13 +126,13 @@ int main(int argc, char* argv[])
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
|
||||
// aquant_quantgrouped_instance_factory(lut);
|
||||
// aquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
// bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
// bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
// bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
// bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
// bquant_quantgrouped_bf16fp4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
// bquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
// bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut);
|
||||
// quant_rowcol_instance_factory(lut);
|
||||
// quant_tensor_instance_factory(lut);
|
||||
|
||||
@@ -521,12 +521,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f,
|
||||
3.0f /*, fill_seed(gen)*/}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{1.0f, 1.0f /*, fill_seed(gen)*/}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f /*, fill_seed(gen)*/}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
@@ -572,7 +573,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
{
|
||||
for(int j = 0; j < BQN; j++)
|
||||
{
|
||||
(*bq_tensor_ptr)(i, j) = value;
|
||||
(*bq_tensor_ptr)(i, j) = 1.0;
|
||||
value += static_cast<BQDataType>(1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,18 +240,19 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
|
||||
|
||||
printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d, nIter: "
|
||||
"%d, KPerBlockBQ: %d, "
|
||||
"kQScale: %d, scale_reg: %f, "
|
||||
"scale_reg_f: %f\n",
|
||||
get_block_id(),
|
||||
get_warp_id(),
|
||||
get_thread_id(),
|
||||
static_cast<int>(nIter),
|
||||
KPerBlockBQ,
|
||||
static_cast<int>(kQScale),
|
||||
scale_reg,
|
||||
scale_reg_f);
|
||||
// printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d,
|
||||
// nIter: "
|
||||
// "%d, KPerBlockBQ: %d, "
|
||||
// "kQScale: %d, scale_reg: %f, "
|
||||
// "scale_reg_f: %f\n",
|
||||
// get_block_id(),
|
||||
// get_warp_id(),
|
||||
// get_thread_id(),
|
||||
// static_cast<int>(nIter),
|
||||
// KPerBlockBQ,
|
||||
// static_cast<int>(kQScale),
|
||||
// scale_reg,
|
||||
// scale_reg_f);
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
@@ -263,14 +264,14 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
});
|
||||
});
|
||||
});
|
||||
// auto c_thread_buffer = c_block_tensor.get_thread_buffer();
|
||||
auto c_thread_buffer = c_block_tensor.get_thread_buffer();
|
||||
// printf("C Data:\n");
|
||||
// for(index_t i = 0; i < c_thread_buffer.size(); ++i)
|
||||
// {
|
||||
// auto value = c_thread_buffer.get(i);
|
||||
// auto float_value = type_convert<float>(value);
|
||||
// printf(" [%d] = %f\n", i, float_value);
|
||||
// }
|
||||
for(index_t i = 0; i < c_thread_buffer.size(); ++i)
|
||||
{
|
||||
auto value = c_thread_buffer.get(i);
|
||||
auto float_value = type_convert<float>(value);
|
||||
printf("[%d] = %f\n", i, float_value);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -390,10 +390,6 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / loop_count;
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("iCounter:%d \n\n ", iCounter);
|
||||
}
|
||||
while(iCounter > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
Reference in New Issue
Block a user