From 39bc8453c6a8fde703a9c506a7262edd5b71be30 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" <210906412+assistant-librarian[bot]@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:17:20 +0100 Subject: [PATCH] [CK_TILE] add tf32 support (#4302) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in CK_TILE on gfx942 and gfx950. ## Checklist Please put an into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run on all changed files - [ ] Any dependent changes have been merged ## Discussion --- 🔁 Imported from [ROCm/composable_kernel#3538](https://github.com/ROCm/composable_kernel/pull/3538) 🧑‍💻 Originally authored by @yingluAMD --------- Co-authored-by: yingluAMD Co-authored-by: assistant-librarian[bot] Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- example/ck_tile/03_gemm/gemm_basic.cpp | 11 + .../ck_tile/03_gemm/gemm_basic_invoker.hpp | 67 +++-- example/ck_tile/03_gemm/gemm_utils.hpp | 23 +- example/ck_tile/03_gemm/run_gemm_example.inc | 92 ++++--- include/ck_tile/core/numeric/bfloat16.hpp | 59 +++++ .../ck_tile/core/numeric/ext_vector_base.hpp | 80 ++++++ include/ck_tile/core/numeric/numeric.hpp | 24 ++ include/ck_tile/core/numeric/type_convert.hpp | 38 +++ include/ck_tile/core/numeric/vector_type.hpp | 70 +---- include/ck_tile/core/utility/type_traits.hpp | 5 + include/ck_tile/host/check_err.hpp | 22 +- .../ck_tile/host/reference/reference_gemm.hpp | 218 +++++++++++---- .../ops/epilogue/cshuffle_epilogue.hpp | 41 ++- .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 11 +- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 5 + ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 30 ++- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 38 +-- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 22 ++ .../warp/warp_gemm_attribute_mfma_impl.hpp | 135 ++++++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 16 ++ test/ck_tile/data_type/CMakeLists.txt | 2 + .../data_type/test_bf16_f32_convert.cpp | 248 ++++++++++++++++++ test/ck_tile/data_type/test_tf32.cpp | 86 ++++++ .../epilogue/test_cshuffle_epilogue_util.hpp | 4 +- test/ck_tile/gemm/CMakeLists.txt | 4 + .../gemm/test_gemm_pipeline_kernel_types.hpp | 10 + .../gemm/test_gemm_pipeline_prec_types.hpp | 2 + .../gemm/test_gemm_pipeline_tf32_mem.cpp | 22 ++ test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 33 ++- .../ops/gemm/gemm_universal/gemm_common.hpp | 6 + 30 files changed, 1164 insertions(+), 260 deletions(-) create mode 100644 include/ck_tile/core/numeric/ext_vector_base.hpp create mode 100644 test/ck_tile/data_type/test_bf16_f32_convert.cpp create mode 100644 test/ck_tile/data_type/test_tf32.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_tf32_mem.cpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index e30ae8319f..7d6a2adc38 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -41,6 +41,17 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) return run_gemm_example_prec_type( a_layout, b_layout, arg_parser); } +#ifdef CK_GFX950_SUPPORT + else if(data_type == "tf32") + { + // Pass tf32_t as A/B types - epilogue auto-detects and maps to float for data operations + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); + } +#endif else if(data_type == "fp8") { return run_gemm_example_prec_type static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { + // ADataTypeCompute: compute type (tf32_t for TF32 mode, used for warp gemm selection) + // ADataTypeBuf: buffer/storage type (fp32 when tf32) + using ADataTypeCompute = ADataType_; + using BDataTypeCompute = BDataType_; + using ADataTypeBuf = ck_tile::if_select_t; + using BDataTypeBuf = ck_tile::if_select_t; + + if constexpr(std::is_same_v) + { + static_assert(std::is_same_v, + "ADataTypeCompute and BDataTypeCompute must be the same"); + } + if constexpr(Persistent) { std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; } + constexpr bool is_fp32_input = std::is_same_v; + constexpr bool is_tf32_compute = std::is_same_v; + // This part comes from the Codegen - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t M_Tile = is_fp32_input ? 128 : 256; + constexpr ck_tile::index_t N_Tile = is_fp32_input ? 128 : 256; constexpr ck_tile::index_t K_Tile = 64; #if CK_TILE_USE_WMMA @@ -38,12 +54,14 @@ struct BasicInvoker constexpr ck_tile::index_t N_Warp_Tile = 16; constexpr ck_tile::index_t K_Warp_Tile = 16; #else - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; + // gfx950: fp32 uses 16x16x16 tile (native MFMA) + // tf32 uses 32x32x16 tile (3x bf16 32x32x16 MFMA emulation) + constexpr ck_tile::index_t M_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2; + constexpr ck_tile::index_t N_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t M_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32; + constexpr ck_tile::index_t N_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32; constexpr ck_tile::index_t K_Warp_Tile = 16; #endif @@ -61,17 +79,21 @@ struct BasicInvoker BLayout, CLayout>; - using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + using CodegenPipelineProblem = + ck_tile::GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, @@ -112,7 +134,7 @@ struct BasicInvoker } // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; + std::unique_ptr> rotating_mem_ptr; std::function preprocess; auto clear_gemm_output = [&]() { @@ -125,16 +147,21 @@ struct BasicInvoker { std::cout << "Flushing cache..." << std::endl; - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( args.K, args.N, args.stride_B, is_row_major(BLayout{}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes(); auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - rotating_mem_ptr = std::make_unique>( - kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr = + std::make_unique>( + kargs.as_ptr[0], + kargs.bs_ptr[0], + s.rotating_count_, + size_a_buffer, + size_b_buffer); rotating_mem_ptr->Print(); preprocess = [&]() { diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 3e3b2055ed..bc0853ec18 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -35,6 +35,10 @@ struct GemmConfigBase static constexpr bool TiledMMAPermuteN = false; }; +// Type trait for tf32 storage type (tf32 uses float for memory layout calculations) +template +using prec_storage_type = ck_tile::if_select_t; + template struct GemmConfigMemoryInterwave : public GemmConfigBase { @@ -81,7 +85,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase // Compute V3 only support Intrawave scheduler static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(prec_storage_type); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; @@ -121,7 +125,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(prec_storage_type); static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; @@ -293,7 +297,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(prec_storage_type); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; @@ -302,7 +306,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile, M_Warp_Tile, true>(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -324,6 +328,15 @@ struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill
 struct GemmTypeConfig;
 
+template <>
+struct GemmTypeConfig
+{
+    using ADataType   = float;
+    using BDataType   = float;
+    using AccDataType = float;
+    using CDataType   = float;
+};
+
 template <>
 struct GemmTypeConfig
 {
@@ -486,7 +499,7 @@ inline auto create_args()
         .insert("stride_b", "0", "Tensor B stride")
         .insert("stride_c", "0", "Tensor C stride")
         .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
-        .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
+        .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t/tf32 (tf32 only on gfx950)")
         .insert("warmup", "50", "number of iterations before benchmark the kernel")
         .insert("repeat", "100", "number of iterations to benchmark the kernel")
         .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc
index c4f9c2afda..39dd6357e5 100644
--- a/example/ck_tile/03_gemm/run_gemm_example.inc
+++ b/example/ck_tile/03_gemm/run_gemm_example.inc
@@ -30,6 +30,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
         ck_tile::get_relative_threshold(kbatch);
     const auto atol_split_k = ck_tile::get_absolute_threshold(
         max_accumulated_value, kbatch);
+    // Use higher threshold
     return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
 }
 
@@ -205,11 +206,13 @@ std::tuple inline parse_ge
     return std::make_tuple(M, N, K);
 }
 
+// ADataType_ and BDataType_ are original types (e.g., tf32_t for TF32 mode)
+// They are passed through invoke_gemm to invoker for tf32 auto-detection
 template 
@@ -218,7 +221,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
                                   const BLayout b_layout                  = BLayout{},
                                   [[maybe_unused]] const CLayout c_layout = CLayout{})
 {
-    using AccDataType = typename GemmTypeConfig::AccDataType;
+    // ADataTypeCompute: compute type (tf32_t for TF32 mode, used for warp gemm selection)
+    // ADataTypeBuf: buffer/storage type (fp32 when tf32, from TypeConfig)
+    using ADataTypeCompute = ADataType_;
+    using BDataTypeCompute = BDataType_;
+
+    // Use GemmTypeConfig to get actual data types for tensor operations
+    // This handles tf32 -> float mapping for host tensors and device buffers
+    using TypeConfig   = GemmTypeConfig;
+    using ADataTypeBuf = typename TypeConfig::ADataType;
+    using BDataTypeBuf = typename TypeConfig::BDataType;
+    using CDataType    = typename TypeConfig::CDataType;
+    using AccDataType  = typename TypeConfig::AccDataType;
 
     ck_tile::index_t M = arg_parser.get_int("m");
     ck_tile::index_t N = arg_parser.get_int("n");
@@ -242,27 +256,27 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
     stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
     stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
 
-    ck_tile::HostTensor a_m_k(
+    ck_tile::HostTensor a_m_k(
         ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
-    ck_tile::HostTensor b_k_n(
+    ck_tile::HostTensor b_k_n(
         ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
     ck_tile::HostTensor c_m_n_dev_result(
         ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
 
     if(init_method == 0)
     {
-        ck_tile::FillUniformDistribution{-2.f, 2.f}(a_m_k);
-        ck_tile::FillUniformDistribution{-2.f, 2.f}(b_k_n);
+        ck_tile::FillUniformDistribution{-2.f, 2.f}(a_m_k);
+        ck_tile::FillUniformDistribution{-2.f, 2.f}(b_k_n);
     }
     else if(init_method == 1)
     {
-        ck_tile::FillMonotonicSeq{}(a_m_k);
-        ck_tile::FillMonotonicSeq{}(b_k_n);
+        ck_tile::FillMonotonicSeq{}(a_m_k);
+        ck_tile::FillMonotonicSeq{}(b_k_n);
     }
     else if(init_method == 2)
     {
-        ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k);
-        ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n);
+        ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k);
+        ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n);
     }
     else
     {
@@ -274,7 +288,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
     {
         if constexpr(GemmConfig::UseStructuredSparsity)
         {
-            ck_tile::AdjustToStructuredSparsity{}(a_m_k);
+            ck_tile::AdjustToStructuredSparsity{}(a_m_k);
         }
     }
 
@@ -286,7 +300,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
 
     if constexpr(preshuffle)
     {
-        ck_tile::HostTensor b_shuffle_host = [&]() {
+        ck_tile::HostTensor b_shuffle_host = [&]() {
             if constexpr(GemmConfig::TiledMMAPermuteN)
             {
                 std::cout << "Run with PermuteN" << std::endl;
@@ -299,7 +313,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
             }
         }();
         // shuffled buffer B for device implementation
-        if constexpr(std::is_same_v)
+        if constexpr(std::is_same_v)
         {
             ck_tile::permute_vectors_i4x4_b(b_shuffle_host);
         }
@@ -307,16 +321,16 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
     }
     else
     {
-        if constexpr(std::is_same_v)
+        if constexpr(std::is_same_v)
         {
             // Permute vector pk_i4x4 data for device implementation
-            ck_tile::HostTensor b_k_n_dev = b_k_n;
+            ck_tile::HostTensor b_k_n_dev = b_k_n;
             if constexpr(GemmConfig::PermuteB)
             {
                 permute_tensor_b,
                                  AccDataType,
                                  CDataType,
@@ -371,8 +385,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
 
     std::size_t flop = std::size_t(2) * M * N * K;
     std::size_t num_byte =
-        sizeof(ADataType) * M * K / ck_tile::numeric_traits::PackedSize +
-        sizeof(BDataType) * N * K / ck_tile::numeric_traits::PackedSize +
+        sizeof(ADataTypeBuf) * M * K / ck_tile::numeric_traits::PackedSize +
+        sizeof(BDataTypeBuf) * N * K / ck_tile::numeric_traits::PackedSize +
         sizeof(CDataType) * M * N;
     float tflops     = static_cast(flop) / 1.E9 / ave_time;
     float gb_per_sec = num_byte / 1.E6 / ave_time;
@@ -381,8 +395,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
               << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
               << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
               << " C_Layout=" << CLayout::name
-              << " A_Type=" << ck_tile::DataTypeTraits::name
-              << " B_Type=" << ck_tile::DataTypeTraits::name
+              << " A_Type=" << ck_tile::DataTypeTraits::name
+              << " B_Type=" << ck_tile::DataTypeTraits::name
               << " C_Type=" << ck_tile::DataTypeTraits::name
               << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
               << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
@@ -397,17 +411,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
 
     if(arg_parser.get_int("v") == 1)
     {
-        ck_tile::reference_gemm(
+        ck_tile::reference_gemm(
             a_m_k, b_k_n, c_m_n_ref);
         const float max_accumulated_value =
             *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
-        const auto rtol_atol = calculate_rtol_atol(
-            K, kbatch, max_accumulated_value);
+        const auto rtol_atol =
+            calculate_rtol_atol(
+                K, kbatch, max_accumulated_value);
         pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
     }
     else if(arg_parser.get_int("v") == 2)
     {
-        if constexpr(std::is_same_v)
+        if constexpr(std::is_same_v)
         {
             // Restore input for B for gpu reference
             b_k_n_dev_buf.ToDevice(b_k_n.data());
@@ -421,12 +436,12 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
         ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
         c_m_n_gpu_buf_ref.SetZero();
 
-        ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer());
-        BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer());
-        CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer());
+        ADataTypeBuf* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer());
+        BDataTypeBuf* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer());
+        CDataType* d_C    = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer());
 
-        ck_tile::reference_gemm_gpu(
-            K, kbatch, max_accumulated_value);
+        const auto rtol_atol =
+            calculate_rtol_atol(
+                K, kbatch, max_accumulated_value);
         pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
     }
 
@@ -447,8 +463,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
         dump_gemm_json_results(arg_parser.get_str("jsonfile"),
diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp
index e193c58915..3508c0705e 100644
--- a/include/ck_tile/core/numeric/bfloat16.hpp
+++ b/include/ck_tile/core/numeric/bfloat16.hpp
@@ -6,6 +6,7 @@
 #include "ck_tile/core/numeric/half.hpp"
 #include "ck_tile/core/numeric/integral_constant.hpp"
 #include "ck_tile/core/numeric/numeric.hpp"
+#include "ck_tile/core/numeric/ext_vector_base.hpp"
 #if CK_TILE_USE_LLVM_BUILTIN_BF16
 #include 
 #endif
@@ -440,4 +441,62 @@ CK_TILE_HOST_DEVICE constexpr bf16x2_t fp32x2_to_bf16x2(const fp32x2_t& x)
     return bf16x2_t{float_to_bf16(x.x), float_to_bf16(x.y)};
 }
 
+// Available on gfx94x (gfx942, gfx950) and later
+CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
+{
+#if defined(__gfx94__) && CK_TILE_USE_LLVM_BUILTIN_BF16
+    return __builtin_convertvector(fp32x2_t{a, b}, bf16x2_t);
+#else
+    return fp32x2_to_bf16x2(fp32x2_t{a, b});
+#endif
+}
+
+// Packed bf16x2 to fp32x2 conversion
+CK_TILE_HOST_DEVICE constexpr fp32x2_t bf16x2_to_fp32x2(bf16x2_t x)
+{
+#if CK_TILE_USE_LLVM_BUILTIN_BF16
+    return __builtin_convertvector(x, fp32x2_t);
+#else
+    uint32_t packed = bit_cast(x);
+    float f0        = bit_cast(packed << 16);
+    float f1        = bit_cast(packed & 0xFFFF0000u);
+    return fp32x2_t{f0, f1};
+#endif
+}
+
+#ifndef CK_TILE_TF32_USE_PACKED_CVT
+#define CK_TILE_TF32_USE_PACKED_CVT 1
+#endif
+
+template 
+CK_TILE_DEVICE void convert_float_to_bf16_pairs(const ext_vector_t& reg_f32,
+                                                ext_vector_t& reg_bf16_big,
+                                                ext_vector_t& reg_bf16_small)
+{
+#if defined(__gfx94__) && CK_TILE_TF32_USE_PACKED_CVT && CK_TILE_USE_LLVM_BUILTIN_BF16
+    static_assert(VecSize % 2 == 0, "VecSize must be even for packed operations");
+
+#pragma unroll
+    for(int i = 0; i < VecSize; i += 2)
+    {
+        fp32x2_t orig = {reg_f32[i], reg_f32[i + 1]};
+
+        bf16x2_t big_pair   = cvt_pk_bf16_f32(orig[0], orig[1]);
+        fp32x2_t big_f32    = bf16x2_to_fp32x2(big_pair);
+        fp32x2_t diff       = orig - big_f32;
+        bf16x2_t small_pair = cvt_pk_bf16_f32(diff[0], diff[1]);
+
+        reinterpret_cast(®_bf16_big)[i / 2]   = big_pair;
+        reinterpret_cast(®_bf16_small)[i / 2] = small_pair;
+    }
+#else
+#pragma unroll
+    for(int i = 0; i < VecSize; i++)
+    {
+        reg_bf16_big[i]   = float_to_bf16(reg_f32[i]);
+        reg_bf16_small[i] = float_to_bf16(reg_f32[i] - bf16_to_float(reg_bf16_big[i]));
+    }
+#endif
+}
+
 } // namespace ck_tile
diff --git a/include/ck_tile/core/numeric/ext_vector_base.hpp b/include/ck_tile/core/numeric/ext_vector_base.hpp
new file mode 100644
index 0000000000..a0c7d2248d
--- /dev/null
+++ b/include/ck_tile/core/numeric/ext_vector_base.hpp
@@ -0,0 +1,80 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include "ck_tile/core/numeric/integer.hpp"
+#include "ck_tile/core/utility/type_traits.hpp"
+
+#include 
+
+namespace ck_tile {
+
+// this structure is used to pick up the  type inside
+// using xxx =  __attribute__((ext_vector_type(N)));
+// because clang only allow native type + bool in this term (custom type will fail)
+// overload this structure to let proper  type
+
+template 
+struct native_t
+{
+    using type = remove_cvref_t;
+};
+
+// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
+// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
+// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2)));  -> will
+// have compiler error
+namespace impl {
+
+template 
+struct ext_vector;
+
+template 
+struct ext_vector::type>>>
+{
+    static constexpr index_t N = N_;
+    // struct type is not supported for ext_vector
+    using value_type = typename native_t::type;
+    static_assert(!std::is_class_v);
+    using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
+};
+
+template 
+struct ext_vector::type>>>
+{
+    static constexpr index_t N = N_;
+    // struct type is not supported for ext_vector
+    using value_type = typename native_t::type::type;
+    static_assert(!std::is_class_v);
+    using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
+};
+
+template 
+struct ext_vector::type>>>
+{
+    static constexpr index_t N = Vs_ * N_;
+    using value_type           = typename native_t>::type;
+    static_assert(!std::is_class_v);
+    using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
+};
+
+template 
+struct ext_vector::type>>>
+{
+    static constexpr index_t N = Vs_ * N_;
+    using value_type           = typename native_t>::type::type;
+    static_assert(!std::is_class_v);
+    using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
+};
+
+} // namespace impl
+
+template 
+using ext_vector_t = typename impl::ext_vector::type;
+
+} // namespace ck_tile
diff --git a/include/ck_tile/core/numeric/numeric.hpp b/include/ck_tile/core/numeric/numeric.hpp
index b2bd628685..e78bba32e1 100644
--- a/include/ck_tile/core/numeric/numeric.hpp
+++ b/include/ck_tile/core/numeric/numeric.hpp
@@ -9,6 +9,11 @@
 
 namespace ck_tile {
 
+// TF32 tag type: 1 sign bit, 8 exponent bits, 10 mantissa bits (see numeric_traits)
+struct tf32_t
+{
+};
+
 // this struct has the information of
 // 1. limit of a certain type, simliar to std::numeric_limits
 // 2. some pre-defined value, zero, one...
@@ -101,6 +106,25 @@ struct numeric_traits
     using bitwise_type                  = uint32_t;
 };
 
+template <>
+struct numeric_traits
+{
+    static constexpr int exp            = 8;
+    static constexpr int mant           = 10;
+    static constexpr int bias           = 127;
+    static constexpr uint32_t nan_mask  = 0x7F800000;
+    static constexpr uint32_t head_mask = 0xFF800000;
+    static constexpr uint32_t mant_mask = 0x7FFFFF;
+    static constexpr uint32_t exp_mask  = 0xFF;
+    static constexpr uint32_t abs_mask  = 0x7FFFFFFF;
+    static constexpr uint32_t Inf       = 0x7F800000;
+    static constexpr uint32_t NegInf    = 0xFF800000;
+    static constexpr uint32_t NaN       = 0x7F800001;
+    static constexpr uint32_t Neg0      = 0x80000000;
+    static constexpr int PackedSize     = 1;
+    using bitwise_type                  = uint32_t;
+};
+
 } // namespace ck_tile
 
 #define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)                                       \
diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp
index 634b845725..da5579f5f0 100644
--- a/include/ck_tile/core/numeric/type_convert.hpp
+++ b/include/ck_tile/core/numeric/type_convert.hpp
@@ -57,6 +57,44 @@ CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
 CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
 CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
 
+static constexpr uint32_t float32_exponent_mask = 0x7f800000u;
+
+enum class tf32_rounding_mode
+{
+    trunc = 0, // truncate
+    rne   = 1, // round to nearest even (RTNE)
+};
+
+template 
+CK_TILE_HOST_DEVICE constexpr float float_to_tf32(float x)
+{
+    uint32_t i = bit_cast(x);
+    if constexpr(rounding == tf32_rounding_mode::rne)
+    {
+        // RTNE rounding.
+        if((i & float32_exponent_mask) != float32_exponent_mask)
+        {
+            // Add rounding bias for round-to-nearest-even (RTNE) before truncating:
+            //  - 0xfff is the rounding bias corresponding to the 13 fraction bits that
+            //    will be discarded.
+            //  - (i >> 13) & 1 extracts the least significant of those discarded bits and
+            //    adding it implements "ties to even" (round half-way cases to even).
+            i += 0xfff + ((i >> 13) & 1);
+        }
+    }
+    // Zero out the lowest 13 fraction bits to form the TF32-like value.
+    i &= 0xFFFFE000u;
+    return bit_cast(i);
+}
+
+template , bool> = false>
+CK_TILE_HOST_DEVICE constexpr float type_convert(float x)
+{
+    return float_to_tf32(x);
+}
+
 CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
 CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
 CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp
index 756bc7f6fc..9425595b08 100644
--- a/include/ck_tile/core/numeric/vector_type.hpp
+++ b/include/ck_tile/core/numeric/vector_type.hpp
@@ -5,7 +5,7 @@
 
 #include "ck_tile/core/config.hpp"
 #include "ck_tile/core/container/array.hpp"
-#include "ck_tile/core/numeric/integer.hpp"
+#include "ck_tile/core/numeric/ext_vector_base.hpp"
 #include "ck_tile/core/numeric/integral_constant.hpp"
 #include "ck_tile/core/numeric/float8.hpp"
 #include "ck_tile/core/numeric/half.hpp"
@@ -13,77 +13,9 @@
 #include "ck_tile/core/numeric/pk_int4.hpp"
 #include "ck_tile/core/numeric/pk_fp4.hpp"
 #include "ck_tile/core/numeric/e8m0.hpp"
-#include "ck_tile/core/utility/type_traits.hpp"
 
 namespace ck_tile {
 
-// this structure is used to pick up the  type inside
-// using xxx =  __attribute__((ext_vector_type(N)));
-// because clang only allow native type + bool in this term (custom type will fail)
-// overload this structure to let proper  type
-
-template 
-struct native_t
-{
-    using type = remove_cvref_t;
-};
-
-// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
-// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
-// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2)));  -> will
-// have compiler error
-namespace impl {
-
-template 
-struct ext_vector;
-
-template 
-struct ext_vector::type>>>
-{
-    static constexpr index_t N = N_;
-    // struct type is not supported for ext_vector
-    using value_type = typename native_t::type;
-    static_assert(!std::is_class_v);
-    using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
-};
-
-template 
-struct ext_vector::type>>>
-{
-    static constexpr index_t N = N_;
-    // struct type is not supported for ext_vector
-    using value_type = typename native_t::type::type;
-    static_assert(!std::is_class_v);
-    using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
-};
-
-template 
-struct ext_vector::type>>>
-{
-    static constexpr index_t N = Vs_ * N_;
-    using value_type           = typename native_t>::type;
-    static_assert(!std::is_class_v);
-    using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
-};
-
-template 
-struct ext_vector::type>>>
-{
-    static constexpr index_t N = Vs_ * N_;
-    using value_type           = typename native_t>::type::type;
-    static_assert(!std::is_class_v);
-    using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
-};
-
-} // namespace impl
-
-template 
-using ext_vector_t = typename impl::ext_vector::type;
-
 // by default, any type will result in a vector_size=1 with scalar_type=T traits.
 // ... unless we have other vector_traits specialization
 template 
diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp
index c11d180839..7e0c0886bb 100644
--- a/include/ck_tile/core/utility/type_traits.hpp
+++ b/include/ck_tile/core/utility/type_traits.hpp
@@ -112,6 +112,11 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
 #pragma clang diagnostic pop
 }
 
+// Template ternary: if Cond == Match, use TrueType, else FalseType
+// Usage: if_select_t evaluates to float if T==int, else double
+template 
+using if_select_t = std::conditional_t, TrueType, FalseType>;
+
 template 
 struct is_any_of : std::false_type
 {
diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp
index bf4ec4ee94..458a725379 100644
--- a/include/ck_tile/host/check_err.hpp
+++ b/include/ck_tile/host/check_err.hpp
@@ -58,6 +58,7 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1
                             F16,
                             BF16,
                             F32,
+                            tf32_t,
                             pk_fp4_t,
                             pk_fp4_raw_t,
                             pk_int4_t,
@@ -76,8 +77,9 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1
         compute_error = std::pow(2, -numeric_traits::mant) * 0.5;
     }
 
-    static_assert(is_any_of::value,
-                  "Warning: Unhandled OutDataType for setting up the relative threshold!");
+    static_assert(
+        is_any_of::value,
+        "Warning: Unhandled OutDataType for setting up the relative threshold!");
 
     double output_error = 0;
     if constexpr(is_any_of::value)
@@ -90,8 +92,9 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1
     }
     double midway_error = std::max(compute_error, output_error);
 
-    static_assert(is_any_of::value,
-                  "Warning: Unhandled AccDataType for setting up the relative threshold!");
+    static_assert(
+        is_any_of::value,
+        "Warning: Unhandled AccDataType for setting up the relative threshold!");
 
     double acc_error = 0;
     if constexpr(is_any_of::value)
@@ -129,6 +132,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
                             F16,
                             BF16,
                             F32,
+                            tf32_t,
                             pk_fp4_t,
                             pk_fp4_raw_t,
                             pk_int4_t,
@@ -151,8 +155,9 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
         compute_error = std::pow(2, discrete_expo - numeric_traits::mant) * 0.5;
     }
 
-    static_assert(is_any_of::value,
-                  "Warning: Unhandled OutDataType for setting up the absolute threshold!");
+    static_assert(
+        is_any_of::value,
+        "Warning: Unhandled OutDataType for setting up the absolute threshold!");
 
     double output_error = 0;
     if constexpr(is_any_of::value)
@@ -168,8 +173,9 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
     }
     double midway_error = std::max(compute_error, output_error);
 
-    static_assert(is_any_of::value,
-                  "Warning: Unhandled AccDataType for setting up the absolute threshold!");
+    static_assert(
+        is_any_of::value,
+        "Warning: Unhandled AccDataType for setting up the absolute threshold!");
 
     double acc_error = 0;
     if constexpr(is_any_of::value)
diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp
index 80f93cfd1b..2ed5a14e6d 100644
--- a/include/ck_tile/host/reference/reference_gemm.hpp
+++ b/include/ck_tile/host/reference/reference_gemm.hpp
@@ -4,11 +4,11 @@
 #pragma once
 
 #include 
-#include 
 #include 
 
 #include "ck_tile/core.hpp"
 #include "ck_tile/host/host_tensor.hpp"
+#include "ck_tile/host/device_prop.hpp"
 
 namespace ck_tile {
 
@@ -447,24 +447,34 @@ CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor& a_m_k,
     std::cout << std::endl;
 }
 
-template 
-CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
-                                 const HostTensor& b_k_n,
-                                 HostTensor& c_m_n,
-                                 const AElementOp& a_element_op     = {},
-                                 const BElementOp& b_element_op     = {},
-                                 const ACCElementOp& acc_element_op = {})
+CK_TILE_HOST void
+reference_gemm(const HostTensor>& a_m_k,
+               const HostTensor>& b_k_n,
+               HostTensor& c_m_n,
+               const AElementOp& a_element_op     = {},
+               const BElementOp& b_element_op     = {},
+               const ACCElementOp& acc_element_op = {})
 {
+    if constexpr(std::is_same_v || std::is_same_v)
+        static_assert(std::is_same_v,
+                      "ADataType and BDataType must be the same");
+    using ADataTypeCompute = ADataType_;
+    using ADataTypeBuf     = if_select_t;
+    using BDataTypeBuf     = if_select_t;
+
     const std::size_t M = a_m_k.get_length(0);
     const std::size_t N = b_k_n.get_length(1);
     const std::size_t K = a_m_k.get_length(1);
 
+    const bool is_gfx950 = (ck_tile::get_device_name() == "gfx950");
+
     auto f_mn = [&](auto m, auto n) {
         AccDataType v_acc = 0;
 
@@ -472,7 +482,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
         {
             AccDataType v_a;
             AccDataType v_b;
-            if constexpr(std::is_same_v)
+            if constexpr(std::is_same_v)
             {
                 // HostTensor automatically handles packed indexing: a_m_k(m,k) divides offset by
                 // PackedSize So a_m_k(m,0) and a_m_k(m,1) return the same packed byte
@@ -481,7 +491,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
                 const float unpacked    = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
                 v_a = ck_tile::type_convert(a_element_op(unpacked));
             }
-            else if constexpr(std::is_same_v)
+            else if constexpr(std::is_same_v)
             {
                 // HostTensor automatically handles packed indexing
                 const pk_int4_t pk_val  = a_m_k(m, k);
@@ -493,7 +503,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
             {
                 v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k)));
             }
-            if constexpr(std::is_same_v)
+            if constexpr(std::is_same_v)
             {
                 // HostTensor automatically handles packed indexing
                 const pk_fp4_t pk_val   = b_k_n(k, n);
@@ -501,7 +511,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
                 const float unpacked    = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
                 v_b = ck_tile::type_convert(b_element_op(unpacked));
             }
-            else if constexpr(std::is_same_v)
+            else if constexpr(std::is_same_v)
             {
                 // HostTensor automatically handles packed indexing
                 const pk_int4_t pk_val  = b_k_n(k, n);
@@ -513,7 +523,36 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
             {
                 v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n)));
             }
-            v_acc += v_a * v_b;
+
+            if constexpr(std::is_same_v)
+            {
+                if(is_gfx950)
+                {
+                    // gfx950: use 3x bf16 emulation
+                    bf16_t v_a_bf16_big   = ck_tile::type_convert(v_a);
+                    bf16_t v_a_bf16_small = ck_tile::type_convert(
+                        v_a - type_convert(v_a_bf16_big));
+                    bf16_t v_b_bf16_big   = ck_tile::type_convert(v_b);
+                    bf16_t v_b_bf16_small = ck_tile::type_convert(
+                        v_b - type_convert(v_b_bf16_big));
+
+                    v_acc += ck_tile::type_convert(v_a_bf16_big) *
+                                 ck_tile::type_convert(v_b_bf16_small) +
+                             ck_tile::type_convert(v_a_bf16_small) *
+                                 ck_tile::type_convert(v_b_bf16_big) +
+                             ck_tile::type_convert(v_a_bf16_big) *
+                                 ck_tile::type_convert(v_b_bf16_big);
+                }
+                else
+                {
+                    // Other architectures: tf32 not supported or handled via fp32 fallback
+                    v_acc += v_a * v_b;
+                }
+            }
+            else
+            {
+                v_acc += v_a * v_b;
+            }
         }
 
         c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc));
@@ -764,15 +803,15 @@ reference_gemm_multiple_d(const HostTensor& a_m_k,
     make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
 }
 
