Elementwise Operator Family¶
1. Overview¶
TileOPs elementwise operator family covers 72 ops across 12 sub-categories, with fp8 support, torch.compile integration, kernel caching, and performance tuning.
Scope (72 ops)¶
| Category | Ops | Count |
|---|---|---|
| unary_math | exp, log, sqrt, rsqrt, abs, neg, reciprocal, sign, sin, cos, floor, ceil, round, trunc, erf, log1p, expm1 | 17 |
| binary_arith | add, sub, mul, div, remainder, pow, floor_divide, lerp, maximum, minimum | 10 |
| activation | relu, gelu, silu, sigmoid, tanh, leaky_relu, elu, selu, hardswish, hardsigmoid, hardtanh, softplus, mish, prelu | 14 |
| fused_gated_activation | silu_and_mul (SwiGLU), gelu_and_mul, gelu_tanh_and_mul | 3 |
| comparison | eq, ne, gt, lt, ge, le | 6 |
| bitwise | bitwise_and, bitwise_or, bitwise_xor, bitwise_not | 4 |
| logical | logical_not, logical_and, logical_or | 3 |
| special_elementwise | where, clamp, masked_fill, nan_to_num, isnan, isinf, isfinite | 7 |
| rope | neox, non_neox, llama31, yarn, longrope | 5 |
| dropout | dropout | 1 |
| alibi | ALiBi position encoding | 1 |
| sinusoidal | Sinusoidal position encoding | 1 |
| 72 |
Supported Dtypes¶
fp32, bf16, fp16, fp8_e4m3fn, fp8_e5m2
2. Design Decisions¶
2.1 Kernel Template Strategy¶
Ops are grouped by input signature into template base classes. Each concrete op is a 3-5 line subclass overriding only op_func. Ops with user-configurable parameters or non-standard signatures are implemented independently.
Template to Op Mapping:
| Template | Signature | Ops | Count |
|---|---|---|---|
| UnaryKernel | op_func(x) → y |
unary_math (17); activation (9): relu, gelu, silu, sigmoid, tanh, hardswish, hardsigmoid, mish, selu; logical_not (1); bitwise_not (1); isnan, isinf, isfinite (3) | 31 |
| BinaryKernel | op_func(a, b) → y, with output_dtype and broadcast |
binary_arith (10); comparison (6); logical_and, logical_or (2); bitwise_and, bitwise_or, bitwise_xor (3) | 21 |
| FusedGatedKernel | activation_func(gate) * value, input (M, 2N) split |
silu_and_mul, gelu_and_mul, gelu_tanh_and_mul | 3 |
| Independent | Custom signatures, no shared template base class | leaky_relu, elu, hardtanh, softplus, prelu, where, clamp, masked_fill, nan_to_num, alibi, sinusoidal | 11 |
| Total | 66 |
rope (5 ops) and dropout (1 op) are in separate files (
rope.py,dropout.py) due to fundamentally different kernel patterns (complex rotary addressing, PRNG state management).
Notes:
- hardsigmoid、selu have fixed constants (not user-configurable), hardcoded in
op_func, classified under UnaryKernel. - softplus has user-configurable beta/threshold → independent implementation.
- Comparison ops output
int8at the kernel level; the Op layer casts totorch.bool.
2.2 File Organization¶
tileops/kernels/
├── elementwise.py # All template base classes + concrete kernel subclasses + independent kernels (~3,000 lines)
├── rope.py # 5 rope variants
├── dropout.py # Dropout with PRNG
tileops/ops/
├── elementwise.py # All template base Op classes + concrete op subclasses + custom_op registration (~1,800 lines)
tests/
├── ops/test_special_elementwise.py
├── test_elementwise_fp8.py
├── test_elementwise_compile.py
├── test_elementwise_independent_fp8.py
├── test_elementwise_strategy_bench.py
├── test_elementwise_caching_autotune.py
benchmarks/ops/
├── bench_unary_elementwise.py
├── bench_binary_elementwise.py
├── bench_independent_elementwise.py
├── bench_elementwise_fp8.py
Single elementwise.py for both kernels and ops. Tests split by concern (correctness, fp8, compile, strategy, caching). Benchmarks split by kernel template.
2.3 Re-export Strategy¶
tileops/kernels/__init__.py: Re-exports only 3 base classes (UnaryKernel,BinaryKernel,FusedGatedKernel).tileops/ops/__init__.py: Re-exports only 3 base Op classes (UnaryOp,BinaryOp,FusedGatedOp).- Concrete classes imported directly:
from tileops.kernels.elementwise import ReluKernel.
2.4 Concrete Kernel Class per Op¶
Each op is an explicit class:
Chosen over factory functions for: IDE navigation, isinstance checks, traceback clarity, and per-op default_config / autotune_configs overrides.
2.5 torch.compile Support: custom_op Registration¶
All 66 ops support torch.compile via @torch.library.custom_op registration.
Registry mechanism: _OP_REGISTRY = weakref.WeakValueDictionary() keyed by id(instance). Each Op instance registers itself at construction time; the custom_op wrapper looks up the instance by key and delegates to _eager_forward.
Factory functions (7 total, registered at module load time):
| Factory | Ops | Signature |
|---|---|---|
_register_unary_custom_op |
31 unary ops | (x, instance_key) → Tensor |
_register_binary_custom_op |
21 binary ops | (a, b, out_shape, instance_key) → Tensor |
_register_fused_gated_custom_op |
3 fused gated ops | (x, M, N, instance_key) → Tensor |
_register_prelu_custom_op |
prelu | (x, weight, instance_key) → Tensor |
_register_where_custom_op |
where | (cond, x, y, instance_key) → Tensor |
_register_masked_fill_custom_op |
masked_fill | (x, mask, instance_key) → Tensor |
_register_generative_custom_op |
alibi, sinusoidal | (device_carrier, num_a, num_b, instance_key) → Tensor |
Each Op class's forward dispatches through _wrapped(...) when torch.compile is active, falling back to _eager_forward otherwise.
3. Kernel Architecture¶
3.1 Shape Convention¶
Unary ops: Arbitrary shape, flattened to 1D (N_total,) by the Op layer before passing to kernel.
Binary ops: Stride-based offset with N-dim broadcast (equivalent to PyTorch TensorIterator semantics).
Op layer processing:
1. torch.broadcast_shapes(a.shape, b.shape) → output shape
2. Align dimensions (left-pad with 1), compute broadcast strides (stride=0 for broadcast dims)
3. Dimension coalescing (coalesce_broadcast_dims): merge adjacent dims with compatible broadcast patterns to minimize divmod count
4. Both inputs .contiguous().view(-1) flattened to 1D; (coalesced_shape, a_strides, b_strides) passed as compile-time constants
5. Kernel computes output via flat index (contiguous), maps to input offsets through divmod chain
FusedGated ops: Input (M, 2N) → 2D grid, no broadcast.
3.2 Memory Access Pattern: No Shared Memory¶
Elementwise ops are pure memory-bandwidth bound (arithmetic intensity ≈ 0). The optimal path is:
- No shared memory allocation or budget constraints
T.Parallelmaps threads + 128-bit vectorized coalesced access (uint4loads/stores)LegalizeSafeMemoryAccesscompiler pass auto-inserts boundary guards
3.3 Three Kernel Strategies¶
Strategy 1: direct — Compiler-optimized¶
Each thread processes 1 element. TileLang may auto-vectorize. Least code, but compiler behavior not fully controllable.
def _make_unary_direct(N, dtype, op_func, output_dtype=None, threads=256):
@T.prim_func
def main(x: T.Tensor((N,), dtype), y: T.Tensor((N,), out_dtype)):
with T.Kernel(T.ceildiv(N, threads), threads=threads) as bx:
for i in T.Parallel(threads):
y[bx * threads + i] = op_func(x[bx * threads + i])
return main
Strategy 2: explicit_parallel — Explicit vectorization control¶
Each thread processes num_per_thread elements. T.Parallel(threads, num_per_thread) explicitly guides the compiler for vectorized access.
def _make_unary_explicit(N, dtype, op_func, output_dtype=None, threads=256, num_per_thread=8):
block_size = threads * num_per_thread
@T.prim_func
def main(x: T.Tensor((N,), dtype), y: T.Tensor((N,), out_dtype)):
with T.Kernel(T.ceildiv(N, block_size), threads=threads) as bx:
for i, j in T.Parallel(threads, num_per_thread):
idx = (bx * threads + i) * num_per_thread + j
y[idx] = op_func(x[idx])
return main
Strategy 3: register_copy — Explicit register staging¶
Explicit T.alloc_fragment + T.copy for global↔register transfers. T.copy generates vectorized load/store instructions.
def _make_unary_regcopy(N, dtype, op_func, output_dtype=None, threads=256, num_per_thread=8):
block_size = threads * num_per_thread
@T.prim_func
def main(x: T.Tensor((N,), dtype), y: T.Tensor((N,), out_dtype)):
with T.Kernel(T.ceildiv(N, block_size), threads=threads) as bx:
x_reg = T.alloc_fragment((block_size,), dtype)
y_reg = T.alloc_fragment((block_size,), out_dtype)
T.copy(x[bx * block_size:(bx + 1) * block_size], x_reg)
for i, j in T.Parallel(threads, num_per_thread):
y_reg[i * num_per_thread + j] = op_func(x_reg[i * num_per_thread + j])
T.copy(y_reg, y[bx * block_size:(bx + 1) * block_size])
return main
Binary Kernel: Stride-Based Offset¶
Binary ops use 1D output flat index + stride-based offset for N-dim broadcast. All shape/stride values are compile-time constants; the divmod chain is unrolled at compilation.
All three strategies (direct, explicit_parallel, register_copy) are supported. When register_copy is requested on broadcast inputs (ndim > 1), it silently downgrades to explicit_parallel because T.copy requires contiguous slices.
Strategy Defaults (H200 benchmarks)¶
| Kernel Type | Default Strategy | Rationale |
|---|---|---|
| UnaryKernel | register_copy |
H200: register_copy wins for fp16/bf16 across all tested shapes |
| BinaryKernel | explicit_parallel |
Broadcast inputs incompatible with register_copy; explicit_parallel is most general |
| FusedGatedKernel | explicit_parallel |
H200: ~2× faster than direct (e.g. silu_and_mul: 3.04 TB/s vs 1.50 TB/s) |
| fp8 (all types) | explicit_parallel |
register_copy unreliable for 8-bit fragments |
Config Defaults¶
Strategy-aware num_per_thread defaults (H200 benchmarks):
| Strategy | fp32 | fp16/bf16 | fp8 |
|---|---|---|---|
explicit_parallel |
4 | 4 | 16 |
register_copy |
4 | 8 | — |
direct |
— | — | — |
Rationale:
- fp32 × 4B × 4 = 16B = 128-bit uint4 alignment
- fp16/bf16 under explicit_parallel: npt=4 yields 42% bandwidth gain over npt=8
- fp16/bf16 under register_copy: npt=8 × 2B = 16B = 128-bit vectorized loads
- fp8: npt=16 × 1B = 16B = 128-bit alignment
threads fixed at 256 for all strategies. Elementwise ops have low register pressure; 256 threads achieves good occupancy on SM80+.
3.4 Template Base Classes¶
UnaryKernel¶
class UnaryKernel(Kernel):
supported_archs = [80, 86, 89, 90]
STRATEGIES = ["direct", "explicit_parallel", "register_copy"]
DEFAULT_STRATEGY = "register_copy"
OUTPUT_DTYPE = None # subclass override (e.g. torch.bool for predicates)
SUPPORTED_DTYPES = None # subclass override to restrict input dtypes
@staticmethod
def op_func(x):
raise NotImplementedError
def __init__(self, N_total, dtype, strategy=None, config=None, tune=False):
# dtype validation
# fp8: override strategy to explicit_parallel
# _build_kernel → _make_unary_{direct,explicit,regcopy}
# init_config → cache _compiled_fn
def forward(self, x):
return self._compiled_fn(x)
Key implementation details:
- _build_kernel dispatches to the appropriate _make_unary_* factory based on strategy
- init_config caches _compiled_fn by pre-applying config params (threads, npt) to the kernel builder
- fp8 handling: _wrap_fp8_accumulation wraps op_func with cast-in/cast-out; _fp8_output_dtype signals the Op layer to apply post-cast
- Autotune override catches serialization failures (op_func closures aren't pickleable) and falls back to default_config
BinaryKernel¶
class BinaryKernel(Kernel):
supported_archs = [80, 86, 89, 90]
STRATEGIES = ["direct", "explicit_parallel", "register_copy"]
DEFAULT_STRATEGY = "explicit_parallel"
OUTPUT_DTYPE = None # "int8" for comparison/logical ops
SUPPORTED_DTYPES = None
@staticmethod
def op_func(a, b):
raise NotImplementedError
def __init__(self, N_total, dtype, coalesced_shape, a_strides, b_strides,
a_numel, b_numel, strategy=None, config=None, tune=False):
# dtype validation
# register_copy auto-downgrade on broadcast inputs
# _build_kernel → _make_binary_{direct,explicit,register_copy}
# init_config → cache _compiled_fn
def forward(self, a, b):
return self._compiled_fn(a, b)
- Comparison ops use
OUTPUT_DTYPE = "int8"at kernel level; Op layer (_BoolOutputBinaryOp) casts totorch.bool register_copysilently downgrades toexplicit_parallelwhen any input has broadcast dimensions (ndim > 1)
FusedGatedKernel¶
class FusedGatedKernel(Kernel):
supported_archs = [80, 86, 89, 90]
STRATEGIES = ["direct", "explicit_parallel"]
DEFAULT_STRATEGY = "explicit_parallel"
SUPPORTED_DTYPES = None
@staticmethod
def activation_func(x):
raise NotImplementedError
def __init__(self, M, N, dtype, strategy=None, config=None, tune=False):
# _build_kernel → _make_fused_gated_{direct,explicit}
# init_config → cache _compiled_fn
def forward(self, x):
return self._compiled_fn(x)
- Input: single
(M, 2N)tensor (gate + value packed); output:(M, N) - No
register_copy— FusedGated has a 2D grid pattern incompatible with 1D fragment copy - Single tensor interface matches vLLM/FlashInfer/xformers convention
Independent Ops (11 ops)¶
Each has its own builder function, __init__, and forward. Shared patterns:
- _make_*_kernel with is_fp8 branching: fp8 uses explicit_parallel with fp16 accumulation; non-fp8 uses register_copy
- _compiled_fn caching in init_config
- _fp8_output_dtype for Op-layer post-cast
| Op | Extra Parameters | Builder |
|---|---|---|
| leaky_relu | negative_slope (scalar) | _make_leaky_relu_kernel |
| elu | alpha (scalar) | _make_elu_kernel |
| hardtanh | min_val, max_val (scalars) | _make_hardtanh_kernel |
| softplus | beta, threshold (scalars) | _make_softplus_kernel |
| prelu | weight (tensor) | _make_prelu_kernel |
| where | — | _make_where_kernel |
| clamp | min_val, max_val (scalars) | _make_clamp_kernel |
| masked_fill | fill_value (scalar) | _make_masked_fill_kernel |
| nan_to_num | nan, posinf, neginf (scalars) | _make_nan_to_num_kernel |
| alibi | — | _make_alibi_kernel |
| sinusoidal | — | _make_sinusoidal_kernel |
4. fp8 Support¶
4.1 Accumulation Strategy¶
fp8 dtypes lack precision for intermediate computation. All fp8 kernels use fp16 accumulation:
| Input dtype | Accumulation | Output | Post-cast (Op layer) |
|---|---|---|---|
| e4m3fn | fp16 | e4m3fn (saturating T.Cast) |
None — kernel output is already e4m3fn |
| e5m2 | fp16 | fp16 (no cast in kernel) | Non-saturating .to(e5m2) preserving Inf/NaN |
The distinction: e4m3fn has no Inf representation, so saturating cast is correct. e5m2 supports Inf/NaN, so the final cast must be non-saturating (PyTorch .to() semantics).
4.2 Implementation¶
Shared helper _wrap_fp8_accumulation(base_op, dtype, dtype_str, arity):
- Wraps any op_func with fp8 cast-in/cast-out logic
- Applied uniformly in _get_effective_op_func for all three template kernels
- Non-fp8 dtypes pass through unchanged
Scalar clamping _clamp_to_dtype_range(value, dtype):
- Prevents TVM FloatImm range-check failures for scalar literals exceeding fp8 range (e.g., fill_value=1e4 into e4m3fn whose max is 448)
- NaN passed through; Inf mapped to dtype max/min
Config for fp8:
- npt=16 (1B × 16 = 128-bit alignment)
- Strategy forced to explicit_parallel (register_copy unreliable for 8-bit fragments)
- Autotune grid: threads_opts=[128,256,512], npt_opts=[16,32]
5. Op Layer Architecture¶
5.1 Broadcast Utility: coalesce_broadcast_dims¶
The core utility for binary ops. Converts arbitrary N-dim broadcast into minimal (coalesced_shape, a_strides, b_strides).
Algorithm:
1. torch.broadcast_shapes → output shape
2. Right-align, left-pad with 1
3. Compute C-contiguous strides (stride=0 for broadcast dims)
4. Merge adjacent dims where all operands have compatible stride patterns
5. Remove trivial size-1 groups
Coalescing effect:
| Scenario | shapes | Coalesced ndim | Divmod count |
|---|---|---|---|
| same-shape | (B,S,D) + (B,S,D) |
1 | 0 |
| bias add | (B,S,D) + (1,1,D) |
2 | 1 |
| scaling | (B,S,D) + (B,S,1) |
2 | 1 |
| attn mask | (B,H,S,S) + (1,1,S,S) |
2 | 1 |
| interleaved | (B,H,S,S) + (B,1,1,S) |
3 | 2 |
| outer product | (M,1) + (1,N) |
2 | 1 |
5.2 Template Op Classes¶
UnaryOp:
class UnaryOp(Op):
def forward(self, x):
orig_shape = x.shape
x = x.contiguous().reshape(-1) # flatten to 1D
y = self.kernel(x) # kernel handles boundary
y = y.reshape(orig_shape)
return _apply_fp8_post_cast(y, self.kernel) # e5m2 non-saturating cast
BinaryOp:
class BinaryOp(Op):
def __init__(self, a_shape, b_shape, dtype, ...):
out_shape, coalesced_shape, a_strides, b_strides = coalesce_broadcast_dims(a_shape, b_shape)
self.kernel = self.kernel_cls(N_total, dtype, coalesced_shape, a_strides, b_strides, ...)
def forward(self, a, b):
a_flat = a.contiguous().view(-1)
b_flat = b.contiguous().view(-1)
y = self.kernel(a_flat, b_flat)
return y.reshape(self.out_shape)
FusedGatedOp: Input (M, 2N) → kernel → output (M, N). No reshape needed.
All Op classes include validation: CUDA device check, dtype matching, numel matching.
5.3 torch.compile Dispatch¶
Each Op's forward checks for _wrapped:
- If _wrapped is set: dispatch through custom_op (supports torch.compile tracing)
- Otherwise: call _eager_forward directly
Instance-based registry (_OP_REGISTRY) enables the custom_op wrapper to look up the correct Op instance at runtime.
6. Kernel Caching¶
Each kernel instance pre-compiles its kernel function in init_config:
def init_config(self, config=None, tune=False):
super().init_config(config, tune)
cfg = self.config
self._compiled_fn = self.kernel(cfg["threads"], cfg["num_per_thread"])
This avoids JIT lookup overhead on every forward() call. The compiled function is cached per-instance, not globally.
Autotune serialization: op_func/activation_func closures aren't pickleable. The autotune override catches serialization failures and falls back to default_config.