You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Expected behavior
Expect the jitted pytorch module to work per the non-pytorch version (using plan.run, which compiles and runs the kernel directly through pycuda / C interface).
Environment details (please complete the following information):
GPU: A6000
nvidia-cutlass: 3.5.0
Additional Context
Below is the generated extension module (with jit set to False).
Issues:
The code refers to DeviceKernel but none is generated
Even though the EVT is declared, none of the interface functions provide args for the visitor func
// This file was automatically generated by the CUTLASS 3.5.0 Python interface (https://github.com/nvidia/cutlass/python)
#include<cuda_runtime.h>
#include<torch/extension.h>
#include<ATen/ATen.h>
#include<ATen/cuda/CUDAContext.h>
#include"cutlass/cutlass.h"
#include"cutlass/util/device_memory.h"
#include"cutlass/gemm_coord.h"
#include"cutlass/numeric_types.h"
#include"cutlass/arch/arch.h"
#include"cutlass/arch/mma.h"
#include"cutlass/layout/matrix.h"
#include"cutlass/gemm/device/gemm.h"
#include"cutlass/gemm/device/gemm_universal_adapter.h"
#include"cutlass/gemm/kernel/default_gemm_universal.h"
#include"cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include"cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"// helper function allocating the memoryvoid *device_memory_allocation(size_t size, int device_id = 0)
{
if (size > 0)
{
torch::Device device(torch::kCUDA, device_id);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
at::Tensor device_tensor = torch::empty({
(long)size,
},
options);
returnreinterpret_cast<void *>(device_tensor.data_ptr());
}
else
{
returnnullptr;
}
}
#include"cutlass/gemm/device/gemm_universal.h"using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::half_t,
8,
1/* epilogue stages */
>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, cutlass::half_t,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, cutlass::half_t, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
Compute0,
Scale,
Accum>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
D,
EVTCompute0>;
// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8using cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
EVTD,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd,
1/* epilogue stages */
>::GemmKernel;
// Define named typestructcutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_type : publiccutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base
{
};
using DeviceKernel = cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_type;
using ElementCompute = typename DeviceKernel::ElementC;
cutlass::Status epilogue_broadcast_kernel_run(int M, int N, int K,
const DeviceKernel::ElementA *A, const DeviceKernel::ElementB *B, const DeviceKernel::ElementC *C, DeviceKernel::ElementC *D,
ElementCompute alpha, ElementCompute beta)
{
typename DeviceKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K}, // problem size1,
{alpha, beta},
A,
B,
C,
D,
0,
0,
0,
0, // batch stridesDeviceKernel::LayoutA::packed({M, K}).stride(0), // ldaDeviceKernel::LayoutB::packed({K, N}).stride(0), // ldbDeviceKernel::LayoutC::packed({M, N}).stride(0), // ldcDeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
};
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
DeviceKernel gemm_op;
cutlass::Status status = gemm_op.initialize(arguments,
workspace.get(),
nullptr); // CUDA streamif (status != cutlass::Status::kSuccess)
{
return status;
}
status = gemm_op();
return status;
}
at::Tensor epilogue_broadcast_kernel(const at::Tensor &A, const at::Tensor &B, at::optional<const at::Tensor> C, float alpha, float beta)
{
int M = A.size(0);
int N = B.size(1);
int K = A.size(1);
typename DeviceKernel::ElementC *ptrC = (C == at::nullopt) ? nullptr : reinterpret_cast<typename DeviceKernel::ElementC *>(C->contiguous().data_ptr());
at::Tensor D = B.new_empty({M, N}, torch::kF16);
cutlass::Status status = epilogue_broadcast_kernel_run(M, N, K,
reinterpret_cast<typename DeviceKernel::ElementA *>(A.contiguous().data_ptr()),
reinterpret_cast<typename DeviceKernel::ElementB *>(B.contiguous().data_ptr()),
ptrC,
reinterpret_cast<typename DeviceKernel::ElementC *>(D.contiguous().data_ptr()),
ElementCompute(alpha), ElementCompute(beta));
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
The text was updated successfully, but these errors were encountered:
We haven't yet done the plumbing to emit the correct EVT arguments structures for creating a PyTorch extension for a kernel that uses EVT. Apologies that this hasn't been better documented and lacks a clear error indicating the lack of support.
Describe the bug
The Python pytorch emitter does not output functioning code when compiling
Gemm
with anEVT
.Steps/Code to reproduce bug
The script below reproduces the bug.
Switch
jit
toTrue
when callingcutlass.emit.pytorch
to see the generated code (see additional context, as well).Expected behavior
Expect the
jitted
pytorch module to work per the non-pytorch version (usingplan.run
, which compiles and runs the kernel directly throughpycuda
/C
interface).Environment details (please complete the following information):
A6000
nvidia-cutlass
:3.5.0
Additional Context
Below is the generated extension module (with
jit
set toFalse
).Issues:
DeviceKernel
but none is generatedEVT
is declared, none of the interface functions provide args for the visitor funcThe text was updated successfully, but these errors were encountered: