Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Epilogue Broadcast: Adapter vs GemmUniversal #1459

Open
jeromeku opened this issue Apr 7, 2024 · 4 comments
Open

[QST] Epilogue Broadcast: Adapter vs GemmUniversal #1459

jeromeku opened this issue Apr 7, 2024 · 4 comments
Labels
question Question

Comments

@jeromeku
Copy link
Contributor

jeromeku commented Apr 7, 2024

What is your question?
Trying to understand the behavior of Gemm with a column-broadcasted bias vector epilogue.

When defining a device GemmUniversalWithBroadcast with the following config:

using DType = cutlass::half_t;
using ElementWiseOp = cutlass::epilogue::thread::Identity<DType>;
using BinaryOp = cutlass::plus<DType>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;

constexpr int stages = 3;
using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;

using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
    DType,
    DType,
    DType,
    DType,
    DType,
    8,
    ElementWiseOp,
    BinaryOp>;

using GemmUniversal = cutlass::gemm::device::GemmUniversalWithBroadcast<
    DType,
    LayoutA,
    DType, LayoutB,
    DType, LayoutC,
    DType,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    EpilogueOutputOp,
    ThreadBlockSwizzle,
    stages>;

I get a core dump whenever I try to run the above with M != K. Running with M == N, I get the correct GEMM but the epilogue is broadcasted incorrectly (row-wise vs column-wise).

When I run the above using GemmUniversalAdapter as the device handle, the op runs for all M and N. However, the A and B inputs transposed because of an internal transpose that the adapter does, while the epilogue op is performed correctly.

Questions

  • How to properly instantiate / use a GemmUniversalWithBroadcast?
  • Why does the GemmUniversalAdapter transpose layouts internally?

Repro

Here is a simple script for reproducing above.

  • GemmUniversalWithBroadcast will fail to run with M != N
  • GemmUniversalWithBroadcast runs with M == N but epilogue incorrect
  • GemmUniversalAdapter runs, but with operands A and B transposed.
#include <iostream>

#include "cutlass/cutlass.h"
#include "cutlass/functional.h"

#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"

#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"

#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

#define CUTLASS_CHECK(status)                                                                    \
  {                                                                                              \
    cutlass::Status error = status;                                                              \
    if (error != cutlass::Status::kSuccess)                                                      \
    {                                                                                            \
      std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
                << std::endl;                                                                    \
      exit(EXIT_FAILURE);                                                                        \
    }                                                                                            \
  }

using DType = cutlass::half_t;
using ElementWiseOp = cutlass::epilogue::thread::Identity<DType>;
using BinaryOp = cutlass::plus<DType>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;

constexpr int stages = 3;
using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;

using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
    DType,
    DType,
    DType,
    DType,
    DType,
    8,
    ElementWiseOp,
    BinaryOp>;

using GemmKernel =
    typename cutlass::gemm::kernel::DefaultGemmWithBroadcast<
        DType, LayoutA, cutlass::ComplexTransform::kNone, 8, // transposed B operand
        DType, LayoutB, cutlass::ComplexTransform::kNone, 8, // transposed A operand
        DType, LayoutC,
        DType,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        cutlass::gemm::GemmShape<128, 128, 32>,
        cutlass::gemm::GemmShape<64, 64, 32>,
        cutlass::gemm::GemmShape<16, 8, 16>,
        EpilogueOutputOp,
        ThreadBlockSwizzle,
        stages,
        cutlass::arch::OpMultiplyAdd>::GemmKernel;

using GemmUniversal = cutlass::gemm::device::GemmUniversalWithBroadcast<
    DType,
    LayoutA,
    DType, LayoutB,
    DType, LayoutC,
    DType,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    EpilogueOutputOp,
    ThreadBlockSwizzle,
    stages>;

using GemmAdapter = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

template <typename Gemm>
void test(int M = 8, int N = 4, int K = 8, bool verbose = true, int batch_count = 1,
          cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
          DType alpha = DType(1.0), DType beta = DType(0.0))
{
  cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord(M, N, K);
  cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;
  cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;
  cutlass::HostTensor<EpilogueOutputOp::ElementZ, typename Gemm::LayoutC> tensor_Z;
  cutlass::HostTensor<EpilogueOutputOp::ElementVector, typename Gemm::LayoutC> tensor_Broadcast;

  tensor_A.resize({problem_size.m(), problem_size.k()});
  tensor_B.resize({problem_size.k(), problem_size.n()});
  tensor_Z.resize({problem_size.m(), problem_size.n()});
  tensor_Broadcast.resize({problem_size.m(), 1});
  cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), tensor_A.capacity());
  cutlass::reference::host::BlockFill(tensor_B.host_data(), tensor_B.capacity(), typename Gemm::ElementB(1.0));
  // cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), tensor_B.capacity());
  cutlass::reference::host::BlockFill(tensor_Z.host_data(), tensor_Z.capacity(), EpilogueOutputOp::ElementZ(0.0));
  cutlass::reference::host::BlockFillSequential(tensor_Broadcast.host_data(), tensor_Broadcast.capacity());

  tensor_A.sync_device();
  tensor_B.sync_device();
  tensor_Z.sync_device();
  tensor_Broadcast.sync_device();
  if (verbose)
  {
    std::cout << "tensor_A:\n"
              << tensor_A.host_view() << std::endl;
    std::cout << "tensor_B:\n"
              << tensor_B.host_view() << std::endl;
    std::cout << "tensor_Broadcast:\n"
              << tensor_Broadcast.host_view() << std::endl;
  }

  typename Gemm::Arguments arguments{
      mode,
      problem_size,
      batch_count,
      {alpha, beta},
      tensor_A.device_data(),
      tensor_B.device_data(),
      nullptr, // C
      tensor_Z.device_data(),
      tensor_Broadcast.device_data(),
      nullptr,                             // T
      problem_size.m() * problem_size.k(), // batch stride A
      problem_size.n() * problem_size.k(), // batch stride B
      problem_size.m() * problem_size.n(), // batch stride C
      problem_size.m() * problem_size.n(), // batch stride Z
      problem_size.m(),                    // batch stride broadcast
      problem_size.m() * problem_size.n(), // batch stride T
      tensor_A.layout().stride(0),         // stride A
      tensor_B.layout().stride(0),         // stride B
      tensor_Z.layout().stride(0),         // stride C
      tensor_Z.layout().stride(0),         // stride Z
      0,                                   // This must be zero for broadcast
      tensor_Z.layout().stride(0),         // stride T
  };

  Gemm gemm_op;

  size_t workspace_size = Gemm::get_workspace_size(arguments);

  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  cutlass::Status status = gemm_op.initialize(arguments, workspace.get());

  CUTLASS_CHECK(status);

  status = gemm_op();

  CUTLASS_CHECK(status);
  tensor_Z.sync_host();
  std::cout << "tensor_Z:\n"
            << tensor_Z.host_view() << std::endl;
}
int main()
{
  int M = 8;
  int N = 8;
  int K = 8;

  // NOTE: Running with `GemmUniversalBroadcast` will segfault if M != N
  std::cout << "GemmUniversalBroadcast" << std::endl;
  test<GemmUniversal>(M, N, K);
  std::cout << " ----------------------- " << std::endl;

  std::cout << "GemmAdapterBroadcast" << std::endl;
  test<GemmAdapter>(M, N, K);
}
@jeromeku
Copy link
Contributor Author

jeromeku commented Apr 7, 2024

As a follow-up, trying to implement the above using epilogue visitor trees.

Encountering 2 problems:

  • Constructing the EVT args: I get no instance of constructor "cute::tuple<T...>" matches the argument list compiler error when trying to construct the arguments for a DefaultGemmVisitor kernel with a simple EVT that does a columnwise bias broadcast. I've followed the nested structure of the EVT, which consists of a store node and a EVTCompute node which contains the bias, accumulator, and compute node but clearly there's still a mistake.
  • Using the GemmUniversalAdapter with a DefaultGemmVisitor that uses a threadblock swizzle other than streamk to construct the arguments (as opposed to directly using the gemm kernel per above), I get
utlass/include/cutlass/gemm/kernel/gemm_universal.h(78): here is inaccessible
          detected during instantiation of class "cutlass::gemm::device::GemmUniversalAdapter<GemmKernel_, std::enable_if_t<<expression>, void>> [with GemmKernel_=EVTKernel]" 

I tried to tweak the streamk with broadcast example with the above changes (simpler EVT and ThreadBlockIdentitySwizzle).

Below is the full script:

#include <iostream>

#include "cutlass/cutlass.h"
#include "cutlass/functional.h"

#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"

#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"

#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"

#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"

#include "cute/tensor.hpp"
using namespace cute;

/////////////////////////////////////////////////////////////////////////////////////////////////

#define CUTLASS_CHECK(status)                                                                          \
    {                                                                                                  \
        cutlass::Status error = status;                                                                \
        if (error != cutlass::Status::kSuccess)                                                        \
        {                                                                                              \
            std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
                      << std::endl;                                                                    \
            exit(EXIT_FAILURE);                                                                        \
        }                                                                                              \
    }

using DType = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;

constexpr int stages = 3;
using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;

// EVT
constexpr int Alignment = 128 / cutlass::sizeof_bits_v<DType>;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;          // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;    // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4;                                     // Number of global->shared pipeline stages used in the GEMM mainloop
constexpr int EVTEpilogueStages = 1;                             // Number of epilogue stages in EVT

using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
    ThreadblockShape,
    WarpShape,
    DType,
    Alignment,
    EVTEpilogueStages>;

using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
    OutputTileThreadMap, cutlass::half_t,
    cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::multiplies, DType, DType,
    cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
    Compute0,
    Scale,
    Accum>;

using D = cutlass::epilogue::threadblock::VisitorAuxStore<
    OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
    cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;