-template 
-__global__ void naive_gemm_kernel(ADataType* A,
-                                  BDataType* B,
+__global__ void naive_gemm_kernel(if_select_t* A,
+                                  if_select_t* B,
                                   CDataType* C,
                                   ck_tile::index_t M,
                                   ck_tile::index_t N,
@@ -781,6 +820,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
                                   ck_tile::index_t strideB,
                                   ck_tile::index_t strideC)
 {
+    if constexpr(std::is_same_v || std::is_same_v)
+        static_assert(std::is_same_v,
+                      "ADataType and BDataType must be the same");
+    using ADataTypeCompute = ADataType_;
+    // ADataTypeBuf: buffer/storage type (fp32 when tf32)
+    using ADataTypeBuf = if_select_t;
+    using BDataTypeBuf = if_select_t;
+
     int idx = blockIdx.x * blockDim.x + threadIdx.x;
     int row = idx / N; // Compute row index
     int col = idx % N; // Compute column index
@@ -790,8 +837,8 @@ __global__ void naive_gemm_kernel(ADataType* A,
         AccDataType acc = 0.0;
         for(int k = 0; k < K; ++k)
         {
-            constexpr index_t packed_size_a = ck_tile::numeric_traits::PackedSize;
-            constexpr index_t packed_size_b = ck_tile::numeric_traits::PackedSize;
+            constexpr index_t packed_size_a = ck_tile::numeric_traits::PackedSize;
+            constexpr index_t packed_size_b = ck_tile::numeric_traits::PackedSize;
             // Adjust indexing based on matrix layout
             int a_index = (std::is_same_v)
                               ? row * strideA + k
@@ -802,7 +849,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
 
             AccDataType v_a;
             AccDataType v_b;
-            if constexpr(std::is_same_v)
+            if constexpr(std::is_same_v)
             {
                 const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
                 if(k % 2 == 1)
@@ -810,7 +857,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
                 else
                     v_a = fp32_val.lo;
             }
-            else if constexpr(std::is_same_v)
+            else if constexpr(std::is_same_v)
             {
                 const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
                 if(k % 2 == 1)
@@ -822,7 +869,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
             {
                 v_a = ck_tile::type_convert(A[a_index]);
             }
-            if constexpr(std::is_same_v)
+            if constexpr(std::is_same_v)
             {
                 const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
                 if(k % 2 == 1)
@@ -830,7 +877,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
                 else
                     v_b = fp32_val.lo;
             }
-            else if constexpr(std::is_same_v)
+            else if constexpr(std::is_same_v)
             {
                 const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
                 if(k % 2 == 1)
@@ -842,7 +889,33 @@ __global__ void naive_gemm_kernel(ADataType* A,
             {
                 v_b = ck_tile::type_convert(B[b_index]);
             }
-            acc += v_a * v_b;
+
+            if constexpr(std::is_same_v)
+            {
+#ifdef CK_GFX950_SUPPORT
+                // gfx950: use 3x bf16 emulation
+                bf16_t v_a_bf16_big = ck_tile::type_convert(v_a);
+                bf16_t v_a_bf16_small =
+                    ck_tile::type_convert(v_a - type_convert(v_a_bf16_big));
+                bf16_t v_b_bf16_big = ck_tile::type_convert(v_b);
+                bf16_t v_b_bf16_small =
+                    ck_tile::type_convert(v_b - type_convert(v_b_bf16_big));
+
+                acc += ck_tile::type_convert(v_a_bf16_big) *
+                           ck_tile::type_convert(v_b_bf16_small) +
+                       ck_tile::type_convert(v_a_bf16_small) *
+                           ck_tile::type_convert(v_b_bf16_big) +
+                       ck_tile::type_convert(v_a_bf16_big) *
+                           ck_tile::type_convert(v_b_bf16_big);
+#else
+                // Other architectures: use fp32 fallback
+                acc += v_a * v_b;
+#endif
+            }
+            else
+            {
+                acc += v_a * v_b;
+            }
         }
 
         int c_index = (std::is_same_v)
@@ -852,15 +925,15 @@ __global__ void naive_gemm_kernel(ADataType* A,
     }
 }
 
-template 
-__global__ void blockwise_gemm_kernel(ADataType* A,
-                                      BDataType* B,
+__global__ void blockwise_gemm_kernel(if_select_t* A,
+                                      if_select_t* B,
                                       CDataType* C,
                                       ck_tile::index_t M,
                                       ck_tile::index_t N,
@@ -874,6 +947,14 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
                                       float* scale_A_ptr,
                                       float* scale_B_ptr)
 {
+    if constexpr(std::is_same_v || std::is_same_v)
+        static_assert(std::is_same_v,
+                      "ADataType and BDataType must be the same");
+    using ADataTypeCompute = ADataType_;
+    // ADataTypeBuf: buffer/storage type (fp32 when tf32)
+    using ADataTypeBuf = if_select_t;
+    using BDataTypeBuf = if_select_t;
+
     int idx = blockIdx.x * blockDim.x + threadIdx.x;
     int row = idx / N; // Compute row index
     int col = idx % N; // Compute column index
@@ -902,8 +983,8 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
                                       (k / scale_granularity_k) * scale_B_stride];
             }
 
-            constexpr index_t packed_size_a = ck_tile::numeric_traits::PackedSize;
-            constexpr index_t packed_size_b = ck_tile::numeric_traits::PackedSize;
+            constexpr index_t packed_size_a = ck_tile::numeric_traits::PackedSize;
+            constexpr index_t packed_size_b = ck_tile::numeric_traits::PackedSize;
             // Adjust indexing based on matrix layout
             int a_index = (std::is_same_v)
                               ? row * strideA + k
@@ -914,7 +995,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
 
             AccDataType v_a;
             AccDataType v_b;
-            if constexpr(std::is_same_v)
+            if constexpr(std::is_same_v)
             {
                 const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
                 if(k % 2 == 1)
@@ -922,7 +1003,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
                 else
                     v_a = fp32_val.lo;
             }
-            else if constexpr(std::is_same_v)
+            else if constexpr(std::is_same_v)
             {
                 const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
                 if(k % 2 == 1)
@@ -935,7 +1016,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
                 v_a = ck_tile::type_convert(A[a_index]);
             }
 
-            if constexpr(std::is_same_v)
+            if constexpr(std::is_same_v)
             {
                 const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
                 if(k % 2 == 1)
@@ -943,7 +1024,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
                 else
                     v_b = fp32_val.lo;
             }
-            else if constexpr(std::is_same_v)
+            else if constexpr(std::is_same_v)
             {
                 const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
                 if(k % 2 == 1)
@@ -955,7 +1036,33 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
             {
                 v_b = ck_tile::type_convert(B[b_index]);
             }
-            acc_temp += v_a * v_b;
+
+            if constexpr(std::is_same_v)
+            {
+#ifdef CK_GFX950_SUPPORT
+                // gfx950: use 3x bf16 emulation
+                bf16_t v_a_bf16_big = ck_tile::type_convert(v_a);
+                bf16_t v_a_bf16_small =
+                    ck_tile::type_convert(v_a - type_convert(v_a_bf16_big));
+                bf16_t v_b_bf16_big = ck_tile::type_convert(v_b);
+                bf16_t v_b_bf16_small =
+                    ck_tile::type_convert(v_b - type_convert(v_b_bf16_big));
+
+                acc_temp += ck_tile::type_convert(v_a_bf16_big) *
+                                ck_tile::type_convert(v_b_bf16_small) +
+                            ck_tile::type_convert(v_a_bf16_small) *
+                                ck_tile::type_convert(v_b_bf16_big) +
+                            ck_tile::type_convert(v_a_bf16_big) *
+                                ck_tile::type_convert(v_b_bf16_big);
+#else
+                // Other architectures: use fp32 fallback
+                acc_temp += v_a * v_b;
+#endif
+            }
+            else
+            {
+                acc_temp += v_a * v_b;
+            }
         }
         // final accumulation
         acc += acc_temp * scale_A * scale_B;
@@ -974,8 +1081,8 @@ template 
-void reference_gemm_gpu(ADataType* a_ptr,
-                        BDataType* b_ptr,
+void reference_gemm_gpu(if_select_t* a_ptr,
+                        if_select_t* b_ptr,
                         CDataType* c_ptr,
                         index_t M,
                         index_t N,
@@ -1002,8 +1109,8 @@ template 
-void reference_blockwise_gemm_gpu(ADataType* a_ptr,
-                                  BDataType* b_ptr,
+void reference_blockwise_gemm_gpu(if_select_t* a_ptr,
+                                  if_select_t* b_ptr,
                                   CDataType* c_ptr,
                                   index_t M,
                                   index_t N,
@@ -1040,15 +1147,15 @@ void reference_blockwise_gemm_gpu(ADataType* a_ptr,
     return;
 }
 
-template 
-void reference_batched_gemm_gpu(ADataType* a_ptr,
-                                BDataType* b_ptr,
+void reference_batched_gemm_gpu(if_select_t* a_ptr,
+                                if_select_t* b_ptr,
                                 CDataType* c_ptr,
                                 index_t M,
                                 index_t N,
@@ -1061,18 +1168,29 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
                                 index_t batch_stride_C,
                                 index_t batch_count)
 {
+    using ADataTypeBuf = if_select_t;
+    using BDataTypeBuf = if_select_t;
+
+    using ADataTypeCompute = ADataType_;
+    using BDataTypeCompute = BDataType_;
+
     int totalElements      = M * N;
     int numThreadsPerBlock = 256; // Common choice for threads per block
     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
     for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
     {
-        ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
-        BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
-        CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
-        naive_gemm_kernel
-            <<>>(
-                d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
+        ADataTypeBuf* d_ATemp = a_ptr + batch_id * batch_stride_A;
+        BDataTypeBuf* d_BTemp = b_ptr + batch_id * batch_stride_B;
+        CDataType* d_CTemp    = c_ptr + batch_id * batch_stride_C;
+        naive_gemm_kernel<<>>(
+            d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
     }
 
     return;
diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
index 8dad775ecf..3639c811fd 100644
--- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
@@ -89,19 +89,32 @@ struct CShuffleEpilogue
                                                remove_cvref_t,
                                                remove_cvref_t>>;
 
-    using ADataType = remove_cvref_t{}, AsDataTypeTuple>>;
-    using BDataType = remove_cvref_t{}, BsDataTypeTuple>>;
+    // ADataTypeCompute: compute type from Problem (may be tf32_t for TF32 mode)
+    using ADataTypeCompute = remove_cvref_t{}, AsDataTypeTuple>>;
+    using BDataTypeCompute = remove_cvref_t{}, BsDataTypeTuple>>;
 
-    using ATypeToUse = std::conditional_t ||
-                                              std::is_same_v,
-                                          BDataType,
-                                          ADataType>;
+    // ADataTypeBuf: buffer/storage type (fp32 when tf32)
+    using ADataTypeBuf = if_select_t;
+    using BDataTypeBuf = if_select_t;
+
+    // For warp gemm selection: use tf32_t if compute type was tf32_t
+    // For pk_int4/pk_fp4: use the other data type
+    using ATypeToUse =
+        std::conditional_t,
+                           tf32_t,
+                           std::conditional_t ||
+                                                  std::is_same_v,
+                                              BDataTypeBuf,
+                                              ADataTypeBuf>>;
     // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
-    using BTypeToUse = std::conditional_t ||
-                                              std::is_same_v ||
-                                              sizeof(BDataType) < sizeof(ADataType),
-                                          ADataType,
-                                          BDataType>;
+    using BTypeToUse =
+        std::conditional_t,
+                           tf32_t,
+                           std::conditional_t ||
+                                                  std::is_same_v ||
+                                                  sizeof(BDataTypeBuf) < sizeof(ADataTypeBuf),
+                                              ADataTypeBuf,
+                                              BDataTypeBuf>>;
 
     using ELayout                          = remove_cvref_t;
     using CDElementwise                    = remove_cvref_t;
@@ -137,7 +150,7 @@ struct CShuffleEpilogue
     [[nodiscard]] CK_TILE_HOST static const std::string GetName()
     {
         // clang-format off
-        return concat('_', "CShuffleEpilogue", 
+        return concat('_', "CShuffleEpilogue",
                       concat('x', MWave, NWave),
                       concat('x', MPerXdl, NPerXdl, KPerXdl),
                       VectorSizeC,
@@ -440,8 +453,8 @@ struct CShuffleEpilogue
                 constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
                 // BlockedLayout
                 // this branch is for original a16w4
-                if constexpr(is_950 || is_any_of::value ||
-                             is_any_of::value)
+                if constexpr(is_950 || is_any_of::value ||
+                             is_any_of::value)
                 {
                     if constexpr(EightWave)
                     {
diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp
index 4cca604ff1..463f149a65 100644
--- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp
+++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp
@@ -229,15 +229,6 @@ CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
     return result;
 }
 
-CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
-{
-    bf16x2_t result;
-    asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
-                 : [result] "=v"(result)
-                 : [a] "v"(a), [b] "v"(b));
-    return result;
-}
-
 CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
 {
     fp32x2_t result;
@@ -856,7 +847,7 @@ struct BlockFmhaFwdV3Pipeline
                 }
                 else
                 {
-                    auto casted                           = detail::cvt_pk_bf16_f32(x, y);
+                    auto casted                           = ck_tile::cvt_pk_bf16_f32(x, y);
                     sp(sp_reg_idx).p.thread_buf_[idx]     = casted.x;
                     sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
                 }
diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
index 5a26d72c11..f43bcbc4b1 100644
--- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
@@ -49,6 +49,7 @@ struct GemmPipelineAgBgCrImplBase
     // that only work for certain K warp tile sizes based on data type size:
     // - For 1-byte types (fp8/bf8): K warp tile <= 64
     // - For 2-byte types (fp16/bf16): K warp tile <= 32
+    // - For 4-byte types (float/tf32): transpose load not supported
     static constexpr bool is_a_load_tr = []() {
         using WarpTile                  = typename BlockGemmShape::WarpTile;
         constexpr index_t kKWarpTile    = WarpTile::at(number<2>{});
@@ -57,6 +58,8 @@ struct GemmPipelineAgBgCrImplBase
             return false;
         else if constexpr(std::is_same_v)
             return false;
+        else if constexpr(sizeof(ADataType) >= 4)
+            return false; // 4-byte types (float/tf32) don't support transpose load
         else if constexpr(kKWarpTile > kMaxKWarpTile)
             return false;
         else
@@ -71,6 +74,8 @@ struct GemmPipelineAgBgCrImplBase
             return false;
         else if constexpr(std::is_same_v)
             return false;
+        else if constexpr(sizeof(BDataType) >= 4)
+            return false; // 4-byte types (float/tf32) don't support transpose load
         else if constexpr(kKWarpTile > kMaxKWarpTile)
             return false;
         else
diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
index 4e06ac8f7b..f45b0ffb2c 100644
--- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
@@ -909,26 +909,28 @@ struct UniversalGemmPipelineAgBgCrPolicy
             : vector_size * 4 == thread_elements              ? WGAttrNumAccessEnum::Quad
                                                               : WGAttrNumAccessEnum::Invalid;
 
-        using ADataType = remove_cvref_t;
-        using BDataType = remove_cvref_t;
-        using ATypeToUse =
-            std::conditional_t, BDataType, ADataType>;
+        using ADataType       = remove_cvref_t;
+        using BDataType       = remove_cvref_t;
+        using ComputeDataType = remove_cvref_t;
+
+        using ATypeToUse = if_select_t;
         using BTypeToUse = std::conditional_t ||
                                                   std::is_same_v ||
                                                   sizeof(BDataType) < sizeof(ADataType),
                                               ADataType,
                                               BDataType>;
 
-        using WarpGemm = WarpGemmDispatcher;
+        using WarpGemm =
+            WarpGemmDispatcher,
+                               if_select_t,
+                               typename Problem::CDataType,
+                               WarpTile::at(I0),
+                               WarpTile::at(I1),
+                               WarpTile::at(I2),
+                               Problem::TransposeC,
+                               false,
+                               Problem::UseStructuredSparsity,
+                               wg_attr_num_access>;
 
         using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy;
+        using ADataType       = remove_cvref_t;
+        using BDataType       = remove_cvref_t;
+
         // Determine compute types to use
         // This logic defaults to A/B DataType, but if one of them is packed falls back to the other
         // If both are packed, it falls back to the explicitly defined ComputeDataType in the
         // problem It might be a good idea to use ComputeDataType anyway, but that would break how
         // this behaviour used to work
-        using ATypeToUse = mixed_prec_compute_type_from_input_t;
-        using BTypeToUse = mixed_prec_compute_type_from_input_t;
-
+        using ATypeToUse =
+            mixed_prec_compute_type_from_input_t;
+        using BTypeToUse =
+            mixed_prec_compute_type_from_input_t;
         constexpr index_t WaveSize = get_warp_size();
         constexpr index_t KLane    = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
         // When BDataType is pk_int4_t, it is internally converted to fp8 for computation.
         constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse);
         constexpr auto NumAccess     = static_cast(max(1, KLaneBytes / 16));
-        using WarpGemm               = WarpGemmDispatcher;
+        // For tf32 mode, use tf32_t for warp gemm; otherwise use original types
+        using WarpGemm =
+            WarpGemmDispatcher,
+                               if_select_t,
+                               typename Problem::CDataType,
+                               WarpTile::at(I0),
+                               WarpTile::at(I1),
+                               WarpTile::at(I2),
+                               Problem::TransposeC,
+                               false,
+                               false,
+                               NumAccess>;
 
         using BlockWeightPreshufflePolicy =
             BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy>;
 
+// tf32
+// On gfx950: uses 3x bf16 MFMA emulation (no native xf32 support)
+
+#if defined(CK_GFX950_SUPPORT)
+// gfx950: tf32 emulated using 3x bf16 MFMA
+using WarpGemmMfmaTf32Tf32F32M32N32K16Native = WarpGemmImpl>>;
+
+using WarpGemmMfmaTf32Tf32F32M16N16K32Native = WarpGemmImpl>>;
+
+template 
+using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl,
+    AttrNumAccess>>;
+
+template 
+using WarpGemmMfmaTf32Tf32F32M16N16K32 = WarpGemmImpl,
+    AttrNumAccess>>;
+#endif
+
 // fp16
 
 using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
index bc591ae740..eb2f9c96f0 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
@@ -190,6 +190,141 @@ struct WarpGemmAttributeMfmaImplF32F32F32M32N32K2
     }
 };
 
+// tf32/xf32 emulation on gfx950 using 3x bf16 MFMA
+// Algorithm: split float into bf16_big and bf16_small, then compute:
+//   out = A_big * B_big + A_small * B_big + A_big * B_small
+// This provides tf32-like precision using bf16 hardware
+
+// V_MFMA_F32_32x32x16_XF32 emulated on gfx950 using 3x bf16 32x32x16
+template 
+struct WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950
+{
+    static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
+
+    using ADataType = float;
+    using BDataType = float;
+    using CDataType = float;
+
+    // Input: 8 floats for K=16 (each lane holds 8 elements, kABKPerLane=8)
+    using AVecType = ext_vector_t;
+    using BVecType = ext_vector_t;
+    using CVecType = ext_vector_t;
+
+    static constexpr index_t kM = 32;
+    static constexpr index_t kN = 32;
+    static constexpr index_t kK = 16;
+
+    static constexpr index_t kAMBlock = 1;
+    static constexpr index_t kBNBlock = 1;
+
+    static constexpr index_t kAMLane     = 32;
+    static constexpr index_t kBNLane     = 32;
+    static constexpr index_t kABKLane    = 2;
+    static constexpr index_t kABKPerLane = 8;
+
+    static constexpr index_t kCMLane     = 2;
+    static constexpr index_t kCNLane     = 32;
+    static constexpr index_t kCM0PerLane = 4;
+    static constexpr index_t kCM1PerLane = 4;
+
+    // c_vec += a_vec * b_vec
+    template 
+    CK_TILE_DEVICE void operator()(CVecType& c_vec,
+                                   const AVecType& a_vec,
+                                   const BVecType& b_vec,
+                                   bool_constant = {}) const
+    {
+#if defined(__gfx950__)
+        // Convert float to bf16 pairs using packed instructions
+        ext_vector_t a_big, a_small, b_big, b_small;
+        convert_float_to_bf16_pairs<8>(a_vec, a_big, a_small);
+        convert_float_to_bf16_pairs<8>(b_vec, b_big, b_small);
+
+        // Run 3 bf16 MFMAs: small*big, big*small, big*big
+        c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_small, b_big, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_big, b_small, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_big, b_big, c_vec, 0, 0, 0);
+#else
+        ck_tile::ignore = c_vec;
+        ck_tile::ignore = a_vec;
+        ck_tile::ignore = b_vec;
+#endif
+    }
+
+    // c_vec = a_vec * b_vec
+    CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
+    {
+        CVecType c_vec{0.f};
+        (*this)(c_vec, a_vec, b_vec);
+        return c_vec;
+    }
+};
+
+// V_MFMA_F32_16x16x32_XF32 emulated on gfx950 using 3x bf16 16x16x32
+template 
+struct WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950
+{
+    static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
+
+    using ADataType = float;
+    using BDataType = float;
+    using CDataType = float;
+
+    // Input: 8 floats for K=32 (each lane holds 8 elements, kABKPerLane=8)
+    using AVecType = ext_vector_t;
+    using BVecType = ext_vector_t;
+    using CVecType = ext_vector_t;
+
+    static constexpr index_t kM = 16;
+    static constexpr index_t kN = 16;
+    static constexpr index_t kK = 32;
+
+    static constexpr index_t kAMBlock = 1;
+    static constexpr index_t kBNBlock = 1;
+
+    static constexpr index_t kAMLane     = 16;
+    static constexpr index_t kBNLane     = 16;
+    static constexpr index_t kABKLane    = 4;
+    static constexpr index_t kABKPerLane = 8;
+
+    static constexpr index_t kCMLane     = 4;
+    static constexpr index_t kCNLane     = 16;
+    static constexpr index_t kCM0PerLane = 1;
+    static constexpr index_t kCM1PerLane = 4;
+
+    // c_vec += a_vec * b_vec
+    template 
+    CK_TILE_DEVICE void operator()(CVecType& c_vec,
+                                   const AVecType& a_vec,
+                                   const BVecType& b_vec,
+                                   bool_constant = {}) const
+    {
+#if defined(__gfx950__)
+        // Convert float to bf16 pairs using packed instructions
+        ext_vector_t a_big, a_small, b_big, b_small;
+        convert_float_to_bf16_pairs<8>(a_vec, a_big, a_small);
+        convert_float_to_bf16_pairs<8>(b_vec, b_big, b_small);
+
+        // Run 3 bf16 MFMAs: small*big, big*small, big*big
+        c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_small, b_big, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_big, b_small, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_big, b_big, c_vec, 0, 0, 0);
+#else
+        ck_tile::ignore = c_vec;
+        ck_tile::ignore = a_vec;
+        ck_tile::ignore = b_vec;
+#endif
+    }
+
+    // c_vec = a_vec * b_vec
+    CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
+    {
+        CVecType c_vec{0.f};
+        (*this)(c_vec, a_vec, b_vec);
+        return c_vec;
+    }
+};
+
 // V_MFMA_F32_16x16x32_BF16
 template 
 struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
index 081ff5150d..94e0494aac 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
@@ -40,6 +40,22 @@ template<> struct Dispatcher { using Typ
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K8; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
+
+// tf32 (on gfx950: uses 3x bf16 MFMA emulation)
+// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
+#if defined(CK_GFX950_SUPPORT)
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+// TF32 16x16x32 for weight preshuffle pipeline (uses native 16x16x32 TF32 MFMA emulation)
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K32<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K32; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K32; };
+#endif
+// Note: For gfx11/gfx12 and other architectures that don't support tf32,
+// these dispatchers are not defined. Code using tf32 should be guarded
+// by CK_ENABLE_TF32 or CK_GFX950_SUPPORT macros.
 // fp16
 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
 template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
diff --git a/test/ck_tile/data_type/CMakeLists.txt b/test/ck_tile/data_type/CMakeLists.txt
index 17df115b9d..357e2f2721 100644
--- a/test/ck_tile/data_type/CMakeLists.txt
+++ b/test/ck_tile/data_type/CMakeLists.txt
@@ -7,6 +7,8 @@ endif()
 if(GPU_TARGETS MATCHES "gfx95")
     add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
     add_gtest_executable(test_ck_tile_mx_scale test_mx_scale.cpp)
+    add_gtest_executable(test_ck_tile_tf32 test_tf32.cpp)
+    add_gtest_executable(test_ck_tile_bf16_f32_convert test_bf16_f32_convert.cpp)
 endif()
 
 if(CK_USE_OCP_FP8 OR CK_USE_FNUZ_FP8)
diff --git a/test/ck_tile/data_type/test_bf16_f32_convert.cpp b/test/ck_tile/data_type/test_bf16_f32_convert.cpp
new file mode 100644
index 0000000000..d74d8bc072
--- /dev/null
+++ b/test/ck_tile/data_type/test_bf16_f32_convert.cpp
@@ -0,0 +1,248 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+#include "gtest/gtest.h"
+#include 
+#include 
+
+#include 
+#include "ck_tile/core.hpp"
+#include "ck_tile/host.hpp"
+
+using ck_tile::bf16_to_float;
+using ck_tile::bf16x2_t;
+using ck_tile::bfloat16_t;
+using ck_tile::bit_cast;
+using ck_tile::float_to_bf16;
+using ck_tile::fp32x2_t;
+
+// =====================================================================
+// Tests for bf16x2_to_fp32x2 (host-side, always available)
+// =====================================================================
+
+TEST(Bf16F32Convert, Bf16x2ToFp32x2_BasicValues)
+{
+    auto a = float_to_bf16(1.0f);
+    auto b = float_to_bf16(-2.5f);
+
+    bf16x2_t packed{a, b};
+    fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
+
+    EXPECT_FLOAT_EQ(result[0], bf16_to_float(a));
+    EXPECT_FLOAT_EQ(result[1], bf16_to_float(b));
+}
+
+TEST(Bf16F32Convert, Bf16x2ToFp32x2_Zeros)
+{
+    auto pos_zero = float_to_bf16(0.0f);
+    auto neg_zero = float_to_bf16(-0.0f);
+
+    bf16x2_t packed{pos_zero, neg_zero};
+    fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
+
+    EXPECT_FLOAT_EQ(result[0], 0.0f);
+    EXPECT_TRUE(std::signbit(result[1]));
+    EXPECT_FLOAT_EQ(result[1], -0.0f);
+}
+
+TEST(Bf16F32Convert, Bf16x2ToFp32x2_LargeSmall)
+{
+    auto big   = float_to_bf16(65504.0f);
+    auto small = float_to_bf16(0.00390625f);
+
+    bf16x2_t packed{big, small};
+    fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
+
+    EXPECT_FLOAT_EQ(result[0], bf16_to_float(big));
+    EXPECT_FLOAT_EQ(result[1], bf16_to_float(small));
+}
+
+TEST(Bf16F32Convert, Bf16x2ToFp32x2_RoundTrip)
+{
+    const float test_values[] = {1.0f, -1.0f, 0.5f, 3.14f, 100.0f, -42.0f, 0.001f};
+    for(float v : test_values)
+    {
+        auto bf        = float_to_bf16(v);
+        float expected = bf16_to_float(bf);
+
+        bf16x2_t packed{bf, bf};
+        fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
+
+        EXPECT_FLOAT_EQ(result[0], expected) << "v=" << v;
+        EXPECT_FLOAT_EQ(result[1], expected) << "v=" << v;
+    }
+}
+
+// =====================================================================
+// Tests for fp32x2_to_bf16x2 (host-side)
+// =====================================================================
+
+TEST(Bf16F32Convert, Fp32x2ToBf16x2_BasicValues)
+{
+    fp32x2_t input{1.5f, -3.0f};
+    bf16x2_t result = ck_tile::fp32x2_to_bf16x2(input);
+
+    EXPECT_FLOAT_EQ(bf16_to_float(result[0]), bf16_to_float(float_to_bf16(1.5f)));
+    EXPECT_FLOAT_EQ(bf16_to_float(result[1]), bf16_to_float(float_to_bf16(-3.0f)));
+}
+
+// =====================================================================
+// Device tests for cvt_pk_bf16_f32 and convert_float_to_bf16_pairs
+// =====================================================================
+
+struct CvtPkBf16F32Result
+{
+    bfloat16_t r0;
+    bfloat16_t r1;
+};
+
+__global__ void kernel_cvt_pk_bf16_f32(const float* in, CvtPkBf16F32Result* out, int n)
+{
+    int idx = threadIdx.x;
+    if(idx < n)
+    {
+        bf16x2_t result = ck_tile::cvt_pk_bf16_f32(in[2 * idx], in[2 * idx + 1]);
+        out[idx].r0     = result[0];
+        out[idx].r1     = result[1];
+    }
+}
+
+TEST(Bf16F32Convert, CvtPkBf16F32_Device)
+{
+    const std::vector host_in = {1.0f, -1.0f, 0.0f, 3.14f, 100.0f, -0.5f, 42.0f, 0.001f};
+    const int num_pairs              = host_in.size() / 2;
+
+    ck_tile::DeviceMem in_buf(host_in.size() * sizeof(float));
+    ck_tile::DeviceMem out_buf(num_pairs * sizeof(CvtPkBf16F32Result));
+    in_buf.ToDevice(host_in.data());
+
+    kernel_cvt_pk_bf16_f32<<<1, num_pairs>>>(
+        static_cast(in_buf.GetDeviceBuffer()),
+        static_cast(out_buf.GetDeviceBuffer()),
+        num_pairs);
+    (void)hipDeviceSynchronize();
+
+    std::vector host_out(num_pairs);
+    out_buf.FromDevice(host_out.data());
+
+    for(int i = 0; i < num_pairs; i++)
+    {
+        float ref0 = bf16_to_float(float_to_bf16(host_in[2 * i]));
+        float ref1 = bf16_to_float(float_to_bf16(host_in[2 * i + 1]));
+        EXPECT_FLOAT_EQ(bf16_to_float(host_out[i].r0), ref0) << "pair=" << i << " elem=0";
+        EXPECT_FLOAT_EQ(bf16_to_float(host_out[i].r1), ref1) << "pair=" << i << " elem=1";
+    }
+}
+
+// =====================================================================
+// Device test for convert_float_to_bf16_pairs
+// =====================================================================
+
+template 
+struct Bf16PairsResult
+{
+    bfloat16_t big[VecSize];
+    bfloat16_t small_val[VecSize];
+};
+
+template 
+__global__ void kernel_convert_float_to_bf16_pairs(const float* in, Bf16PairsResult* out)
+{
+    using float_vec_t = ck_tile::ext_vector_t;
+    using bf16_vec_t  = ck_tile::ext_vector_t;
+
+    float_vec_t reg_f32;
+    for(int i = 0; i < VecSize; i++)
+        reg_f32[i] = in[i];
+
+    bf16_vec_t reg_big, reg_small;
+    ck_tile::convert_float_to_bf16_pairs(reg_f32, reg_big, reg_small);
+
+    for(int i = 0; i < VecSize; i++)
+    {
+        out[0].big[i]       = reg_big[i];
+        out[0].small_val[i] = reg_small[i];
+    }
+}
+
+template 
+void test_convert_float_to_bf16_pairs_device()
+{
+    static_assert(VecSize >= 2 && VecSize % 2 == 0);
+
+    std::vector host_in(VecSize);
+    // Use diverse values: mix of exact and non-exact bf16 representable numbers
+    const float base_vals[] = {1.1f, -2.3f, 0.7f, 100.1f, -0.001f, 42.42f, 3.14f, -7.77f};
+    for(int i = 0; i < VecSize; i++)
+        host_in[i] = base_vals[i % 8];
+
+    ck_tile::DeviceMem in_buf(VecSize * sizeof(float));
+    ck_tile::DeviceMem out_buf(sizeof(Bf16PairsResult));
+    in_buf.ToDevice(host_in.data());
+
+    kernel_convert_float_to_bf16_pairs
+        <<<1, 1>>>(static_cast(in_buf.GetDeviceBuffer()),
+                   static_cast*>(out_buf.GetDeviceBuffer()));
+    (void)hipDeviceSynchronize();
+
+    Bf16PairsResult host_out;
+    out_buf.FromDevice(&host_out);
+
+    for(int i = 0; i < VecSize; i++)
+    {
+        float orig  = host_in[i];
+        float big_f = bf16_to_float(host_out.big[i]);
+
+        // big should match scalar float_to_bf16
+        float ref_big = bf16_to_float(float_to_bf16(orig));
+        EXPECT_FLOAT_EQ(big_f, ref_big) << "VecSize=" << VecSize << " i=" << i;
+
+        // small should match float_to_bf16(orig - big)
+        float ref_small = bf16_to_float(float_to_bf16(orig - ref_big));
+        float small_f   = bf16_to_float(host_out.small_val[i]);
+        EXPECT_FLOAT_EQ(small_f, ref_small) << "VecSize=" << VecSize << " i=" << i;
+
+        // big + small should be closer to orig than big alone
+        float reconstructed = big_f + small_f;
+        EXPECT_LE(std::fabs(reconstructed - orig), std::fabs(big_f - orig) + 1e-10f)
+            << "VecSize=" << VecSize << " i=" << i;
+    }
+}
+
+TEST(Bf16F32Convert, ConvertFloatToBf16Pairs_Vec2) { test_convert_float_to_bf16_pairs_device<2>(); }
+TEST(Bf16F32Convert, ConvertFloatToBf16Pairs_Vec4) { test_convert_float_to_bf16_pairs_device<4>(); }
+TEST(Bf16F32Convert, ConvertFloatToBf16Pairs_Vec8) { test_convert_float_to_bf16_pairs_device<8>(); }
+
+// =====================================================================
+// 3x BF16 multiply-accumulate precision test
+// =====================================================================
+
+TEST(Bf16F32Convert, ThreeBf16MulAccPrecision)
+{
+    // Verify that a_big*b_big + a_small*b_big + a_big*b_small is more precise
+    // than a single bf16(a)*bf16(b) for non-exact values
+    const float test_pairs[][2] = {
+        {1.1f, 2.3f}, {3.14f, -2.71f}, {0.123f, 456.789f}, {-100.1f, 0.99f}};
+
+    for(const auto& pair : test_pairs)
+    {
+        float a = pair[0];
+        float b = pair[1];
+
+        float a_big_f   = bf16_to_float(float_to_bf16(a));
+        float a_small_f = bf16_to_float(float_to_bf16(a - a_big_f));
+        float b_big_f   = bf16_to_float(float_to_bf16(b));
+        float b_small_f = bf16_to_float(float_to_bf16(b - b_big_f));
+
+        float exact       = a * b;
+        float single_bf16 = a_big_f * b_big_f;
+        float three_bf16  = a_big_f * b_big_f + a_small_f * b_big_f + a_big_f * b_small_f;
+
+        float err_single = std::fabs(exact - single_bf16);
+        float err_three  = std::fabs(exact - three_bf16);
+
+        EXPECT_LE(err_three, err_single + 1e-10f)
+            << "a=" << a << " b=" << b << " exact=" << exact << " single=" << single_bf16
+            << " three=" << three_bf16;
+    }
+}
diff --git a/test/ck_tile/data_type/test_tf32.cpp b/test/ck_tile/data_type/test_tf32.cpp
new file mode 100644
index 0000000000..f7f0c390fc
--- /dev/null
+++ b/test/ck_tile/data_type/test_tf32.cpp
@@ -0,0 +1,86 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+#include "gtest/gtest.h"
+#include 
+#include 
+#include 
+
+#include "ck_tile/core.hpp"
+
+using ck_tile::bit_cast;
+using ck_tile::numeric_traits;
+using ck_tile::tf32_rounding_mode;
+using ck_tile::tf32_t;
+using ck_tile::type_convert;
+
+static uint32_t to_bits(float x) { return bit_cast(x); }
+static float from_bits(uint32_t i) { return bit_cast(i); }
+
+TEST(ConvertTest, NumericTraits)
+{
+    EXPECT_EQ(numeric_traits::exp, 8);
+    EXPECT_EQ(numeric_traits::mant, 10);
+    EXPECT_EQ(numeric_traits::bias, 127);
+    EXPECT_EQ(numeric_traits::PackedSize, 1);
+}
+
+TEST(ConvertTest, ToTf32Trunc)
+{
+    // exact values (low 13 bits already zero)
+    EXPECT_EQ(to_bits(type_convert(1.0f)), 0x3F800000u);  // 1.0f
+    EXPECT_EQ(to_bits(type_convert(-1.0f)), 0xBF800000u); // -1.0f
+    EXPECT_EQ(to_bits(type_convert(0.0f)), 0x00000000u);  // +0.0f
+    EXPECT_EQ(to_bits(type_convert(-0.0f)), 0x80000000u); // -0.0f
+    EXPECT_EQ(to_bits(type_convert(2.0f)), 0x40000000u);  // 2.0f
+    EXPECT_EQ(to_bits(type_convert(0.5f)), 0x3F000000u);  // 0.5f
+
+    // truncation zeros the low 13 mantissa bits
+    EXPECT_EQ(to_bits(type_convert(1.1f)), 0x3F8CC000u); // 1.1f (0x3F8CCCCD)
+    EXPECT_EQ(to_bits(type_convert(3.14159265358979323846f)),
+              0x40490000u); // pi (0x40490FDB)
+    EXPECT_EQ(to_bits(type_convert(123.456f)),
+              0x42F6E000u);                                        // 123.456f (0x42F6E979)
+    EXPECT_EQ(to_bits(type_convert(-3.14f)), 0xC048E000u); // -3.14f (0xC048F5C3)
+
+    // special values
+    EXPECT_EQ(to_bits(type_convert(std::numeric_limits::infinity())), 0x7F800000u);
+    EXPECT_EQ(to_bits(type_convert(-std::numeric_limits::infinity())), 0xFF800000u);
+    EXPECT_TRUE(std::isnan(type_convert(std::numeric_limits::quiet_NaN())));
+    EXPECT_EQ(to_bits(type_convert(std::numeric_limits::denorm_min())), 0x00000000u);
+
+    // property: low 13 bits must be zero, top 19 bits preserved
+    for(float val : {1.0f, 1.5f, 2.0f, 0.1f, 100.0f, -42.5f, 1e10f, 1e-10f})
+    {
+        uint32_t orig = to_bits(val);
+        uint32_t tf32 = to_bits(type_convert(val));
+
+        EXPECT_EQ(tf32 & 0xFFFFE000u, tf32) << "val=" << val;
+        EXPECT_EQ(orig & 0xFFFFE000u, tf32) << "val=" << val;
+    }
+}
+
+TEST(ConvertTest, ToTf32Rtne)
+{
+    // exact values (low 13 bits already zero)
+    EXPECT_EQ(to_bits(type_convert(1.0f)),
+              0x3F800000u); // 1.0f
+    EXPECT_EQ(to_bits(type_convert(-1.0f)),
+              0xBF800000u); // -1.0f
+    EXPECT_EQ(to_bits(type_convert(0.0f)),
+              0x00000000u); // +0.0f
+
+    // past midpoint (bit12 + bit11 set) -> rounds up
+    float val = from_bits(0x3F801800u);
+    EXPECT_EQ(to_bits(type_convert(val)), 0x3F802000u);
+
+    // special values (keep the same as float)
+    EXPECT_EQ(to_bits(type_convert(
+                  std::numeric_limits::infinity())),
+              0x7F800000u); // infinity in float is 0x7F800000
+    EXPECT_EQ(to_bits(type_convert(
+                  -std::numeric_limits::infinity())),
+              0xFF800000u); // negative infinity in float is 0xFF800000
+    EXPECT_TRUE(std::isnan(type_convert(
+        std::numeric_limits::quiet_NaN()))); // quiet NaN in float is 0x7FC00000
+}
diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp
index 2c258b5bb9..573cc08510 100644
--- a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp
+++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp
@@ -46,8 +46,8 @@ test_cshuffle_epilogue_kernel(const typename Problem::AccDataType* __restrict__
     __shared__ char smem[Epilogue::GetSmemSize()];
 
     // Create accumulator tile with GEMM accumulator distribution (matches BlockGemm)
-    using WG = ck_tile::WarpGemmDispatcher
 >;
 
+// TF32 (gfx950 only): 3x bf16 MFMA emulation, uses float buffers with tf32_t compute type
+// Tile: 128x128x64, Warp tile: 32x32x16
+using KernelTypesTf32Mem = ::testing::Types<
+    //         ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, K_TileSize, Scheduler, PipelineType
+    std::tuple<    Row,     Row,     Row,      TF32,      TF32,         F32,       F32,        I128,        I128,         I64,        I32,        I32,        I16, Intrawave,         Mem>,
+    std::tuple<    Row,     Row,     Row,      TF32,      TF32,         F32,       F32,        I128,        I128,         I64,        I32,        I32,        I16, Interwave,         Mem>,
+    std::tuple<    Row,     Col,     Row,      TF32,      TF32,         F32,       F32,        I128,        I128,         I64,        I32,        I32,        I16, Intrawave,         Mem>,
+    std::tuple<    Row,     Col,     Row,      TF32,      TF32,         F32,       F32,        I128,        I128,         I64,        I32,        I32,        I16, Interwave,         Mem>
+>;
+
 // clang-format on
diff --git a/test/ck_tile/gemm/test_gemm_pipeline_prec_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_prec_types.hpp
index c758737897..1490a23b0d 100644
--- a/test/ck_tile/gemm/test_gemm_pipeline_prec_types.hpp
+++ b/test/ck_tile/gemm/test_gemm_pipeline_prec_types.hpp
@@ -13,3 +13,5 @@ using BF16 = ck_tile::bf16_t;
 using BF8  = ck_tile::bf8_t;
 
 using I4 = ck_tile::pk_int4_t;
+
+using TF32 = ck_tile::tf32_t;
diff --git a/test/ck_tile/gemm/test_gemm_pipeline_tf32_mem.cpp b/test/ck_tile/gemm/test_gemm_pipeline_tf32_mem.cpp
new file mode 100644
index 0000000000..966d80a156
--- /dev/null
+++ b/test/ck_tile/gemm/test_gemm_pipeline_tf32_mem.cpp
@@ -0,0 +1,22 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+#include "test_gemm_pipeline_kernel_types.hpp"
+#include "test_gemm_pipeline_util.hpp"
+#include "gtest/gtest.h"
+
+template 
+class TestCkTileGemmPipelineTf32Mem
+    : public TestCkTileGemmPipeline>
+{
+    public:
+    static constexpr bool check_data_type() { return true; }
+};
+
+#define TEST_SUITE_NAME TestCkTileGemmPipelineTf32Mem
+
+TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesTf32Mem);
+
+#include "test_gemm_pipeline_ut_cases.inc"
+
+#undef TEST_SUITE_NAME
diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp
index 1dd9288a66..a4f06bed67 100644
--- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp
@@ -135,6 +135,10 @@ class TestCkTileGemmPipeline : public ::testing::Test
     static constexpr bool Persistent =
         ck_tile::tuple_element_or_default_t::value;
 
+    // TF32 uses tf32_t as compute type but float as buffer/storage type
+    using ADataTypeBuf = ck_tile::if_select_t;
+    using BDataTypeBuf = ck_tile::if_select_t;
+
     protected:
     template 
     void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
@@ -183,12 +187,16 @@ class TestCkTileGemmPipeline : public ::testing::Test
                                                                      NumWaveGroup,
                                                                      preshuffle>;
 
-        using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem;
+        using UniversalGemmProblem =
+            ck_tile::UniversalGemmPipelineProblem;
 
         using GemmPipeline =
             typename GemmPipelineTypeSelector::pipeline;
@@ -304,24 +312,23 @@ class TestCkTileGemmPipeline : public ::testing::Test
         ck_tile::index_t stride_C =
             ck_tile::get_default_stride(M, N, StrideC, is_row_major(CLayout{}));
 
-        ck_tile::HostTensor a_m_k(
+        ck_tile::HostTensor a_m_k(
             ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{})));
-        ck_tile::HostTensor b_k_n(
+        ck_tile::HostTensor b_k_n(
             ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{})));
         ck_tile::HostTensor c_m_n_dev_result(
             ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
 
-        ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11939}(a_m_k);
-        ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11940}(b_k_n);
+        ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11939}(a_m_k);
+        ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11940}(b_k_n);
 
         ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
         ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
         ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
 
-        if constexpr(std::is_same_v)
+        if constexpr(std::is_same_v)
         {
-            // Permute vector pk_i4x4 data for device implementation
-            ck_tile::HostTensor b_k_n_dev = b_k_n;
+            ck_tile::HostTensor b_k_n_dev = b_k_n;
             permute_vectors_i4x4_b(b_k_n_dev);
             b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
         }
diff --git a/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp b/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
index 899221547f..a1b43460c1 100644
--- a/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
+++ b/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
@@ -20,6 +20,12 @@ struct DataTypeTraits
     static constexpr const char* name = "fp32";
 };
 
+template <>
+struct DataTypeTraits
+{
+    static constexpr const char* name = "tf32";
+};
+
 template <>
 struct DataTypeTraits
 {