-
Notifications
You must be signed in to change notification settings - Fork 815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] Epilogue Broadcast: Adapter
vs GemmUniversal
#1459
Comments
As a follow-up, trying to implement the above using epilogue visitor trees. Encountering 2 problems:
I tried to tweak the streamk with broadcast example with the above changes (simpler EVT and Below is the full script: #include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cute/tensor.hpp"
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) \
{ \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
using DType = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
constexpr int stages = 3;
using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
// EVT
constexpr int Alignment = 128 / cutlass::sizeof_bits_v<DType>;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
constexpr int EVTEpilogueStages = 1; // Number of epilogue stages in EVT
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
DType,
Alignment,
EVTEpilogueStages>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, cutlass::half_t,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, DType, DType,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
Compute0,
Scale,
Accum>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
D,
EVTCompute0>;
using EVTKernel =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
DType, LayoutA, cutlass::ComplexTransform::kNone, Alignment,
DType, LayoutB, cutlass::ComplexTransform::kNone, Alignment,
DType, LayoutC, Alignment,
DType,
DType,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
ThreadBlockSwizzle,
NumStages,
cutlass::arch::OpMultiplyAdd,
EVTEpilogueStages>::GemmKernel;
using DeviceGemmEVT = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
template <typename Gemm>
void test(int M = 8, int N = 4, int K = 8, bool verbose = true, int batch_count = 1,
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
DType alpha = DType(1.0), DType beta = DType(0.0))
{
cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord(M, N, K);
cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;
cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;
cutlass::HostTensor<DType, typename Gemm::LayoutC> tensor_Z;
cutlass::HostTensor<DType, typename Gemm::LayoutC> tensor_Broadcast;
tensor_A.resize({problem_size.m(), problem_size.k()});
tensor_B.resize({problem_size.k(), problem_size.n()});
tensor_Z.resize({problem_size.m(), problem_size.n()});
tensor_Broadcast.resize({problem_size.m(), 1});
cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), tensor_A.capacity());
cutlass::reference::host::BlockFill(tensor_B.host_data(), tensor_B.capacity(), DType(1.0));
cutlass::reference::host::BlockFill(tensor_Z.host_data(), tensor_Z.capacity(), DType(0.0));
cutlass::reference::host::BlockFillSequential(tensor_Broadcast.host_data(), tensor_Broadcast.capacity());
tensor_A.sync_device();
tensor_B.sync_device();
tensor_Z.sync_device();
tensor_Broadcast.sync_device();
if (verbose)
{
std::cout << "tensor_A:\n"
<< tensor_A.host_view() << std::endl;
std::cout << "tensor_B:\n"
<< tensor_B.host_view() << std::endl;
std::cout << "tensor_Broadcast:\n"
<< tensor_Broadcast.host_view() << std::endl;
}
typename EVTD::Arguments callback_args{
{
{}, // Compute0
{tensor_Broadcast.device_data(), DType(0), {_0{}, _1{}, int32_t(problem_size.m())}}, // bias / scale
{} // Accum
}, // EvtCompute0
{tensor_Z.device_data(), {problem_size.n(), _1{}, problem_size.mn().product()}}, // D
};
typename EVTKernel::Arguments evtArgs{
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
problem_size, // problem_size
1, // batch count / splitk slices
callback_args, // argument of EVT callbacks
tensor_A.device_data(), // ptr_A
tensor_B.device_data(), // ptr_B
nullptr, // ptr_C (unused)
nullptr, // ptr_D (unused)
problem_size.mk().product(), // batch_stride_A
problem_size.nk().product(), // batch_stride_B
0, // batch_stride_C (unused)
0, // batch_stride_D (unused)
tensor_A.layout().stride(0), // stride_a
tensor_B.layout().stride(0), // stride_b
0, // stride_c (unused)
0 // stride_d (unused)
};
}
int main()
{
int M = 8;
int N = 8;
int K = 8;
std::cout << "GemmEVT" << std::endl;
test<EVTKernel>(M, N, K);
} |
Not an expert but I recently made the exactly same problem when crafting my custom epilogue visitor tree. Here is what I think :
To
Do the same with type |
For your second problem, the reason why the example works well with
to me more specific, the example 47 uses |
What is your question?
Trying to understand the behavior of Gemm with a column-broadcasted bias vector epilogue.
When defining a device
GemmUniversalWithBroadcast
with the following config:I get a
core dump
whenever I try to run the above withM != K
. Running withM == N
, I get the correctGEMM
but the epilogue is broadcasted incorrectly (row-wise vs column-wise).When I run the above using
GemmUniversalAdapter
as the device handle, the op runs for allM
andN
. However, theA
andB
inputs transposed because of an internal transpose that the adapter does, while the epilogue op is performed correctly.Questions
GemmUniversalWithBroadcast
?GemmUniversalAdapter
transpose layouts internally?Repro
Here is a simple script for reproducing above.
GemmUniversalWithBroadcast
will fail to run withM != N
GemmUniversalWithBroadcast
runs withM == N
but epilogue incorrectGemmUniversalAdapter
runs, but with operandsA
andB
transposed.The text was updated successfully, but these errors were encountered: