/* ************************************************************************
 *
 * MIT License
 *
 * Copyright (C) 2024-2025 Advanced Micro Devices, Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 *
 * ************************************************************************ */
#pragma once
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>

#if defined(__HIP_PLATFORM_AMD__)
#if HIP_VERSION_MAJOR < 7
#define HIPBLASLT_HIPVEC_ACCESS(x) x.data
#else
#define HIPBLASLT_HIPVEC_ACCESS(x) x
#endif
#endif

#define TRANSFORM_FUNC_NAME_HELPER(                                           \
    DType, SType, RowMajA, RowMajB, RowMajC, ThreadsM, ThreadsN, VectorWidth) \
    Transform_##DType##_##SType##_##RowMajA##RowMajB##RowMajC##_##ThreadsM##_##ThreadsN##_VW_##VectorWidth
#define TRANSFORM_FUNC_NAME(DType, SType, RowMajA, RowMajB, RowMajC, ThreadsM, ThreadsN, VW) \
    TRANSFORM_FUNC_NAME_HELPER(DType, SType, RowMajA, RowMajB, RowMajC, ThreadsM, ThreadsN, VW)
#define DTYPE_HELPER(DTypeStr) DType##DTypeStr
#define DTYPE(DTypeStr) DTYPE_HELPER(DTypeStr)
#define STRINGIFY(x) #x
#define TO_STRING(x) STRINGIFY(x)

typedef float        DTypeS;
typedef _Float16     DTypeH;
typedef hip_bfloat16 DTypeBF16;
typedef int32_t      DTypeI32;
typedef int8_t       DTypeI8;

