Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Python EVT Pytorch Emitter Broken #1462

Open
jeromeku opened this issue Apr 8, 2024 · 2 comments
Open

[BUG] Python EVT Pytorch Emitter Broken #1462

jeromeku opened this issue Apr 8, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@jeromeku
Copy link
Contributor

jeromeku commented Apr 8, 2024

Describe the bug
The Python pytorch emitter does not output functioning code when compiling Gemm with an EVT.

Steps/Code to reproduce bug
The script below reproduces the bug.

Switch jit to True when calling cutlass.emit.pytorch to see the generated code (see additional context, as well).

import torch
import cutlass
from cutlass import Tensor as FakeTensor

print_module = True

m = 8
n = 8
k = 8

type_A = torch.float16
type_B = torch.float16
type_C = torch.float16
type_D = torch.float16

tensor_A = torch.arange(m * k, dtype=type_A, device="cuda").reshape(m, k)
tensor_B = torch.ones(n * k, dtype=type_B, device="cuda").reshape(k, n)
tensor_C = torch.zeros(m * n, dtype=type_C, device="cuda").reshape(m, n)
tensor_D = torch.zeros_like(tensor_C)

plan = cutlass.op.Gemm(
    element=torch.float16,
    layout=cutlass.LayoutType.RowMajor,
    element_accumulator=torch.float32,
)

def epilogue_scale(accum, scale):
    D = scale * accum
    return D

# Construct inputs and outputs
scale = torch.arange(m, dtype=type_C, device="cuda").reshape(m, 1)
examples_tensors = {
    "accum": FakeTensor(
        element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor
    ),
    "scale": scale,
    "D": tensor_D,
}

epilogue_visitor = cutlass.epilogue.trace(epilogue_scale, examples_tensors)
visitor_args = {"scale": scale, "D": tensor_D}

plan.epilogue_visitor = epilogue_visitor

#This works
plan.run(
    tensor_A,
    tensor_B,
    tensor_C,
    tensor_D,
    visitor_args=visitor_args,
    print_module=print_module,
)

binary_op = torch.mul
ref_D = binary_op(tensor_A @ tensor_B, scale)
print(f"ref_D =\n {ref_D}")
print(f"tensor_D =\n {tensor_D}")
print(f"(tensor_D - ref_D).abs().max() = {(tensor_D - ref_D).abs().max()}")

# Below does not work, set jit to False which shows the generated code, which is incorrect
op = plan.construct()
mod = cutlass.emit.pytorch(
    op, name="epilogue_broadcast", cc=plan.cc, sourcedir="epilogue", jit=True
)

Expected behavior
Expect the jitted pytorch module to work per the non-pytorch version (using plan.run, which compiles and runs the kernel directly through pycuda / C interface).

Environment details (please complete the following information):

  • GPU: A6000
  • nvidia-cutlass: 3.5.0

Additional Context
Below is the generated extension module (with jit set to False).

Issues:

  • The code refers to DeviceKernel but none is generated
  • Even though the EVT is declared, none of the interface functions provide args for the visitor func
// This file was automatically generated by the CUTLASS 3.5.0 Python interface (https://github.com/nvidia/cutlass/python)

#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"

#include "cutlass/gemm_coord.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"

// helper function allocating the memory
void *device_memory_allocation(size_t size, int device_id = 0)
{
    if (size > 0)
    {
        torch::Device device(torch::kCUDA, device_id);
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();
        torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
        at::Tensor device_tensor = torch::empty({
                                                    (long)size,
                                                },
                                                options);
        return reinterpret_cast<void *>(device_tensor.data_ptr());
    }
    else
    {
        return nullptr;
    }
}

#include "cutlass/gemm/device/gemm_universal.h"

using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
    cutlass::gemm::GemmShape<256, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::half_t,
    8,
    1 /* epilogue stages */
    >;

using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
    OutputTileThreadMap, cutlass::half_t,
    cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::multiplies, cutlass::half_t, float,
    cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
    Compute0,
    Scale,
    Accum>;

using D = cutlass::epilogue::threadblock::VisitorAuxStore<
    OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
    cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;

using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
    D,
    EVTCompute0>;

// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8
using cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base =
    typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
        cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
        cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
        cutlass::half_t, cutlass::layout::RowMajor, 8,
        float,
        float,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        cutlass::gemm::GemmShape<256, 128, 32>,
        cutlass::gemm::GemmShape<64, 64, 32>,
        cutlass::gemm::GemmShape<16, 8, 16>,
        EVTD,
        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
        3,
        cutlass::arch::OpMultiplyAdd,
        1 /* epilogue stages */
        >::GemmKernel;

// Define named type
struct cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_type : public cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base
{
};

using DeviceKernel = cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_type;
using ElementCompute = typename DeviceKernel::ElementC;

cutlass::Status epilogue_broadcast_kernel_run(int M, int N, int K,
                                              const DeviceKernel::ElementA *A, const DeviceKernel::ElementB *B, const DeviceKernel::ElementC *C, DeviceKernel::ElementC *D,
                                              ElementCompute alpha, ElementCompute beta)
{

    typename DeviceKernel::Arguments arguments{
        cutlass::gemm::GemmUniversalMode::kGemm,
        {M, N, K}, // problem size
        1,
        {alpha, beta},
        A,
        B,
        C,
        D,
        0,
        0,
        0,
        0,                                               // batch strides
        DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
        DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
        DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
        DeviceKernel::LayoutC::packed({M, N}).stride(0)  // ldd
    };

    size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    DeviceKernel gemm_op;
    cutlass::Status status = gemm_op.initialize(arguments,
                                                workspace.get(),
                                                nullptr); // CUDA stream

    if (status != cutlass::Status::kSuccess)
    {
        return status;
    }

    status = gemm_op();
    return status;
}

at::Tensor epilogue_broadcast_kernel(const at::Tensor &A, const at::Tensor &B, at::optional<const at::Tensor> C, float alpha, float beta)
{
    int M = A.size(0);
    int N = B.size(1);
    int K = A.size(1);

    typename DeviceKernel::ElementC *ptrC = (C == at::nullopt) ? nullptr : reinterpret_cast<typename DeviceKernel::ElementC *>(C->contiguous().data_ptr());
    at::Tensor D = B.new_empty({M, N}, torch::kF16);

    cutlass::Status status = epilogue_broadcast_kernel_run(M, N, K,
                                                           reinterpret_cast<typename DeviceKernel::ElementA *>(A.contiguous().data_ptr()),
                                                           reinterpret_cast<typename DeviceKernel::ElementB *>(B.contiguous().data_ptr()),
                                                           ptrC,
                                                           reinterpret_cast<typename DeviceKernel::ElementC *>(D.contiguous().data_ptr()),
                                                           ElementCompute(alpha), ElementCompute(beta));

    TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
    return D;
}
@jeromeku jeromeku added ? - Needs Triage bug Something isn't working labels Apr 8, 2024
@jackkosaian
Copy link
Contributor

We haven't yet done the plumbing to emit the correct EVT arguments structures for creating a PyTorch extension for a kernel that uses EVT. Apologies that this hasn't been better documented and lacks a clear error indicating the lack of support.

@jeromeku
Copy link
Contributor Author

@jackkosaian Thanks for the response.

Are there any examples or documentation on how to properly construct arguments for an EVT, other than the streamk example?

Moreover, I'm having trouble with the different epilogue interfaces, #1459, for a relatively simple example. Would appreciate any help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants