mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add support for full types (not just aliases) in type-print
- Added support for static_distributed_tensor<...> - Added support for tile_distribution<...> - Added support for tensor_view<...> - Added support for tensor_descriptor<...> Now type-print handles both: 1. Type aliases (::BottomTensorView, ::TensorDesc, etc.) 2. Full types with no runtime storage (static_distributed_tensor, etc.) Shows [from type] indicator for all type-only extractions. Example: type-print dst_tensor Works even when 'p dst_tensor' shows 'Cannot access memory' Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -29,7 +29,7 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
auto moe_shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
@@ -318,93 +318,94 @@ int run_moe_flatmm_example(int argc, char* argv[])
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
|
||||
if(gemm_kind == "gemm1_gate_up")
|
||||
// if(gemm_kind == "gemm1_gate_up")
|
||||
// {
|
||||
// if(prec_type == "fp8")
|
||||
// {
|
||||
// return run_moe_gemm_example_with_layouts<
|
||||
// ck_tile::fp8_t,
|
||||
// FlatmmConfig<ck_tile::fp8_t>,
|
||||
// ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(prec_type == "bf8")
|
||||
// {
|
||||
// return run_moe_gemm_example_with_layouts<
|
||||
// ck_tile::bf8_t,
|
||||
// FlatmmConfig<ck_tile::bf8_t>,
|
||||
// ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(prec_type == "bf16")
|
||||
// {
|
||||
// return run_moe_gemm_example_with_layouts<
|
||||
// ck_tile::bfloat16_t,
|
||||
// FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
// ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(prec_type == "fp16")
|
||||
// {
|
||||
// return run_moe_gemm_example_with_layouts<
|
||||
// ck_tile::half_t,
|
||||
// FlatmmConfig<ck_tile::half_t>,
|
||||
// ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
// }
|
||||
// }
|
||||
// else if(gemm_kind == "gemm1_gate_only")
|
||||
// {
|
||||
// if(prec_type == "fp8")
|
||||
// {
|
||||
// return run_moe_gemm_example_with_layouts<
|
||||
// ck_tile::fp8_t,
|
||||
// FlatmmConfig<ck_tile::fp8_t>,
|
||||
// ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(prec_type == "bf8")
|
||||
// {
|
||||
// return run_moe_gemm_example_with_layouts<
|
||||
// ck_tile::bf8_t,
|
||||
// FlatmmConfig<ck_tile::bf8_t>,
|
||||
// ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(prec_type == "bf16")
|
||||
// {
|
||||
// return run_moe_gemm_example_with_layouts<
|
||||
// ck_tile::bfloat16_t,
|
||||
// FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
// ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(prec_type == "fp16")
|
||||
// {
|
||||
// return run_moe_gemm_example_with_layouts<
|
||||
// ck_tile::half_t,
|
||||
// FlatmmConfig<ck_tile::half_t>,
|
||||
// ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
// }
|
||||
// }
|
||||
// else if(gemm_kind == "gemm2")
|
||||
if(gemm_kind == "gemm2")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm1_gate_only")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm2")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
// if(prec_type == "fp8")
|
||||
// {
|
||||
// return run_moe_gemm_example_with_layouts<ck_tile::fp8_t,
|
||||
// FlatmmConfig<ck_tile::fp8_t>,
|
||||
// ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(prec_type == "bf8")
|
||||
// {
|
||||
// return run_moe_gemm_example_with_layouts<ck_tile::bf8_t,
|
||||
// FlatmmConfig<ck_tile::bf8_t>,
|
||||
// ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
|
||||
@@ -103,7 +103,7 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
return shuffle_b<FlatmmConfig>(b_origin_host);
|
||||
return moe_shuffle_b<FlatmmConfig>(b_origin_host);
|
||||
}
|
||||
}();
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -304,7 +304,7 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
[[maybe_unused]] const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1 /*kbatch*/, max_accumulated_value);
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user