extern "C" {
__global__ void TRANSFORM_FUNC_NAME(S, S, 1, 1, 1, 16, 16, 1)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 1, 1, 1, 16, 16, 4)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 1, 1, 0, 16, 16, 1)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 1, 1, 0, 16, 16, 4)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 1, 0, 1, 16, 16, 1)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 1, 0, 1, 16, 16, 4)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 1, 0, 0, 16, 16, 1)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 1, 0, 0, 16, 16, 4)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 0, 1, 1, 16, 16, 1)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 0, 1, 1, 16, 16, 4)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 0, 1, 0, 16, 16, 1)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 0, 1, 0, 16, 16, 4)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 0, 0, 1, 16, 16, 1)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 0, 0, 1, 16, 16, 4)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 0, 0, 0, 16, 16, 1)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(S, S, 0, 0, 0, 16, 16, 4)(DTYPE(S) * c,
                                                              const DTYPE(S) * a,
                                                              const DTYPE(S) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 1, 1, 1, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 1, 1, 1, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 1, 1, 0, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 1, 1, 0, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 1, 0, 1, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 1, 0, 1, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 1, 0, 0, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 1, 0, 0, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 0, 1, 1, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 0, 1, 1, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 0, 1, 0, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 0, 1, 0, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 0, 0, 1, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 0, 0, 1, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 0, 0, 0, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, H, 0, 0, 0, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(H) alpha,
                                                              const DTYPE(H) * alphaPtr,
                                                              DTYPE(H) beta,
                                                              const DTYPE(H) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 1, 1, 1, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 1, 1, 1, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 1, 1, 0, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 1, 1, 0, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 1, 0, 1, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 1, 0, 1, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 1, 0, 0, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 1, 0, 0, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 0, 1, 1, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 0, 1, 1, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 0, 1, 0, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 0, 1, 0, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 0, 0, 1, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 0, 0, 1, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 0, 0, 0, 16, 16, 1)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(H, S, 0, 0, 0, 16, 16, 4)(DTYPE(H) * c,
                                                              const DTYPE(H) * a,
                                                              const DTYPE(H) * b,
                                                              DTYPE(S) alpha,
                                                              const DTYPE(S) * alphaPtr,
                                                              DTYPE(S) beta,
                                                              const DTYPE(S) * betaPtr,
                                                              uint32_t numRows,
                                                              uint32_t numCols,
                                                              uint32_t ldA,
                                                              uint32_t ldB,
                                                              uint32_t ldC,
                                                              uint32_t batchStride,
                                                              bool     transA,
                                                              bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 1, 1, 1, 16, 16, 1)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 1, 1, 1, 16, 16, 4)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 1, 1, 0, 16, 16, 1)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 1, 1, 0, 16, 16, 4)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 1, 0, 1, 16, 16, 1)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 1, 0, 1, 16, 16, 4)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 1, 0, 0, 16, 16, 1)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 1, 0, 0, 16, 16, 4)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 0, 1, 1, 16, 16, 1)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 0, 1, 1, 16, 16, 4)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 0, 1, 0, 16, 16, 1)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 0, 1, 0, 16, 16, 4)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 0, 0, 1, 16, 16, 1)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 0, 0, 1, 16, 16, 4)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 0, 0, 0, 16, 16, 1)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(BF16, S, 0, 0, 0, 16, 16, 4)(DTYPE(BF16) * c,
                                                                 const DTYPE(BF16) * a,
                                                                 const DTYPE(BF16) * b,
                                                                 DTYPE(S) alpha,
                                                                 const DTYPE(S) * alphaPtr,
                                                                 DTYPE(S) beta,
                                                                 const DTYPE(S) * betaPtr,
                                                                 uint32_t numRows,
                                                                 uint32_t numCols,
                                                                 uint32_t ldA,
                                                                 uint32_t ldB,
                                                                 uint32_t ldC,
                                                                 uint32_t batchStride,
                                                                 bool     transA,
                                                                 bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 1, 1, 1, 16, 16, 1)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 1, 1, 1, 16, 16, 4)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 1, 1, 0, 16, 16, 1)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 1, 1, 0, 16, 16, 4)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 1, 0, 1, 16, 16, 1)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 1, 0, 1, 16, 16, 4)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 1, 0, 0, 16, 16, 1)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 1, 0, 0, 16, 16, 4)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 0, 1, 1, 16, 16, 1)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 0, 1, 1, 16, 16, 4)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 0, 1, 0, 16, 16, 1)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 0, 1, 0, 16, 16, 4)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 0, 0, 1, 16, 16, 1)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 0, 0, 1, 16, 16, 4)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 0, 0, 0, 16, 16, 1)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I8, S, 0, 0, 0, 16, 16, 4)(DTYPE(I8) * c,
                                                               const DTYPE(I8) * a,
                                                               const DTYPE(I8) * b,
                                                               DTYPE(S) alpha,
                                                               const DTYPE(S) * alphaPtr,
                                                               DTYPE(S) beta,
                                                               const DTYPE(S) * betaPtr,
                                                               uint32_t numRows,
                                                               uint32_t numCols,
                                                               uint32_t ldA,
                                                               uint32_t ldB,
                                                               uint32_t ldC,
                                                               uint32_t batchStride,
                                                               bool     transA,
                                                               bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 1, 1, 1, 16, 16, 1)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 1, 1, 1, 16, 16, 4)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 1, 1, 0, 16, 16, 1)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 1, 1, 0, 16, 16, 4)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 1, 0, 1, 16, 16, 1)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 1, 0, 1, 16, 16, 4)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 1, 0, 0, 16, 16, 1)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 1, 0, 0, 16, 16, 4)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 0, 1, 1, 16, 16, 1)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 0, 1, 1, 16, 16, 4)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 0, 1, 0, 16, 16, 1)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 0, 1, 0, 16, 16, 4)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 0, 0, 1, 16, 16, 1)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 0, 0, 1, 16, 16, 4)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 0, 0, 0, 16, 16, 1)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
__global__ void TRANSFORM_FUNC_NAME(I32, S, 0, 0, 0, 16, 16, 4)(DTYPE(I32) * c,
                                                                const DTYPE(I32) * a,
                                                                const DTYPE(I32) * b,
                                                                DTYPE(S) alpha,
                                                                const DTYPE(S) * alphaPtr,
                                                                DTYPE(S) beta,
                                                                const DTYPE(S) * betaPtr,
                                                                uint32_t numRows,
                                                                uint32_t numCols,
                                                                uint32_t ldA,
                                                                uint32_t ldB,
                                                                uint32_t ldC,
                                                                uint32_t batchStride,
                                                                bool     transA,
                                                                bool     transB);
}