using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
    D,
    EVTCompute0>;

using EVTKernel =
    typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
        DType, LayoutA, cutlass::ComplexTransform::kNone, Alignment,
        DType, LayoutB, cutlass::ComplexTransform::kNone, Alignment,
        DType, LayoutC, Alignment,
        DType,
        DType,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        ThreadblockShape,
        WarpShape,
        InstructionShape,
        EVTD,
        ThreadBlockSwizzle,
        NumStages,
        cutlass::arch::OpMultiplyAdd,
        EVTEpilogueStages>::GemmKernel;

using DeviceGemmEVT = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;

template <typename Gemm>
void test(int M = 8, int N = 4, int K = 8, bool verbose = true, int batch_count = 1,
          cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
          DType alpha = DType(1.0), DType beta = DType(0.0))
{
    cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord(M, N, K);
    cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;
    cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;
    cutlass::HostTensor<DType, typename Gemm::LayoutC> tensor_Z;
    cutlass::HostTensor<DType, typename Gemm::LayoutC> tensor_Broadcast;

    tensor_A.resize({problem_size.m(), problem_size.k()});
    tensor_B.resize({problem_size.k(), problem_size.n()});
    tensor_Z.resize({problem_size.m(), problem_size.n()});
    tensor_Broadcast.resize({problem_size.m(), 1});
    cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), tensor_A.capacity());
    cutlass::reference::host::BlockFill(tensor_B.host_data(), tensor_B.capacity(), DType(1.0));

    cutlass::reference::host::BlockFill(tensor_Z.host_data(), tensor_Z.capacity(), DType(0.0));
    cutlass::reference::host::BlockFillSequential(tensor_Broadcast.host_data(), tensor_Broadcast.capacity());

    tensor_A.sync_device();
    tensor_B.sync_device();
    tensor_Z.sync_device();
    tensor_Broadcast.sync_device();
    if (verbose)
    {
        std::cout << "tensor_A:\n"
                  << tensor_A.host_view() << std::endl;
        std::cout << "tensor_B:\n"
                  << tensor_B.host_view() << std::endl;
        std::cout << "tensor_Broadcast:\n"
                  << tensor_Broadcast.host_view() << std::endl;
    }
    typename EVTD::Arguments callback_args{
        {
            {},                                                                                  // Compute0
            {tensor_Broadcast.device_data(), DType(0), {_0{}, _1{}, int32_t(problem_size.m())}}, // bias / scale
            {}                                                                                   // Accum
        },                                                                                       // EvtCompute0
        {tensor_Z.device_data(), {problem_size.n(), _1{}, problem_size.mn().product()}},         // D
    };
    typename EVTKernel::Arguments evtArgs{
        cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
        problem_size,                            // problem_size
        1,                                       // batch count / splitk slices
        callback_args,                           // argument of EVT callbacks
        tensor_A.device_data(),                  // ptr_A
        tensor_B.device_data(),                  // ptr_B
        nullptr,                                 // ptr_C (unused)
        nullptr,                                 // ptr_D (unused)
        problem_size.mk().product(),             // batch_stride_A
        problem_size.nk().product(),             // batch_stride_B
        0,                                       // batch_stride_C (unused)
        0,                                       // batch_stride_D (unused)
        tensor_A.layout().stride(0),             // stride_a
        tensor_B.layout().stride(0),             // stride_b
        0,                                       // stride_c (unused)
        0                                        // stride_d (unused)
    };
}
int main()
{
    int M = 8;
    int N = 8;
    int K = 8;

    std::cout << "GemmEVT" << std::endl;
    test<EVTKernel>(M, N, K);
}

@hgyhungry
Copy link

Not an expert but I recently made the exactly same problem when crafting my custom epilogue visitor tree. Here is what I think :
for your first problem (cute type):
Change

using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
    OutputTileThreadMap, cutlass::half_t,
    cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

To

using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
    OutputTileThreadMap, cutlass::half_t,
    cute::Stride<cute::Int<1>, cute::Int<0>, int32_t>>; // <-- the only change is the type of last parameter, which should be int or int64_t

Do the same with type D

@hgyhungry
Copy link

For your second problem, the reason why the example works well with ThreadblockSwizzleStreamK, but doesn't work with GemmIdentityThreadblockSwizzle<>, is because underlying the DefaultGemmWithVisitor class, this difference in parameter leads to dispatch to different GemmKernel types as described here

class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature> :

to me more specific, the example 47 uses GemmWithEpilogueVisitorStreamk (link) as the EVTKernelStreamK, but your example will use GemmWithEpilogueVisitor (link) as your EVTKernel.
With that clear, apparently the GemmWithEpilogueVisitor contains some issues while GemmWithEpilogueVisitorStreamK works just fine. I had my own fix, but I suggest you post your whole error message and let a maintainer expert fix it.

@thakkarV
Copy link
Collaborator

@hwu36

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Question
Projects
None yet
Development

No branches or pull requests

4 participants