diff --git a/Makefile b/Makefile index 62c1e5f1a68f..522330034149 100644 --- a/Makefile +++ b/Makefile @@ -548,6 +548,7 @@ SOURCE_FILES = \ Module.cpp \ ModulusRemainder.cpp \ Monotonic.cpp \ + MultiRamp.cpp \ ObjectInstanceRegistry.cpp \ OffloadGPULoops.cpp \ OptimizeShuffles.cpp \ @@ -753,6 +754,7 @@ HEADER_FILES = \ Module.h \ ModulusRemainder.h \ Monotonic.h \ + MultiRamp.h \ ObjectInstanceRegistry.h \ OffloadGPULoops.h \ OptimizeShuffles.h \ diff --git a/apps/iir_blur/Makefile b/apps/iir_blur/Makefile index 92ed5d2a5b0b..5dd3b1200cc6 100644 --- a/apps/iir_blur/Makefile +++ b/apps/iir_blur/Makefile @@ -10,11 +10,11 @@ $(GENERATOR_BIN)/iir_blur.generator: iir_blur_generator.cpp $(GENERATOR_DEPS) $(BIN)/%/iir_blur.a: $(GENERATOR_BIN)/iir_blur.generator @mkdir -p $(@D) - $< -g iir_blur -f iir_blur -o $(BIN)/$* target=$*-no_runtime + $< -g iir_blur -f iir_blur -e $(GENERATOR_OUTPUTS) -o $(BIN)/$* target=$*-no_runtime $(BIN)/%/iir_blur_auto_schedule.a: $(GENERATOR_BIN)/iir_blur.generator @mkdir -p $(@D) - $< -g iir_blur -f iir_blur_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 + $< -g iir_blur -f iir_blur_auto_schedule -e $(GENERATOR_OUTPUTS) -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/runtime.a: $(GENERATOR_BIN)/iir_blur.generator @mkdir -p $(@D) diff --git a/apps/iir_blur/iir_blur_generator.cpp b/apps/iir_blur/iir_blur_generator.cpp index 7f411d7e8fef..146e15fb24b5 100644 --- a/apps/iir_blur/iir_blur_generator.cpp +++ b/apps/iir_blur/iir_blur_generator.cpp @@ -48,11 +48,6 @@ Func blur_cols_transpose(Func input, Expr height, Expr alpha, bool skip_schedule .fuse(yo, c, t) .parallel(t); - blur.in(transpose) - .compute_at(transpose, y) - .vectorize(x) - .unroll(y); - // Run the filter on each row of tiles (which corresponds to a strip of // columns in the input). blur.compute_at(transpose, t); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8e4e4f1e7afd..7558e589fe75 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -160,6 +160,7 @@ target_sources( Module.h ModulusRemainder.h Monotonic.h + MultiRamp.h ObjectInstanceRegistry.h OffloadGPULoops.h OptimizeShuffles.h @@ -337,6 +338,7 @@ target_sources( Module.cpp ModulusRemainder.cpp Monotonic.cpp + MultiRamp.cpp ObjectInstanceRegistry.cpp OffloadGPULoops.cpp OptimizeShuffles.cpp diff --git a/src/FlattenNestedRamps.cpp b/src/FlattenNestedRamps.cpp index efa373f6970a..7ee062ac5850 100644 --- a/src/FlattenNestedRamps.cpp +++ b/src/FlattenNestedRamps.cpp @@ -4,6 +4,7 @@ #include "Deinterleave.h" #include "IRMutator.h" #include "IROperator.h" +#include "MultiRamp.h" #include "Simplify.h" using std::vector; @@ -17,15 +18,25 @@ class FlattenRamps : public IRMutator { Expr visit(const Ramp *op) override { if (op->base.type().is_vector()) { - Expr base = mutate(op->base); - Expr stride = mutate(op->stride); - std::vector ramp_elems; - ramp_elems.reserve(op->lanes); - for (int ix = 0; ix < op->lanes; ix++) { - ramp_elems.push_back(base + ix * stride); - } + if (MultiRamp mr; + is_multiramp(op, Scope::empty_scope(), &mr)) { + // Flatten multiramps entirely in one go, instead of recursively + // with the general case below, so that we get one big concat + // instead of a concat-of-concats. The innermost dimension is + // left as a Ramp. + mr.mutate(this); + return Shuffle::make_concat(mr.flatten()); + } else { + Expr base = mutate(op->base); + Expr stride = mutate(op->stride); + std::vector ramp_elems; + ramp_elems.reserve(op->lanes); + for (int ix = 0; ix < op->lanes; ix++) { + ramp_elems.push_back(base + ix * stride); + } - return Shuffle::make_concat(ramp_elems); + return Shuffle::make_concat(ramp_elems); + } } return IRMutator::visit(op); @@ -40,6 +51,18 @@ class FlattenRamps : public IRMutator { return IRMutator::visit(op); } + // Return the sub-vector of `v` corresponding to the n-th sub-ramp of a + // flattened multiramp of width `inner_lanes`. Scalar broadcasts get + // rebroadcast to `inner_lanes`; everything else is a slice. + static Expr slice_per_inner_ramp(const Expr &v, int n, int inner_lanes) { + if (const Broadcast *b = v.as()) { + if (b->value.type().is_scalar()) { + return Broadcast::make(b->value, inner_lanes); + } + } + return Shuffle::make_slice(v, n * inner_lanes, 1, inner_lanes); + } + Expr visit(const Load *op) override { // Convert a load of a bounded span of indices into a shuffle // of a dense or strided load if possible. @@ -124,6 +147,92 @@ class FlattenRamps : public IRMutator { } } } + + // If the index is a multiramp, emit a concat of per-inner-ramp + // dense/strided loads. This handles the case where the bounded-span + // conversion above didn't fire (e.g. symbolic strides, or the + // access range is too large for a single dense load). Doing the + // concat directly (rather than letting the Ramp visitor flatten + // the nested ramp into a big scalar-index load + a subtracted + // broadcast offset) makes the structure visible to downstream + // shuffle simplification rules. + if (op->type.is_vector()) { + if (MultiRamp mr; + is_multiramp(op->index, Scope::empty_scope(), &mr) && + mr.dimensions() >= 2) { + + Expr predicate = mutate(op->predicate); + mr.mutate(this); + std::vector sub_indices = mr.flatten(); + int inner_lanes = mr.lanes[0]; + Type elem_type = op->type.with_lanes(inner_lanes); + std::vector loads; + loads.reserve(sub_indices.size()); + for (size_t n = 0; n < sub_indices.size(); n++) { + Expr p = slice_per_inner_ramp(predicate, (int)n, inner_lanes); + ModulusRemainder align = (n == 0) ? op->alignment : ModulusRemainder{}; + loads.push_back(Load::make(elem_type, op->name, sub_indices[n], + op->image, op->param, p, align)); + } + return Shuffle::make_concat(loads); + } + } + + return IRMutator::visit(op); + } + + Stmt visit(const Store *op) override { + // If the index is a multiramp, unroll into a sequence of per-inner-ramp + // stores, for the same reason as the Load visitor above. + if (op->index.type().is_vector()) { + if (MultiRamp mr; + is_multiramp(op->index, Scope::empty_scope(), &mr) && + mr.dimensions() >= 2) { + + Expr predicate = mutate(op->predicate); + Expr value = mutate(op->value); + mr.mutate(this); + std::vector sub_indices = mr.flatten(); + int inner_lanes = mr.lanes[0]; + + // The value and/or predicate may load from the buffer being + // stored to, so they must be fully evaluated before any of + // the stores run. Hoist non-trivial ones into LetStmts that + // wrap the block of stores. Skip the hoisting if the expr + // is already a Variable or a constant. + auto needs_hoist = [](const Expr &e) { + return !is_const(e) && !e.as(); + }; + std::string value_name, predicate_name; + Expr value_ref = value, predicate_ref = predicate; + if (needs_hoist(value)) { + value_name = unique_name('t'); + value_ref = Variable::make(value.type(), value_name); + } + if (needs_hoist(predicate)) { + predicate_name = unique_name('t'); + predicate_ref = Variable::make(predicate.type(), predicate_name); + } + + std::vector stores; + stores.reserve(sub_indices.size()); + for (size_t n = 0; n < sub_indices.size(); n++) { + Expr p = slice_per_inner_ramp(predicate_ref, (int)n, inner_lanes); + Expr v = slice_per_inner_ramp(value_ref, (int)n, inner_lanes); + ModulusRemainder align = (n == 0) ? op->alignment : ModulusRemainder{}; + stores.push_back(Store::make(op->name, v, sub_indices[n], + op->param, p, align)); + } + Stmt result = Block::make(stores); + if (!predicate_name.empty()) { + result = LetStmt::make(predicate_name, predicate, result); + } + if (!value_name.empty()) { + result = LetStmt::make(value_name, value, result); + } + return result; + } + } return IRMutator::visit(op); } }; diff --git a/src/MultiRamp.cpp b/src/MultiRamp.cpp new file mode 100644 index 000000000000..6ecbb4d3fcd0 --- /dev/null +++ b/src/MultiRamp.cpp @@ -0,0 +1,670 @@ +#include "MultiRamp.h" + +#include "IR.h" +#include "IREquality.h" +#include "IRMutator.h" +#include "IROperator.h" +#include "IRVisitor.h" +#include "ModulusRemainder.h" +#include "Simplify.h" +#include "Util.h" + +#include +#include + +namespace Halide { +namespace Internal { + +namespace { + +// Collapse adjacent dims whose strides align: if the outer stride equals +// inner_stride · inner_lanes, the two dims describe a single flat dim and +// can be merged. Keeps the output tidy; doesn't affect what values the +// MultiRamp represents. +void collapse_adjacent_dims(MultiRamp *m) { + for (size_t i = 1; i < m->lanes.size();) { + Expr want_outer = simplify(m->strides[i - 1] * m->lanes[i - 1]); + if (equal(m->strides[i], want_outer)) { + m->lanes[i - 1] *= m->lanes[i]; + m->strides.erase(m->strides.begin() + i); + m->lanes.erase(m->lanes.begin() + i); + } else { + i++; + } + } +} + +} // namespace + +// Multiramps with compatible lanes form a vector space. Here is scalar multiplication. +void MultiRamp::mul(const Expr &e) { + internal_assert(e.type().is_scalar()); + base *= e; + for (Expr &s : strides) { + s *= e; + } +} + +// And here is vector addition. Returns false when the two shapes have no +// common refinement (the sum is not a multiramp). Adding multiramps with +// different total lane counts is a caller error and triggers an assertion. +bool MultiRamp::add(const MultiRamp &other) { + // We walk through both ramps' dimensions innermost-to-outermost, consuming + // gcd(a_lanes, b_lanes) of lanes at a time. When a dimension is only + // partially consumed, the remaining part of that dimension corresponds to + // an "outer" sub-dim in the refined shape and its stride must be scaled + // by the factor just consumed. + internal_assert(total_lanes() == other.total_lanes()) + << "MultiRamp::add: total lane counts must match (" << total_lanes() + << " vs " << other.total_lanes() << ")"; + if (lanes.empty()) { + // Both are 0-dim scalars. + base = simplify(base + other.base); + return true; + } + MultiRamp result; + result.base = simplify(base + other.base); + size_t ai = 0, bi = 0; + int a_lanes = lanes[0], b_lanes = other.lanes[0]; + Expr a_stride = strides[0], b_stride = other.strides[0]; + while (true) { + int next_lanes = gcd(a_lanes, b_lanes); + if (next_lanes == 1) { + // The two next lanes are coprime, e.g: + // [0, 1, 2, 100, 101, 102] + [0, 1, 100, 101, 200, 201] + // which has no common refinement. + return false; + } + result.strides.emplace_back(simplify(a_stride + b_stride)); + result.lanes.push_back(next_lanes); + a_lanes /= next_lanes; + b_lanes /= next_lanes; + bool a_done = false, b_done = false; + if (a_lanes == 1) { + ai++; + if (ai >= lanes.size()) { + a_done = true; + } else { + a_lanes = lanes[ai]; + a_stride = strides[ai]; + } + } else { + // Remaining portion of current A-dim has a scaled stride. + a_stride = simplify(a_stride * next_lanes); + } + if (b_lanes == 1) { + bi++; + if (bi >= other.lanes.size()) { + b_done = true; + } else { + b_lanes = other.lanes[bi]; + b_stride = other.strides[bi]; + } + } else { + b_stride = simplify(b_stride * next_lanes); + } + if (a_done && b_done) { + collapse_adjacent_dims(&result); + *this = std::move(result); + return true; + } + // The up-front lane-count check ensures both sides always exhaust + // together, so neither side should be done here. + } +} + +namespace { + +// Divide (or mod) a MultiRamp by a positive integer k. Returns a new +// MultiRamp, or false if the quotient/remainder isn't itself a multiramp. +// Shared core of div_by and mod_by. +// +// Precondition: the base is a known multiple of k. Otherwise we return false. +// +// Mental model +// ------------ +// Picture the integers laid out in buckets of size k: [0, k), [k, 2k), .... +// Dividing by k asks "which bucket?", modding by k asks "where inside the +// bucket?". The base sits at the left edge of some bucket. We want every +// lane of the result to remain an affine function of the multi-index — i.e. +// a multiramp. Whether that's possible depends on how the input dims move +// the lanes around relative to those buckets. +// +// Two kinds of input dim +// ---------------------- +// "Pure-carry" dim: stride s is itself a multiple of k. Every step crosses +// a whole number of buckets, so the output stride for this dim is just s/k. +// These are boring in a good way. +// +// "Flex" dim: stride s isn't a multiple of k. Write s = k·q + r with +// r = s mod k in [0, k). A step advances the bucket index by q and shifts +// the position-inside-the-bucket by r. If every lane along this dim still +// lives in the same bucket, the output stride is q and the intra-bucket +// wiggle washes out under /k. The danger is that the position eventually +// exceeds k-1 — at which point the floor jumps and the result isn't a +// multiramp. +// +// Worked example +// -------------- +// base 0, stride 2, lanes 6, k = 4. Values 0, 2, 4, 6, 8, 10. +// +// Treat it as one flat 6-lane dim and it's doomed: the positions inside +// the bucket would be 0, 2, 4, 6, ... — already past k-1 = 3 at lane 2. +// +// But we can reshape the 6 lanes as (inner 2 × outer 3). The inner stride +// stays 2, and the outer stride becomes 2·2 = 4 — a whole bucket. Now the +// outer dim is pure-carry, and the inner dim only shows positions 0 and 2, +// safely inside [0, 4). The result is base 0, strides [0, 1], lanes [2, 3], +// which expands to 0, 0, 1, 1, 2, 2. That matches the per-lane division. +// +// The budget +// ---------- +// Because the base is a bucket boundary, every lane starts at position 0. +// At the far corner of the iteration box each flex dim contributes r·(n-1) +// to the position, and the positions have to stay ≤ k-1 everywhere. So the +// flex dims share a single budget of k-1; each one spends r·(n-1) of it. +// If they all fit, we're done. +// +// Joint fit: base 0, strides [2, 3], lanes [2, 2], k = 6. +// Input values: 0, 2, 3, 5 (all in bucket [0, 6)) +// / 6: 0, 0, 0, 0 (a multiramp with strides [0, 0]) +// Inner spends 2·(2-1) = 2 of the budget, outer spends 3·(2-1) = 3. Total +// 5 = k-1, just fitting. +// +// Joint failure: base 0, strides [2, 5], lanes [2, 2], k = 6. +// Input values: 0, 2, 5, 7 (7 is in the next bucket) +// / 6: 0, 0, 0, 1 (not a multiramp of any shape) +// Inner spends 2, outer spends 5. Each alone would fit (≤ 5), but +// together they want 7 > 5. Return false. +// +// The split trick +// --------------- +// When a single dim's r·(n-1) blows the budget by itself, here's the +// escape. Let p = k / gcd(k, r) — the smallest number of stride-s steps +// that reach a bucket boundary (since p·s ≡ p·r (mod k), and we want that +// to be 0). We re-view the dim of lanes n as (inner p × outer n/p) with +// strides (s, s·p). The outer stride s·p is a whole number of buckets, so +// the outer dim is pure-carry. Only the inner still spends budget, and +// only r·(p-1) of it. If even that doesn't fit, we give up. +// +// Algorithm +// --------- +// Walk input dims innermost-first, with budget = k-1. For each dim we only +// need to know r = s mod k (not s itself) — so a symbolic stride is fine as +// long as we can pin down its residue modulo k. If we can't, fail. For the +// first case that applies, emit its output; if none, fail. +// +// (a) r = 0 (pure carry) → emit (s/k, n). +// (b) r·(n-1) ≤ budget → emit (s/k, n); +// budget -= r·(n-1). +// (c) p = k/gcd(k,r); 1 < p < n, +// p divides n, r·(p-1) ≤ budget → emit inner (s/k, p) and +// outer (s·p/k, n/p); +// budget -= r·(p-1). +// (d) otherwise → return false. +// +// Output base is base/k for div, 0 for mod. For mod, emit r in place of +// s/k and 0 in place of s·p/k; the shape is the same. +// +// Finally, collapse any adjacent output dims where the outer stride is +// inner_stride · inner_lanes — e.g. pure-carry dims with matching strides +// from two consecutive inputs, or a split's outer half lining up with the +// next input's contribution. This just keeps the output tidy; it doesn't +// affect which inputs we accept. +// +// Rejection examples +// ------------------ +// base 0, stride 1, lanes 5, / 2: +// Input values: 0, 1, 2, 3, 4 +// / 2: 0, 0, 1, 1, 2 (not a multiramp of any shape) +// p = 2 doesn't divide 5, and the flat dim would spill immediately +// (r·(n-1) = 4 > 1 = budget). Return false. +// +// base 3, stride 2, lanes 2, / 4: +// Input values: 3, 5 +// / 4: 0, 1 (does happen to be a multiramp, but +// our algorithm requires an aligned +// base and skips this case) +// Return false before even looking at the dims. +bool div_or_mod_impl(MultiRamp *self, const Expr &k_expr, bool is_div) { + auto ck = as_const_int(k_expr); + if (!ck || *ck <= 0) { + return false; + } + int64_t k = *ck; + Type t = self->base.type(); + + // Aligned-base assumption: require base to be a known multiple of k. + int64_t b_mod = 0; + if (!reduce_expr_modulo(self->base, k, &b_mod) || b_mod != 0) { + return false; + } + + MultiRamp result; + result.base = is_div ? simplify(self->base / (int)k) : make_zero(t); + + // Residual budget: how much room is left inside the single k-bucket + // starting at the base. Starts at k-1 and shrinks as each non-pure-carry + // dim spends r·(lanes-1) of it. + int64_t budget = k - 1; + + for (size_t j = 0; j < self->strides.size(); j++) { + const Expr &s = self->strides[j]; + int n = self->lanes[j]; + + // Everything below only needs s mod k, never s itself. So it's fine + // for s to be symbolic, as long as we can pin down its residue. + int64_t r = 0; + if (!reduce_expr_modulo(s, k, &r)) { + return false; + } + + // Case (a): pure carry. + if (r == 0) { + result.strides.push_back(is_div ? simplify(s / (int)k) : make_zero(t)); + result.lanes.push_back(n); + continue; + } + + // Case (b): whole dim fits in the remaining budget. Note that (b) + // and (c) below are mutually exclusive — if the whole dim fits, n + // is necessarily ≤ p, which means case (c) couldn't apply anyway. + // So their order here doesn't matter for which inputs we accept. + if (r * (n - 1) <= budget) { + result.strides.push_back(is_div ? simplify(s / (int)k) : make_const(t, r)); + result.lanes.push_back(n); + budget -= r * (n - 1); + continue; + } + + // Case (c): split into (inner = p, outer = n/p). The smallest p with + // p·s ≡ 0 (mod k) only depends on r, since p·s ≡ p·r (mod k). So + // p = k / gcd(k, r). + int64_t p = k / gcd(k, r); + + // For r ∈ (0, k), gcd(k, r) ≤ r < k, so p ≥ 2. + internal_assert(p > 1); + + if (p >= (int64_t)n) { + // The smallest split that would work is >= than the number of lanes + // we have in this dimension. + return false; + } + + if (n % p) { + // p must divide n to split n by p. Any larger + // split size would also be a multiple of p, so + // if p does not divide n, no valid split size + // divides n. + return false; + } + + if (r * (p - 1) > budget) { + // We ran out of budget. + return false; + } + + // Inner half: residual fits after shrinking to size p. + result.strides.push_back(is_div ? simplify(s / (int)k) : make_const(t, r)); + result.lanes.push_back((int)p); + budget -= r * (p - 1); + + // Outer half: s·p is a multiple of k by construction, so this divides + // exactly (though Halide's simplifier may or may not fold it). + result.strides.push_back(is_div ? simplify((s * (int)p) / (int)k) : make_zero(t)); + result.lanes.push_back((int)(n / p)); + } + + collapse_adjacent_dims(&result); + *self = std::move(result); + return true; +} + +} // namespace + +bool MultiRamp::div(const Expr &k) { + return div_or_mod_impl(this, k, /*is_div=*/true); +} + +bool MultiRamp::mod(const Expr &k) { + return div_or_mod_impl(this, k, /*is_div=*/false); +} + +namespace { +std::optional unbroadcast(const Expr &e) { + if (e.type().is_scalar()) { + return e; + } else if (const Broadcast *b = e.as()) { + return unbroadcast(b->value); + } else { + return std::nullopt; + } +} + +// Internal is_multiramp. May leave *result in a partial state on failure; +// the public is_multiramp below protects callers by only committing on +// success. Recursive calls go through the public wrapper, so each branch +// here can assume *result is either freshly initialized (on entry) or +// freshly filled by a successful recursion. +bool is_multiramp_impl(const Expr &e, const Scope &scope, MultiRamp *result) { + Type elem_t = e.type().element_of(); + if (e.type().is_scalar()) { + result->base = e; + return true; + } else if (const Variable *v = e.as()) { + if (const Expr *e = scope.find(v->name)) { + return is_multiramp(*e, scope, result); + } + } else if (const Broadcast *b = e.as(); + b && is_multiramp(b->value, scope, result)) { + result->strides.push_back(make_zero(elem_t)); + result->lanes.push_back(b->lanes); + return true; + } else if (const Ramp *r = e.as()) { + if (auto stride = unbroadcast(r->stride)) { + if (is_multiramp(r->base, scope, result)) { + result->strides.push_back(*stride); + result->lanes.push_back(r->lanes); + return true; + } + } + } else if (const Add *a = e.as()) { + MultiRamp rb; + if (is_multiramp(a->a, scope, result) && + is_multiramp(a->b, scope, &rb)) { + return result->add(rb); + } + } else if (const Sub *s = e.as()) { + // Convert to Add to reuse logic above. + MultiRamp rb; + if (is_multiramp(s->a, scope, result) && + is_multiramp(s->b, scope, &rb)) { + rb.mul(make_const(elem_t, -1)); + return result->add(rb); + } + } else if (const Mul *m = e.as()) { + // Try each side as the scalar factor. The public wrapper's + // untouched-on-failure guarantee means a failed first attempt + // leaves *result clean for the second. + if (auto b = unbroadcast(m->b); + b && is_multiramp(m->a, scope, result)) { + result->mul(*b); + return true; + } + if (auto a = unbroadcast(m->a); + a && is_multiramp(m->b, scope, result)) { + result->mul(*a); + return true; + } + } else if (const Div *d = e.as
()) { + if (auto denom = unbroadcast(d->b)) { + if (is_multiramp(d->a, scope, result)) { + return result->div(*denom); + } + } + } else if (const Mod *m = e.as()) { + if (auto denom = unbroadcast(m->b)) { + if (is_multiramp(m->a, scope, result)) { + return result->mod(*denom); + } + } + } + return false; +} +} // namespace + +bool is_multiramp(const Expr &e, const Scope &scope, MultiRamp *result) { + // Wrap the impl so that callers get a clean "untouched on failure" + // contract regardless of how the impl leaves its scratch space. + MultiRamp tmp; + if (is_multiramp_impl(e, scope, &tmp)) { + *result = std::move(tmp); + return true; + } + return false; +} + +Expr MultiRamp::operator==(const MultiRamp &other) const { + // Construct the difference, and check if all strides are zero. + MultiRamp diff = other; + diff.mul(-1); + if (!diff.add(*this)) { + return const_false(); + } + Expr c = diff.base == 0; + for (const Expr &s : diff.strides) { + c = c && s == 0; + } + return simplify(c); +} + +void MultiRamp::slice(int d, const Expr &v) { + internal_assert(d >= 0 && d < (int)strides.size()); + internal_assert(v.type() == base.type()); + base += v * strides[d]; + strides.erase(strides.begin() + d); + lanes.erase(lanes.begin() + d); + collapse_adjacent_dims(this); +} + +Expr MultiRamp::alias_free() const { + // A sufficient condition: there exists an ordering of dims such that + // each stride's absolute value is strictly greater than the sum of the + // spans of all earlier dims, where span(k) = |strides[k]| * (lanes[k] − + // 1). Under such an ordering the lanes enumerate distinct offsets in an + // interval-tree fashion. In principle we'd only need to test the + // ordering with increasing |strides|, but symbolic strides leave the + // ordering unknown, so we try all permutations and OR the conditions. + // (The permutation count is small in practice — one dim per nested + // loop.) This ignores base, which is fine for uniqueness within the + // ramp (base is a uniform offset). + + if (lanes.empty()) { + return const_true(); + } + int d = (int)lanes.size(); + std::vector perm(d); + std::iota(perm.begin(), perm.end(), 0); + Expr result = const_false(); + do { + Expr cond = (strides[perm[0]] != 0); + Expr accum = make_zero(base.type()); // running sum of |s_k|*(n_k − 1) + for (int j = 0; j < d; j++) { + Expr s = strides[perm[j]]; + Expr abs_s = abs(s); + if (j > 0) { + cond = cond && (abs_s > accum); + } + accum = accum + abs_s * (lanes[perm[j]] - 1); + } + result = result || cond; + } while (std::next_permutation(perm.begin(), perm.end())); + return simplify(result); +} + +std::vector MultiRamp::alias_free_slice() { + // Greedy: starting from an empty MultiRamp (same base), try adding dims + // one by one from innermost to outermost. Any dim that would break the + // alias-free condition is peeled off instead. Stride-zero dims always + // break alias-freedom (except as the single dim of a 1-dim ramp, which + // is a scalar), so we fast-path them to skip the can_prove call. + std::vector peeled; + MultiRamp remaining; + remaining.base = base; + for (int i = 0; i < dimensions(); i++) { + bool must_peel = is_const_zero(strides[i]) && !remaining.lanes.empty(); + if (!must_peel) { + remaining.strides.push_back(strides[i]); + remaining.lanes.push_back(lanes[i]); + if (can_prove(remaining.alias_free())) { + continue; + } + remaining.strides.pop_back(); + remaining.lanes.pop_back(); + } + peeled.push_back(PeeledDim{strides[i], lanes[i], i}); + } + *this = std::move(remaining); + return peeled; +} + +int MultiRamp::rotate_stride_one_innermost() { + int k = -1; + for (int i = 0; i < dimensions(); i++) { + if (is_const_one(strides[i])) { + k = i; + break; + } + } + if (k <= 0) { + return 0; + } + int A = 1; + for (int i = 0; i < k; i++) { + A *= lanes[i]; + } + int d = dimensions(); + std::vector perm(d); + std::iota(perm.begin(), perm.end(), 0); + std::rotate(perm.begin(), perm.begin() + k, perm.end()); + reorder(perm); + return A; +} + +int MultiRamp::dimensions() const { + return (int)strides.size(); +} + +int MultiRamp::total_lanes() const { + int prod = 1; + for (int l : lanes) { + prod *= l; + } + return prod; +} + +Expr MultiRamp::to_expr() const { + Expr e = base; + for (int i = 0; i < dimensions(); i++) { + if (is_const_zero(strides[i])) { + e = Broadcast::make(e, lanes[i]); + } else if (e.type().is_scalar()) { + e = Ramp::make(e, strides[i], lanes[i]); + } else { + e = Ramp::make(e, Broadcast::make(strides[i], e.type().lanes()), lanes[i]); + } + } + return e; +} + +void MultiRamp::reorder(const std::vector &perm) { + int d = dimensions(); + internal_assert((int)perm.size() == d) << "perm size mismatch\n"; + std::vector new_strides; + std::vector new_lanes; + new_strides.reserve(d); + new_lanes.reserve(d); + for (int k = 0; k < d; k++) { + internal_assert(perm[k] >= 0 && perm[k] < d) << "perm out of range\n"; + new_strides.push_back(std::move(strides[perm[k]])); + new_lanes.push_back(lanes[perm[k]]); + } + strides = std::move(new_strides); + lanes = std::move(new_lanes); +} + +void MultiRamp::accept(IRVisitor *visitor) const { + base.accept(visitor); + for (const Expr &s : strides) { + s.accept(visitor); + } +} + +void MultiRamp::mutate(IRMutator *mutator) { + base = (*mutator)(base); + for (Expr &s : strides) { + s = (*mutator)(s); + } +} + +std::vector MultiRamp::shuffle_from_permuted(const std::vector &perm) const { + // For each output lane n (in *this's lane order), we want the shuffle to + // pull from the input (permuted) vector's lane that represents the same + // multi-index. Decompose n into multi-index (i_0, ..., i_{d-1}) using + // this->lanes (innermost first); the matching multi-index in the permuted + // MultiRamp is (j_k) with j_k = i_{perm[k]}, flattened with + // this->lanes[perm[k]] as its innermost lane counts. + int d = dimensions(); + internal_assert((int)perm.size() == d); + std::vector indices; + indices.reserve(total_lanes()); + for_each_coordinate(lanes, [&](const std::vector &coord) { + int permuted_flat = 0, M = 1; + for (int k = 0; k < d; k++) { + permuted_flat += coord[perm[k]] * M; + M *= lanes[perm[k]]; + } + indices.push_back(permuted_flat); + }); + return indices; +} + +std::vector MultiRamp::flatten() const { + int d = dimensions(); + if (d == 0) { + return {base}; + } + int inner_lanes = lanes[0]; + std::vector outer_sizes(lanes.begin() + 1, lanes.end()); + std::vector result; + result.reserve(total_lanes() / inner_lanes); + for_each_coordinate(outer_sizes, [&](const std::vector &coord) { + Expr offset_base = base; + for (size_t k = 0; k < coord.size(); k++) { + offset_base = offset_base + coord[k] * strides[k + 1]; + } + result.push_back(Ramp::make(offset_base, strides[0], inner_lanes)); + }); + return result; +} + +std::vector MultiRamp::shuffle_from_slice(const std::vector &dims, + const std::vector &pos) const { + // For each output lane n (in the sliced MultiRamp's lane order), we want + // the shuffle to pull from the lane of *this whose multi-index matches + // n in the free (non-sliced) dims, and has the specified values in the + // sliced dims. + internal_assert(dims.size() == pos.size()); + int d = dimensions(); + std::vector fixed(d, -1); + for (size_t j = 0; j < dims.size(); j++) { + int dd = dims[j]; + internal_assert(dd >= 0 && dd < d); + internal_assert(pos[j] >= 0 && pos[j] < lanes[dd]); + internal_assert(fixed[dd] == -1) << "duplicate dim in shuffle_from_slice\n"; + fixed[dd] = pos[j]; + } + // Sizes of the free (non-fixed) dims, in the same order as they + // appear in the full dim list. + std::vector free_sizes; + for (int k = 0; k < d; k++) { + if (fixed[k] == -1) { + free_sizes.push_back(lanes[k]); + } + } + std::vector indices; + for_each_coordinate(free_sizes, [&](const std::vector &free_coord) { + int flat = 0, M = 1; + size_t fj = 0; + for (int k = 0; k < d; k++) { + int ik = (fixed[k] != -1) ? fixed[k] : free_coord[fj++]; + flat += ik * M; + M *= lanes[k]; + } + indices.push_back(flat); + }); + return indices; +} + +} // namespace Internal +} // namespace Halide diff --git a/src/MultiRamp.h b/src/MultiRamp.h new file mode 100644 index 000000000000..35daef494b8e --- /dev/null +++ b/src/MultiRamp.h @@ -0,0 +1,196 @@ +#ifndef HALIDE_MULTI_RAMP_H +#define HALIDE_MULTI_RAMP_H + +/** \file + * Defines the MultiRamp IR helper — a multi-dimensional ramp recognised and + * manipulated by the vectorization pass and its callers. + */ + +#include "Expr.h" +#include "Scope.h" + +namespace Halide { +namespace Internal { + +class IRMutator; +class IRVisitor; + +/** A multi-dimensional ramp. I.e. a ramp of ramps of ramps of ramps... + * + * Represents a vector whose lanes are produced by + * + * base + i_0 * strides[0] + i_1 * strides[1] + ... + * + * where i_k iterates over [0, lanes[k]) and the innermost dim is dim 0. + * For example, with base = 0, strides = [1, 100], lanes = [2, 3] the lane + * sequence is [0, 1, 100, 101, 200, 201]. + * + * Invariants: + * - base is scalar; every entry of strides is scalar and has the same + * type as base. + * - strides.size() == lanes.size() (this value is dimensions()). + * - Each lanes[k] >= 1. An entry of 1 is legal but methods that flatten + * (reorder, add, etc.) will remove it. + * - dimensions() == 0 represents a scalar (total_lanes() == 1); + * to_expr() yields `base` unchanged, and the other methods handle + * this case trivially. + * + * mul, add, div, mod mutate in place. mul always succeeds; add/div/mod + * return false when the result isn't expressible as a multiramp (leaving + * *this undefined). */ +struct MultiRamp { + Expr base; + std::vector strides; + std::vector lanes; + + /** Multiply by a scalar. Always a multiramp. */ + void mul(const Expr &e); + + /** Add another MultiRamp elementwise. Returns false if the result isn't + * a multiramp (which happens when the two input shapes have no common + * refinement). */ + bool add(const MultiRamp &other); + + /** Floor-divide by a scalar. The main use case is recognizing + * downsampling reductions like `f(r/4) += g(r)` as multiramps, so that + * the vectorize pass can handle them as within-vector reductions. + * + * Returns false if the denominator isn't a positive integer constant, + * or if the quotient isn't a multiramp. The result may have one more + * dim than the input (a single split may be introduced per input dim, + * e.g. ramp(0,2,6)/4 requires splitting a dim of extent 6 into 2x3 + * because the quotient changes mid-dim). See div_or_mod_impl in + * MultiRamp.cpp for the derivation. O(d). */ + bool div(const Expr &k); + + /** Euclidean mod by a scalar. Returns false if the denominator isn't a + * positive integer constant, or if the remainder isn't a multiramp. + * Same shape transformations as div. Rare cases where the remainder is + * a multiramp but the quotient isn't are not recognized here. O(d). */ + bool mod(const Expr &k); + + /** Construct an Expr which gives whether one multiramp is equal to + * another in every lane. Assumes the total lane count matches. Returns + * a symbolic Expr (not a bool) matching operator== semantics on + * Exprs. */ + Expr operator==(const MultiRamp &other) const; + + /** Remove dim `d`, adding `v * strides[d]` to base. Pass v = 0 for the + * first slice along that dim, or a Variable to get a parameterized + * slice. */ + void slice(int d, const Expr &v); + + /** Construct an Expr that is a *sufficient* condition for the lanes to + * all be unique — i.e. if it evaluates to true the lanes don't alias, + * but if it evaluates to false the lanes may or may not alias. The + * implication only goes one way. The condition checked is: there + * exists an ordering of the dims along which each stride is greater + * than the sum of the spans of earlier dims (span of dim k = + * |strides[k]| * (lanes[k] - 1)). We OR that condition over all dim + * orderings (base is ignored since it's a uniform offset and doesn't + * affect within-ramp uniqueness). */ + Expr alias_free() const; + + /** Information about one peeled dim, produced by alias_free_slice. + * `dim` is the dim's position in the *pre-call* MultiRamp. */ + struct PeeledDim { + Expr stride; + int lanes; + int dim; + }; + + /** Build an alias-free slice of *this by walking the dims innermost to + * outermost and keeping each one only if the slice is still alias-free + * after adding it. The kept dims are a *subset* of the original dims + * (preserving their relative order), not necessarily a prefix — e.g. a + * middle dim may be dropped while both inner and outer dims are kept. + * Replace *this with the resulting slice, and return a description of + * the dims that weren't kept (innermost first). Always succeeds; *this + * may be reduced to a 0-dim scalar if no dim can be kept. The omitted + * dims' contributions are NOT folded into base — callers usually want + * to add back `var * omitted.stride` per omitted dim before using + * *this. + * + * All dimensions with stride zero or purely symbolic strides will be + * peeled, and some constant stride dimensions may also be peeled if + * they produce values that overlap other dimensions E.g. if there are + * two nested ramps that both have stride 1 the outer one will be + * peeled. */ + std::vector alias_free_slice(); + + /** No-op returning 0 if the stride-1 dim is already innermost (or + * there isn't one). Otherwise rotate the dims so the stride-1 dim + * moves to position 0, with the previously-inner dims moved to the + * outermost end, and return A = the product of those previously-inner + * dims' lane counts. After this call, + * Shuffle::make_transpose(new_to_expr(), total_lanes / A) reconstructs + * a vector in the old lane order from one in the new order. */ + int rotate_stride_one_innermost(); + + /** The dimensionality. May be lower than you expected, because this + * gets flattened when possible by the operations above. */ + int dimensions() const; + + /** The product of all the lane counts. */ + int total_lanes() const; + + /** The multiramp as a nested series of ramps. */ + Expr to_expr() const; + + /** Flatten the multiramp into a vector of 1D Ramps — one per outer + * multi-index, each with inner_lanes = lanes[0] and stride = + * strides[0]. Ramps are returned in this MultiRamp's lane order: + * concat'ing the returned Ramps reproduces the full lane sequence. The + * caller is responsible for any prior mutation/simplification of `base` + * and `strides` (the Ramps reference them directly). */ + std::vector flatten() const; + + /** Reorder this MultiRamp's dimensions in place. perm[k] is the index + * into this's current dims that becomes the k-th dim after reordering + * (innermost first, as always). perm must be a permutation of + * {0, ..., dimensions()-1}. E.g. with dims [s0, s1, s2] and + * perm = [2, 0, 1], after reorder the new dims are [s2, s0, s1]. */ + void reorder(const std::vector &perm); + + /** Pass an IRVisitor through all Exprs referenced (base and each + * stride). Note that base and strides are scalar but may nonetheless + * contain nested vector reductions. */ + void accept(IRVisitor *visitor) const; + + /** Pass an IRMutator through all Exprs referenced, replacing base and + * strides with the mutated results. Note that base and strides are + * scalar but may nonetheless contain nested vector reductions. */ + void mutate(IRMutator *mutator); + + /** Given a permutation `perm`, return shuffle indices `idx` such that + * if `p` is a copy of `*this` with `reorder(perm)` applied, then + * + * Shuffle::make({p.to_expr()}, idx) + * + * produces the same vector of lane values as `this->to_expr()`. I.e. + * given a vector in the permuted lane order, the returned indices put + * it back into this MultiRamp's original lane order. */ + std::vector shuffle_from_permuted(const std::vector &perm) const; + + /** Return shuffle indices `idx` such that + * + * Shuffle::make({this->to_expr()}, idx) + * + * produces the same vector of lane values as a copy of *this with + * slice(dims[j], pos[j]) applied for each j. Since slicing reduces + * the lane count, the shuffle selects the subset of *this's lanes + * whose coordinate along dim `dims[j]` equals `pos[j]` for all j. + * `dims` and `pos` must have the same length and `dims` must list + * distinct dim indices. */ + std::vector shuffle_from_slice(const std::vector &dims, + const std::vector &pos) const; +}; + +/** Check if a vector Expr is a multiramp, and assign to result if so. + * Returns false and leaves *result untouched if not. */ +bool is_multiramp(const Expr &e, const Scope &scope, MultiRamp *result); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index c55648cc4c2d..e82eb82acdf5 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -1,5 +1,10 @@ #include "Simplify_Internal.h" +#include +#include + +#include "MultiRamp.h" + using std::string; namespace Halide { @@ -375,6 +380,7 @@ Expr Simplify::visit(const Load *op, ExprInfo *info) { base_info.alignment = ModulusRemainder::intersect(base_info.alignment, index_info.alignment); ModulusRemainder align = ModulusRemainder::intersect(op->alignment, base_info.alignment); + int A; const Broadcast *b_index = index.as(); const Shuffle *s_index = index.as(); @@ -408,23 +414,34 @@ Expr Simplify::visit(const Load *op, ExprInfo *info) { loaded_vecs.emplace_back(std::move(load)); } return Shuffle::make(loaded_vecs, s_index->indices); - } else if (const Ramp *inner_ramp = r_index ? r_index->base.as() : nullptr; - inner_ramp && - inner_ramp->base.type().is_scalar() && - !is_const_one(inner_ramp->stride) && - is_const_one(r_index->stride)) { - // If it's a nested ramp and the outer ramp has stride 1, swap the - // nesting order of the ramps to make dense loads and transpose the - // resulting vector instead. - Expr transposed_index = - Ramp::make(Ramp::make(inner_ramp->base, make_one(inner_ramp->base.type()), r_index->lanes), - Broadcast::make(inner_ramp->stride, r_index->lanes), inner_ramp->lanes); - Expr transposed_predicate = (predicate.as() ? - predicate : // common case optimization - Shuffle::make_transpose(predicate, inner_ramp->lanes)); - Expr transposed_load = - Load::make(op->type, op->name, transposed_index, op->image, op->param, transposed_predicate, align); - return mutate(Shuffle::make_transpose(transposed_load, r_index->lanes), info); + } else if (MultiRamp mr; + index.type().is_vector() && + // Don't do expensive analysis in the common case of a load of a ramp of scalars. + !(r_index && r_index->base.type().is_scalar()) && + // It's a multi-dimensional multiramp. + is_multiramp(index, Scope::empty_scope(), &mr) && + mr.dimensions() > 1 && + // The innermost stride isn't already one. + !is_const_one(mr.strides[0]) && + // We can successfully rotate a stride one dimension innermost. + (A = mr.rotate_stride_one_innermost()) > 0) { + // Rotating the stride one dimension innermost made the load dense, but + // we must now transpose the predicate to match the transposed index, + // and inverse-transpose the loaded value to restore the original lane + // ordering. + Expr permuted_predicate; + const Broadcast *b_pred = predicate.as(); + if (b_pred && b_pred->value.type().is_scalar()) { + permuted_predicate = predicate; + } else { + permuted_predicate = Shuffle::make_transpose(predicate, A); + } + + Expr permuted_load = + Load::make(op->type, op->name, mr.to_expr(), op->image, + op->param, permuted_predicate, align); + int B = op->type.lanes() / A; + return mutate(Shuffle::make_transpose(permuted_load, B), info); } else if (predicate.same_as(op->predicate) && index.same_as(op->index) && align == op->alignment) { return op; } else { diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 2a614ac81744..9f339dd26614 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -295,8 +295,11 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { vector new_concat_vectors; for (const auto &v : inner_shuffle->vectors) { // Check if current concat vector overlaps with slice. - if ((concat_index >= slice_min && concat_index <= slice_max) || - ((concat_index + v.type().lanes() - 1) >= slice_min && (concat_index + v.type().lanes() - 1) <= slice_max)) { + // Standard interval overlap: [a, b] and [c, d] overlap + // iff a <= d && c <= b. + int v_start = concat_index; + int v_end = concat_index + v.type().lanes() - 1; + if (v_start <= slice_max && slice_min <= v_end) { if (new_slice_start < 0) { new_slice_start = concat_index; } diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index d8c0a9a7baea..c8f9d3effc6d 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -1,7 +1,11 @@ #include "Simplify_Internal.h" +#include +#include + #include "ExprUsesVar.h" #include "IRMutator.h" +#include "MultiRamp.h" #include "Substitute.h" namespace Halide { @@ -357,6 +361,7 @@ Stmt Simplify::visit(const Store *op) { } ModulusRemainder align = ModulusRemainder::intersect(op->alignment, base_info.alignment); + int A; if (is_const_zero(predicate)) { // Predicate is always false @@ -388,26 +393,38 @@ Stmt Simplify::visit(const Store *op) { Stmt s = Block::make(stores); s = LetStmt::make(var_name, value, s); return mutate(s); - } else if (const Ramp *inner_ramp = r_index ? r_index->base.as() : nullptr; - inner_ramp && - inner_ramp->base.type().is_scalar() && - !is_const_one(inner_ramp->stride) && - is_const_one(r_index->stride)) { - // If it's a nested ramp and the outer ramp has stride 1, swap the - // nesting order of the ramps to make dense stores and transpose the - // index and value instead. Later in lowering after flattening the - // nested ramps it will turn into a concat of dense ramps and hit the - // case above. - Expr transposed_index = - Ramp::make(Ramp::make(inner_ramp->base, make_one(inner_ramp->base.type()), r_index->lanes), - Broadcast::make(inner_ramp->stride, r_index->lanes), inner_ramp->lanes); - Expr transposed_value = Shuffle::make_transpose(value, inner_ramp->lanes); - Expr transposed_predicate = (predicate.as() ? - predicate : // common case optimization - Shuffle::make_transpose(predicate, inner_ramp->lanes)); - return mutate(Store::make(op->name, transposed_value, transposed_index, - op->param, transposed_predicate, align)); - } else if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index) && align == op->alignment) { + } else if (MultiRamp mr; + index.type().is_vector() && + // Don't do expensive analysis in the common case of a load of a ramp of scalars. + !(r_index && r_index->base.type().is_scalar()) && + // It's a multi-dimensional multiramp + is_multiramp(index, Scope::empty_scope(), &mr) && + mr.dimensions() > 1 && + // The innermost stride isn't already one + !is_const_one(mr.strides[0]) && + // We can successfully rotate a stride one dimension innermost + (A = mr.rotate_stride_one_innermost()) > 0) { + + // Rotating the stride one dimension innermost in the index made the + // resulting store dense. Now permute the value and predicate to match + // the new lane order using a single make_transpose. Later in lowering, + // after flattening the nested ramps, this turns into a concat of dense + // ramps and hits the case above. + + Expr permuted_value = Shuffle::make_transpose(value, A); + Expr permuted_predicate; + const Broadcast *b_pred = predicate.as(); + if (b_pred && b_pred->value.type().is_scalar()) { + permuted_predicate = predicate; + } else { + permuted_predicate = Shuffle::make_transpose(predicate, A); + } + return mutate(Store::make(op->name, permuted_value, mr.to_expr(), + op->param, permuted_predicate, align)); + } else if (predicate.same_as(op->predicate) && + value.same_as(op->value) && + index.same_as(op->index) && + align == op->alignment) { return op; } else { return Store::make(op->name, value, index, op->param, predicate, align); diff --git a/src/Util.h b/src/Util.h index f29e0ad9b6f0..01d89eb3b6ff 100644 --- a/src/Util.h +++ b/src/Util.h @@ -183,6 +183,26 @@ bool ends_with(const std::string &str, const std::string &suffix); * this function to return the same string without any copies being made. */ std::string replace_all(std::string str, const std::string &find, const std::string &replace); +/** Invoke `f(coord)` for each integer coordinate in the box + * `[0, sizes[0]) x [0, sizes[1]) x ...`, in lex order with the first + * axis varying fastest. `coord` is a `const std::vector &` of the + * same length as `sizes`. The empty-sizes case invokes `f` once with an + * empty coord (a 0-dim box has one point). */ +template +void for_each_coordinate(const std::vector &sizes, F &&f) { + std::vector coord(sizes.size(), 0); + while (true) { + f(coord); + size_t k = 0; + while (k < sizes.size() && ++coord[k] == sizes[k]) { + coord[k++] = 0; + } + if (k == sizes.size()) { + return; + } + } +} + /** Split the source string using 'delim' as the divider. */ std::vector split_string(const std::string &source, const std::string &delim); diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index ebfd63e860bd..f97099287003 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -1,4 +1,5 @@ #include +#include #include #include "CSE.h" @@ -9,10 +10,12 @@ #include "IRMutator.h" #include "IROperator.h" #include "IRPrinter.h" +#include "MultiRamp.h" #include "Scope.h" #include "Simplify.h" #include "Solve.h" #include "Substitute.h" +#include "Util.h" #include "VectorizeLoops.h" namespace Halide { @@ -41,7 +44,7 @@ const Broadcast *as_scalar_broadcast(const Expr &e) { } else { return nullptr; } -}; +} /** Find the exact scalar max and min lanes of a vector expression. Not * conservative like bounds_of_expr, but uses similar rules for some common node @@ -189,119 +192,6 @@ Interval bounds_of_lanes(const Expr &e) { Expr max_lane = VectorReduce::make(VectorReduce::Max, e, 1); return {min_lane, max_lane}; } -}; - -// A ramp with the lanes repeated inner_repetitions times, and then -// the whole vector repeated outer_repetitions times. -// E.g: <0 0 2 2 4 4 6 6 0 0 2 2 4 4 6 6>. -struct InterleavedRamp { - Expr base, stride; - int lanes, inner_repetitions, outer_repetitions; -}; - -bool equal_or_zero(int a, int b) { - return a == 0 || b == 0 || a == b; -} - -bool is_interleaved_ramp(const Expr &e, const Scope &scope, InterleavedRamp *result) { - if (const Ramp *r = e.as()) { - const Broadcast *b_base = r->base.as(); - const Broadcast *b_stride = r->stride.as(); - if (r->base.type().is_scalar()) { - result->base = r->base; - result->stride = r->stride; - result->lanes = r->lanes; - result->inner_repetitions = 1; - result->outer_repetitions = 1; - return true; - } else if (b_base && b_stride && b_base->lanes == b_stride->lanes) { - // Ramp of broadcast - result->base = b_base->value; - result->stride = b_stride->value; - result->lanes = r->lanes; - result->inner_repetitions = b_base->lanes; - result->outer_repetitions = 1; - return true; - } - } else if (const Broadcast *b = e.as()) { - if (b->value.type().is_scalar()) { - result->base = b->value; - result->stride = 0; - result->lanes = b->lanes; - result->inner_repetitions = 0; - result->outer_repetitions = 0; - return true; - } else if (is_interleaved_ramp(b->value, scope, result)) { - // Broadcast of interleaved ramp - result->outer_repetitions *= b->lanes; - return true; - } - } else if (const Add *add = e.as()) { - InterleavedRamp ra; - if (is_interleaved_ramp(add->a, scope, &ra) && - is_interleaved_ramp(add->b, scope, result) && - equal_or_zero(ra.inner_repetitions, result->inner_repetitions) && - equal_or_zero(ra.outer_repetitions, result->outer_repetitions)) { - result->base = simplify(result->base + ra.base); - result->stride = simplify(result->stride + ra.stride); - result->inner_repetitions = std::max(result->inner_repetitions, ra.inner_repetitions); - result->outer_repetitions = std::max(result->outer_repetitions, ra.outer_repetitions); - return true; - } - } else if (const Sub *sub = e.as()) { - InterleavedRamp ra; - if (is_interleaved_ramp(sub->a, scope, &ra) && - is_interleaved_ramp(sub->b, scope, result) && - equal_or_zero(ra.inner_repetitions, result->inner_repetitions) && - equal_or_zero(ra.outer_repetitions, result->outer_repetitions)) { - result->base = simplify(ra.base - result->base); - result->stride = simplify(ra.stride - result->stride); - result->inner_repetitions = std::max(result->inner_repetitions, ra.inner_repetitions); - result->outer_repetitions = std::max(result->outer_repetitions, ra.outer_repetitions); - return true; - } - } else if (const Mul *mul = e.as()) { - std::optional b; - if (is_interleaved_ramp(mul->a, scope, result) && - (b = as_const_int(mul->b))) { - result->base = simplify(result->base * (int)(*b)); - result->stride = simplify(result->stride * (int)(*b)); - return true; - } - } else if (const Div *div = e.as
()) { - std::optional b; - if (is_interleaved_ramp(div->a, scope, result) && - (b = as_const_int(div->b)) && - is_const_one(result->stride) && - (result->inner_repetitions == 1 || - result->inner_repetitions == 0) && - can_prove((result->base % (int)(*b)) == 0)) { - // TODO: Generalize this. Currently only matches - // ramp(base*b, 1, lanes) / b - // broadcast(base * b, lanes) / b - result->base = simplify(result->base / (int)(*b)); - result->inner_repetitions *= (int)(*b); - return true; - } - } else if (const Mod *mod = e.as()) { - std::optional b; - if (is_interleaved_ramp(mod->a, scope, result) && - (b = as_const_int(mod->b)) && - (result->outer_repetitions == 1 || - result->outer_repetitions == 0) && - can_prove(((int)(*b) % result->stride) == 0)) { - // ramp(base, 2, lanes) % 8 - result->base = simplify(result->base % (int)(*b)); - result->stride = simplify(result->stride % (int)(*b)); - result->outer_repetitions *= (int)(*b); - return true; - } - } else if (const Variable *var = e.as()) { - if (const Expr *e = scope.find(var->name)) { - return is_interleaved_ramp(*e, scope, result); - } - } - return false; } // Allocations inside vectorized loops grow an additional inner @@ -749,41 +639,8 @@ class VectorSubs : public IRMutator { op->call_type, op->func, op->value_index, op->image, op->param); } - Expr visit(const Let *op) override { - // Vectorize the let value and check to see if it was vectorized by - // this mutator. The type of the expression might already be vector - // width. - Expr mutated_value = simplify(mutate(op->value)); - bool was_vectorized = (!op->value.type().is_vector() && - mutated_value.type().is_vector()); - - // If the value was vectorized by this mutator, add a new name to - // the scope for the vectorized value expression. - string vectorized_name; - if (was_vectorized) { - vectorized_name = get_widened_var_name(op->name); - scope.push(op->name, op->value); - vector_scope.push(vectorized_name, mutated_value); - } - - Expr mutated_body = mutate(op->body); - - InterleavedRamp ir; - if (is_interleaved_ramp(mutated_value, vector_scope, &ir)) { - return substitute(vectorized_name, mutated_value, mutated_body); - } else if (mutated_value.same_as(op->value) && - mutated_body.same_as(op->body)) { - return op; - } else if (was_vectorized) { - scope.pop(op->name); - vector_scope.pop(vectorized_name); - return Let::make(vectorized_name, mutated_value, mutated_body); - } else { - return Let::make(op->name, mutated_value, mutated_body); - } - } - - Stmt visit(const LetStmt *op) override { + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { Expr mutated_value = simplify(mutate(op->value)); string vectorized_name = op->name; @@ -791,33 +648,42 @@ class VectorSubs : public IRMutator { bool was_vectorized = (!op->value.type().is_vector() && mutated_value.type().is_vector()); + decltype(op->body) mutated_body; if (was_vectorized) { vectorized_name = get_widened_var_name(op->name); - scope.push(op->name, op->value); - vector_scope.push(vectorized_name, mutated_value); // Also keep track of the original let, in case inner code scalarizes. containing_lets.emplace_back(op->name, op->value); - } - Stmt mutated_body = mutate(op->body); + ScopedBinding + bind(scope, op->name, op->value), + bind_vec(vector_scope, vectorized_name, mutated_value); - if (was_vectorized) { + mutated_body = mutate(op->body); containing_lets.pop_back(); - scope.pop(op->name); - vector_scope.pop(vectorized_name); + } else { + mutated_body = mutate(op->body); } - InterleavedRamp ir; - if (is_interleaved_ramp(mutated_value, vector_scope, &ir)) { + MultiRamp m; + if (mutated_value.type().is_vector() && + is_multiramp(mutated_value, vector_scope, &m)) { return substitute(vectorized_name, mutated_value, mutated_body); } else if (mutated_value.same_as(op->value) && mutated_body.same_as(op->body)) { return op; } else { - return LetStmt::make(vectorized_name, mutated_value, mutated_body); + return LetOrLetStmt::make(vectorized_name, mutated_value, mutated_body); } } + Expr visit(const Let *op) override { + return visit_let(op); + } + + Stmt visit(const LetStmt *op) override { + return visit_let(op); + } + Stmt visit(const Provide *op) override { internal_error << "Vectorizing a Provide node is unimplemented. " << "Vectorization usually runs after storage flattening.\n"; @@ -1017,8 +883,8 @@ class VectorSubs : public IRMutator { string vectorized_name = get_widened_var_name(var); Expr vectorized_value = vector_scope.get(vectorized_name); vector_scope.pop(vectorized_name); - InterleavedRamp ir; - if (is_interleaved_ramp(vectorized_value, vector_scope, &ir)) { + MultiRamp m; + if (is_multiramp(vectorized_value, vector_scope, &m)) { body = substitute(vectorized_name, vectorized_value, body); } else { body = LetStmt::make(vectorized_name, vectorized_value, body); @@ -1096,7 +962,21 @@ class VectorSubs : public IRMutator { } Stmt visit(const Atomic *op) override { - // Recognize a few special cases that we can handle as within-vector reduction trees. + // Recognize a few special cases that we can handle as within-vector + // reduction trees. + + // We may partially succeed, in which case we'll have (unrolled) loops + // to rewrap. + struct UnrolledLoop { + std::string name; + int extent; + // Index of this loop's dim in the pre-alias-peel MultiRamp. Used + // by the unroll below to construct a shuffle mask selecting the + // corresponding slice of the reduced value vector. + int dim; + }; + std::vector unrolled_loops; + do { if (!op->mutex_name.empty()) { // We can't vectorize over a mutex @@ -1203,19 +1083,15 @@ class VectorSubs : public IRMutator { Expr store_index = mutate(store->index); Expr load_index = mutate(load_a->index); - // The load and store indices must be the same interleaved - // ramp (or the same scalar, in the total reduction case). - InterleavedRamp store_ir, load_ir; + // The load and store indices must be the same multiramp + // (or the same scalar, in the total reduction case). + MultiRamp store_mr, load_mr; Expr test; if (store_index.type().is_scalar()) { test = simplify(load_index == store_index); - } else if (is_interleaved_ramp(store_index, vector_scope, &store_ir) && - is_interleaved_ramp(load_index, vector_scope, &load_ir) && - store_ir.inner_repetitions == load_ir.inner_repetitions && - store_ir.outer_repetitions == load_ir.outer_repetitions && - store_ir.lanes == load_ir.lanes) { - test = simplify(store_ir.base == load_ir.base && - store_ir.stride == load_ir.stride); + } else if (is_multiramp(store_index, vector_scope, &store_mr) && + is_multiramp(load_index, vector_scope, &load_mr)) { + test = store_mr == load_mr; } if (!test.defined()) { @@ -1250,25 +1126,109 @@ class VectorSubs : public IRMutator { }; int output_lanes = 1; + MultiRamp b_shape_mr; if (store_index.type().is_scalar()) { // The index doesn't depend on the value being // vectorized, so it's a total reduction. - b = VectorReduce::make(reduce_op, b, 1); } else { - output_lanes = store_index.type().lanes() / (store_ir.inner_repetitions * store_ir.outer_repetitions); + // The output lanes is >1, so there must be at least one + // multiramp dimension with non-zero stride. Dims that + // can't be part of an alias-free store fall into two + // kinds, both discovered by one call to alias_free_slice: + // + // - Stride-zero dims: lanes duplicate a value across the + // store, so we fold the duplicates with the reduction + // op. The innermost-in-original stride-zero dim (if + // any) becomes a VectorReduce; others need a + // reduction tree over slices of b. + // - Non-zero-stride aliasing dims (symbolic strides, or + // strides that overlap such that we can't prove + // uniqueness): different lanes of the store go to + // different addresses, so we unroll a containing loop + // and do a slice-per-iteration. + // + // TODO: the innermost-VectorReduce fast-path is keyed on + // "dim 0 of the original". We could move other stride-zero + // dims inward via a transpose and VectorReduce them too; + // might be better on some targets. + + // b's current lane layout. Starts matching the full + // store_mr (before any peel); updated as reductions and + // shuffles reshape it. We use this for the shuffle masks + // that slice b per unrolled iteration below. + b_shape_mr = store_mr; + + std::vector peeled = + store_mr.alias_free_slice(); + + // Snapshot the original strides so we can identify which + // original dims had stride zero after b_shape_mr gets + // reordered below. + std::vector orig_strides = b_shape_mr.strides; + + // Partition peels by handling strategy. + int inner_dup = 1; // >1 if a VectorReduce applies. + int outer_dup = 1; // >1 if a reduction tree applies. + std::vector loop_peels; + for (const auto &p : peeled) { + if (is_const_zero(p.stride)) { + if (p.dim == 0) { + // Stride-zero peel at the innermost position: + // its duplicates are contiguous in b, so we + // can use VectorReduce directly. + inner_dup = p.lanes; + } else { + outer_dup *= p.lanes; + } + } else { + loop_peels.push_back(p); + } + } - store_index = Ramp::make(store_ir.base, store_ir.stride, output_lanes / store_ir.base.type().lanes()); - if (store_ir.inner_repetitions > 1) { - b = VectorReduce::make(reduce_op, b, output_lanes * store_ir.outer_repetitions); + if (inner_dup > 1) { + int new_lanes = b_shape_mr.total_lanes() / inner_dup; + b = VectorReduce::make(reduce_op, b, new_lanes); + b_shape_mr.slice(0, make_zero(b_shape_mr.base.type())); } - // Handle outer repetitions by unrolling the reduction - // over slices. - if (store_ir.outer_repetitions > 1) { - // First remove all powers of two with a binary reduction tree. - int reps = store_ir.outer_repetitions; + // If any non-innermost stride-zero dims need combining, + // shuffle b so their duplicates become contiguous, then + // reduce them with a tree over contiguous sub-vectors. + if (outer_dup > 1) { + // Reorder the remaining zero-stride dims outermost, + // keeping the rest in their relative order. + int d = b_shape_mr.dimensions(); + std::vector perm(d); + std::iota(perm.begin(), perm.end(), 0); + auto stride_not_zero = [&](int i) { + return !is_const_zero(b_shape_mr.strides[i]); + }; + auto mid = std::stable_partition(perm.begin(), perm.end(), stride_not_zero); + + int n_kept = mid - perm.begin(); + // shuffle_from_permuted gives us idx such that + // Shuffle(, idx) == . Here we have + // b in original lane order and want it in permuted + // order, so we invert that as a permutation. + std::vector idx = b_shape_mr.shuffle_from_permuted(perm); + std::vector inverted(idx.size()); + for (size_t i = 0; i < idx.size(); i++) { + inverted[idx[i]] = (int)i; + } + b = Shuffle::make({b}, inverted); + b_shape_mr.reorder(perm); + + // An inner reduction is a VectorReduce node. An outer + // reduction is cutting the vector into contiguous pieces, + // and adding those pieces together. Now that all the + // remaining stride-0 dims are outermost, we can do that in + // a binary tree. We slice the vector in half and add the + // halves for as long as possible, and then slice up what's + // left into pieces and add the pieces. For big power-of-two + // reductions this produces log(n) IR nodes. + int reps = outer_dup; while (reps % 2 == 0) { int l = b.type().lanes() / 2; Expr b0 = Shuffle::make_slice(b, 0, 1, l); @@ -1276,17 +1236,45 @@ class VectorSubs : public IRMutator { b = binop(b0, b1); reps /= 2; } - - // Then reduce linearly over slices for the rest. if (reps > 1) { - Expr v = Shuffle::make_slice(b, 0, 1, output_lanes); + int chunk = b.type().lanes() / reps; + Expr v = Shuffle::make_slice(b, 0, 1, chunk); for (int i = 1; i < reps; i++) { - Expr slice = simplify(Shuffle::make_slice(b, i * output_lanes, 1, output_lanes)); + Expr slice = simplify(Shuffle::make_slice(b, i * chunk, 1, chunk)); v = binop(v, slice); } b = v; } + + // Drop the outer-zero peeled dims from b_shape_mr (they + // are the trailing dims after the reorder above). + b_shape_mr.strides.resize(n_kept); + b_shape_mr.lanes.resize(n_kept); + } + + // We still have peeled dims without zero stride to handle. + // Emit the unrolled containing loops for non-zero aliasing + // peels. Their shuffle indices below select the right slice of + // b per iteration. The loop.dim field is the dim's position in + // b_shape_mr's current layout: the count of earlier original + // dims that survived both the inner-dim reduction and the + // outer-zero drop. + // orig dim 0 was removed if we VectorReduce'd it away. + const int dim_offset = inner_dup > 1 ? 1 : 0; + for (const auto &p : loop_peels) { + int pos = 0; + for (int i = dim_offset; i < p.dim; i++) { + if (!is_const_zero(orig_strides[i])) { + pos++; + } + } + std::string name = unique_name('t'); + unrolled_loops.emplace_back( + UnrolledLoop{name, p.lanes, pos}); + store_mr.base += Variable::make(Int(32), name) * p.stride; } + output_lanes = store_mr.total_lanes(); + store_index = store_mr.to_expr(); } Expr new_load = Load::make(load_a->type.with_lanes(output_lanes), @@ -1294,12 +1282,58 @@ class VectorSubs : public IRMutator { load_a->param, const_true(output_lanes), ModulusRemainder{}); - Expr lhs = cast(b.type(), new_load); - b = binop(lhs, b); - b = cast(new_load.type(), b); + Expr lhs = cast(b.type().with_lanes(output_lanes), new_load); - Stmt s = Store::make(store->name, b, store_index, store->param, - const_true(b.type().lanes()), store->alignment); + Stmt s; + if (unrolled_loops.empty()) { + b = binop(lhs, b); + b = cast(new_load.type(), b); + s = Store::make(store->name, b, store_index, store->param, + const_true(b.type().lanes()), store->alignment); + } else { + // Wrap any containing loops we still need (unrolled). We + // enumerate the cartesian product of loop iteration values + // directly, so that each store's b-slice can be computed + // from the full multi-index. + std::string b_var_name = unique_name('b'); + Expr b_var = Variable::make(b.type().with_lanes(output_lanes), b_var_name); + Stmt store_template = + Store::make(store->name, cast(new_load.type(), binop(lhs, b_var)), + store_index, store->param, + const_true(output_lanes), ModulusRemainder{}); + std::string full_b_var_name = unique_name('b'); + Expr full_b_var = Variable::make(b.type(), full_b_var_name); + + std::vector peeled_dims, loop_extents; + peeled_dims.reserve(unrolled_loops.size()); + loop_extents.reserve(unrolled_loops.size()); + for (const auto &loop : unrolled_loops) { + peeled_dims.push_back(loop.dim); + loop_extents.push_back(loop.extent); + } + // Fully unroll the peeled dims into a flat Block of stores: + // we enumerate every multi-index v in the cartesian product of + // loop_extents and emit one store per iteration, substituting + // the loop variable with the corresponding constant. There is + // no runtime loop nest — UnrolledLoop describes the dims we + // peeled off of b, not loops that survive in the output. + std::vector block; + block.reserve(b_shape_mr.total_lanes() / output_lanes); + for_each_coordinate(loop_extents, [&](const std::vector &v) { + // v is the iteration multi-index (innermost first, + // matching the order in unrolled_loops). + Expr b_slice = Shuffle::make({full_b_var}, + b_shape_mr.shuffle_from_slice(peeled_dims, v)); + Stmt this_store = store_template; + for (size_t j = 0; j < unrolled_loops.size(); j++) { + this_store = substitute(unrolled_loops[j].name, v[j], this_store); + } + this_store = substitute(b_var_name, b_slice, this_store); + block.push_back(this_store); + }); + s = Block::make(block); + s = LetStmt::make(full_b_var_name, b, s); + } // We may still need the atomic node, if there was more // parallelism than just the vectorization. diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index efa0bddeb0e8..1df954dfb0ce 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -91,6 +91,7 @@ tests(GROUPS correctness dilate3x3.cpp div_by_zero.cpp div_round_to_zero.cpp + downsampling_reduce.cpp dynamic_allocation_in_gpu_kernel.cpp dynamic_reduction_bounds.cpp early_out.cpp @@ -230,6 +231,7 @@ tests(GROUPS correctness multi_way_select.cpp multipass_constraints.cpp multiple_outputs.cpp + multiramp.cpp mux.cpp narrow_predicates.cpp negative_split_factors.cpp @@ -324,6 +326,7 @@ tests(GROUPS correctness tracing_stack.cpp transitive_bounds.cpp transpose_idioms.cpp + transposed_vector_reduce.cpp trim_no_ops.cpp tuple_partial_update.cpp tuple_reduction.cpp diff --git a/test/correctness/downsampling_reduce.cpp b/test/correctness/downsampling_reduce.cpp new file mode 100644 index 000000000000..ed84345e3a4e --- /dev/null +++ b/test/correctness/downsampling_reduce.cpp @@ -0,0 +1,80 @@ +#include "Halide.h" + +using namespace Halide; +using namespace Halide::Internal; + +// Test that an atomic vectorized reduction with a downsampling write pattern +// (f(r/4) += g(r)) lowers to within-vector reductions rather than +// scalarizing. This exercises MultiRamp::div in the vectorize path. + +int main(int argc, char **argv) { + const int vec = 16; + const int factor = 4; + const int reduction_extent = vec; // r has `vec` lanes; f's output has vec/factor + + Func g{"g"}; + Var x{"x"}; + RDom r(0, reduction_extent); + + ImageParam input(Int(32), 1); + Buffer input_buf(reduction_extent); + input_buf.for_each_element([&](int i) { input_buf(i) = i * 3 + 7; }); + input.set(input_buf); + + // f(r/4) += g(r): four consecutive lanes of the reduction contribute to + // one output lane. Within one vector of r, the output multiramp has a + // stride-zero innermost dim of extent `factor` and a stride-1 outer dim + // of extent vec/factor. + g(x) = 0; + g(r / factor) += input(r); + + Buffer correct = g.realize({reduction_extent / factor}); + + g.bound(x, 0, reduction_extent / factor) + .update() + .atomic() + .vectorize(r); + + // Check that the reduction over r was vectorized away: after vectorize, + // there should be no inner for-loop over r, and the lowered IR should + // contain a VectorReduce node. + int inner_for_loops = 0; + int vector_reduces = 0; + auto checker = LambdaMutator{ + [&](auto *self, const For *op) { + if (op->name.find("r") != std::string::npos) { + inner_for_loops++; + } + return self->visit_base(op); + }, + [&](auto *self, const VectorReduce *op) { + vector_reduces++; + return self->visit_base(op); + }}; + g.add_custom_lowering_pass(&checker, nullptr); + + Buffer out = g.realize({reduction_extent / factor}); + + for (int i = 0; i < reduction_extent / factor; i++) { + if (out(i) != correct(i)) { + printf("out(%d) = %d instead of %d\n", i, out(i), correct(i)); + return 1; + } + } + + if (inner_for_loops > 0) { + printf("Atomic vectorization of downsampling reduction failed: " + "lowered code contained %d for loop(s) over r\n", + inner_for_loops); + return 1; + } + + if (vector_reduces == 0) { + printf("Expected a VectorReduce node in the lowered IR, but " + "didn't find one\n"); + return 1; + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/interleave.cpp b/test/correctness/interleave.cpp index cbee263f5487..6b437c0651d2 100644 --- a/test/correctness/interleave.cpp +++ b/test/correctness/interleave.cpp @@ -16,7 +16,7 @@ class CountInterleaves : public IRVisitor { using IRVisitor::visit; void visit(const Shuffle *op) override { - if (op->is_interleave()) { + if (op->is_interleave() || op->is_transpose()) { result++; } IRVisitor::visit(op); diff --git a/test/correctness/multiramp.cpp b/test/correctness/multiramp.cpp new file mode 100644 index 000000000000..c7b8aae195db --- /dev/null +++ b/test/correctness/multiramp.cpp @@ -0,0 +1,640 @@ +#include "Halide.h" + +#include +#include +#include + +using namespace Halide; +using namespace Halide::Internal; + +namespace { + +int failures = 0; + +// Expand a MultiRamp (with const base and const strides) to a flat vector +// using the same innermost-fastest enumeration the IR uses. +std::vector expand(const MultiRamp &m) { + auto cb = as_const_int(simplify(m.base)); + internal_assert(cb) << "expand() only supports const bases, got " << m.base << "\n"; + int64_t b = *cb; + std::vector strides; + for (const Expr &s : m.strides) { + auto cs = as_const_int(simplify(s)); + internal_assert(cs) << "expand() only supports const strides, got " << s << "\n"; + strides.push_back(*cs); + } + int total = 1; + for (int n : m.lanes) { + total *= n; + } + std::vector result; + result.reserve(total); + for (int flat = 0; flat < total; flat++) { + int rem = flat; + int64_t v = b; + for (size_t i = 0; i < m.lanes.size(); i++) { + int idx = rem % m.lanes[i]; + rem /= m.lanes[i]; + v += strides[i] * idx; + } + result.push_back((int)v); + } + return result; +} + +void print_vec(const std::vector &v) { + printf("["); + for (size_t i = 0; i < v.size(); i++) { + printf("%s%d", i ? ", " : "", v[i]); + } + printf("]"); +} + +void check_seq(const std::vector &got, const std::vector &want, + const char *msg, int line) { + if (got != want) { + printf("FAIL at %d: %s\n got ", line, msg); + print_vec(got); + printf("\n want "); + print_vec(want); + printf("\n"); + failures++; + } +} + +#define CHECK(cond, msg) \ + do { \ + if (!(cond)) { \ + printf("FAIL at %d: %s\n", __LINE__, msg); \ + failures++; \ + } \ + } while (0) + +#define CHECK_SEQ_LIT(got, msg, ...) check_seq((got), std::vector{__VA_ARGS__}, (msg), __LINE__) +#define CHECK_SEQ(got, want, msg) check_seq((got), (want), (msg), __LINE__) + +// ---- MultiRamp::add ------------------------------------------------------ + +void check_add_refinable_shapes() { + // From the math problem: A = ramp(0,1,6) = [0,1,2,3,4,5], + // B = ramp(ramp(0,2,2),100,3) = [0,2,100,102,200,202], + // A + B = [0,3,102,105,204,207]. + // Shapes (6,) and (2,3) (innermost first) must refine to (2,3). + MultiRamp A{0, {1}, {6}}; + MultiRamp B{0, {2, 100}, {2, 3}}; + CHECK(A.add(B), "add with refinable shapes"); + CHECK_SEQ_LIT(expand(A), "refinable-shape add values", 0, 3, 102, 105, 204, 207); +} + +void check_add_same_shape() { + MultiRamp A{10, {3, 100}, {4, 2}}; + MultiRamp B{5, {-1, 50}, {4, 2}}; + auto a_seq = expand(A), b_seq = expand(B); + CHECK(A.add(B), "same-shape add"); + std::vector want(8); + for (size_t i = 0; i < a_seq.size(); i++) { + want[i] = a_seq[i] + b_seq[i]; + } + CHECK_SEQ(expand(A), want, "same-shape add values"); +} + +void check_add_incompatible_shapes() { + // Shapes with innermost sizes 3 vs 2 and outer sizes 2 vs 3 can't refine. + MultiRamp A{0, {1, 100}, {3, 2}}; + MultiRamp B{0, {1, 100}, {2, 3}}; + CHECK(!A.add(B), "incompatible shapes rejected"); +} + +void check_add_cancels_to_zero() { + // 2·A + (-2)·A should simplify to a single zero-stride dim (one flat dim + // of the total lane count). + MultiRamp A{7, {3, 100}, {4, 2}}; + MultiRamp B = A; + A.mul(2); + B.mul(-2); + CHECK(A.add(B), "add of cancelling multiramps"); + CHECK(A.lanes.size() == 1, "cancelled add should collapse to 1 dim"); + if (A.lanes.size() == 1) { + CHECK(A.lanes[0] == 8, "cancelled add lanes = 8"); + auto s = as_const_int(simplify(A.strides[0])); + CHECK(s && *s == 0, "cancelled add stride = 0"); + auto b = as_const_int(simplify(A.base)); + CHECK(b && *b == 0, "cancelled add base = 0"); + } +} + +void check_add_scaled_outer() { + // Regression test for the stride-scaling bug: adding a 1D ramp of length 6 + // to a 2D ramp with shape (2,3) must scale the 1D's stride by 2 when + // producing the outer dim of the result. + // A = ramp(0,1,6) -> [0,1,2,3,4,5] + // B = ramp(ramp(0,0,2),100,3) = broadcast(0,2) then + ramp-of-100s + // -> [0,0,100,100,200,200] + // A+B = [0,1,102,103,204,205] + MultiRamp A{0, {1}, {6}}; + MultiRamp B{0, {0, 100}, {2, 3}}; + CHECK(A.add(B), "scaled-outer add"); + CHECK_SEQ_LIT(expand(A), "scaled-outer values", 0, 1, 102, 103, 204, 205); +} + +// ---- MultiRamp::div ----------------------------------------------------- + +void check_div_pure_carry_const() { + MultiRamp A{8, {4, 12}, {2, 3}}; + auto a_seq = expand(A); + CHECK(A.div(4), "pure-carry div (const k)"); + std::vector want(a_seq.size()); + for (size_t i = 0; i < a_seq.size(); i++) { + want[i] = a_seq[i] / 4; + } + CHECK_SEQ(expand(A), want, "pure-carry div values"); +} + +void check_div_symbolic_strides() { + // Symbolic base and strides, all provably multiples of the denominator — + // every dim is pure carry. + Var v("v"); + MultiRamp A{2 * v, {2 * v, 8 * v}, {4, 5}}; + CHECK(A.div(2), "pure-carry div with symbolic strides"); + if (A.strides.size() == 2) { + // Strides become (2*v/2, 8*v/2) = (v, 4*v). + Expr want0 = simplify(A.strides[0] - v); + Expr want1 = simplify(A.strides[1] - 4 * v); + CHECK(is_const_zero(want0), "sym-stride div inner"); + CHECK(is_const_zero(want1), "sym-stride div outer"); + } +} + +void check_div_merges_adjacent_pure_carry() { + // Two pure-carry input dims whose output strides line up should collapse + // into a single output dim. + // Input values: 0, 4, 8, 12, 16, 20 (strides [4, 12], lanes [3, 2]). + // Divided by 4: 0, 1, 2, 3, 4, 5 — a flat 1D ramp of length 6. + MultiRamp A{0, {4, 12}, {3, 2}}; + CHECK(A.div(4), "div of two pure-carry dims"); + CHECK(A.lanes.size() == 1, "adjacent dims should merge into one"); + if (A.lanes.size() == 1) { + CHECK(A.lanes[0] == 6, "merged lane count"); + } + CHECK_SEQ_LIT(expand(A), "merged values", 0, 1, 2, 3, 4, 5); +} + +void check_div_with_split() { + // ramp(0,2,6) / 4 = [0,0,1,1,2,2], needs a split of dim 6 -> (2,3). + MultiRamp A{0, {2}, {6}}; + CHECK(A.div(4), "div with split"); + CHECK_SEQ_LIT(expand(A), "split div values", 0, 0, 1, 1, 2, 2); +} + +void check_div_split_with_symbolic_stride() { + // Non-constant stride whose residue mod k is still pinned down: stride + // is 4*v + 2, which is always ≡ 2 (mod 4). The split needs p=2, which + // divides 6. The budget check uses r = 2 only. + Var v("v"); + MultiRamp A{0, {4 * v + 2}, {6}}; + CHECK(A.div(4), "div split with symbolic stride"); + // Expected shape after split: lanes (2, 3); inner stride = (4v+2)/4 + // (symbolic), outer stride = (4v+2)*2/4 = 2v+1. + CHECK(A.lanes.size() == 2, "split produced two output dims"); + if (A.lanes.size() == 2) { + CHECK(A.lanes[0] == 2 && A.lanes[1] == 3, "split lanes (2, 3)"); + // Outer stride should simplify to 2v + 1. + Expr outer = simplify(A.strides[1]); + Expr want = simplify(2 * v + 1); + CHECK(equal(outer, want), "outer stride is 2v+1"); + } +} + +void check_div_rejects_non_multiramp() { + // ramp(0,1,5)/2 = [0,0,1,1,2], not a multiramp (5 has no usable factor). + MultiRamp A{0, {1}, {5}}; + CHECK(!A.div(2), "should reject ramp(0,1,5)/2"); +} + +void check_div_rejects_unaligned_base() { + // ramp(2,2,6)/4 = [0,1,1,2,2,3] would be a multiramp, but our algorithm + // requires the base to be a known multiple of the denominator, and 2 is + // not a multiple of 4. + MultiRamp A{2, {2}, {6}}; + CHECK(!A.div(4), "should reject div when base isn't aligned"); +} + +void check_div_rejects_symbolic_denominator() { + // A symbolic (non-constant) denominator should fail cleanly. The code + // needs k as a known positive integer to reason about bucket sizes. + Var k("k"); + MultiRamp A{0, {1}, {4}}; + CHECK(!A.div(k), "should reject div with symbolic denominator"); + CHECK(!A.mod(k), "should reject mod with symbolic denominator"); +} + +// ---- MultiRamp::mod ----------------------------------------------------- + +void check_mod_basic() { + MultiRamp A{0, {1}, {6}}; + CHECK(A.mod(2), "mod basic"); + CHECK_SEQ_LIT(expand(A), "mod basic values", 0, 1, 0, 1, 0, 1); +} + +void check_mod_with_split() { + MultiRamp A{0, {2}, {6}}; + CHECK(A.mod(4), "mod with split"); + CHECK_SEQ_LIT(expand(A), "mod split values", 0, 2, 0, 2, 0, 2); +} + +void check_mod_symbolic_strides() { + // Symbolic base and strides, all provably multiples of the denominator: + // mod result is entirely zero. + Var v("v"); + MultiRamp A{2 * v, {6 * v, 10 * v}, {3, 2}}; + CHECK(A.mod(2), "mod pure-carry symbolic strides"); + for (const Expr &s : A.strides) { + auto c = as_const_int(simplify(s)); + CHECK(c && *c == 0, "sym-stride mod stride = 0"); + } + auto b = as_const_int(simplify(A.base)); + CHECK(b && *b == 0, "sym-stride mod base = 0"); +} + +void check_mod_rejects_non_multiramp() { + // ramp(0,1,5)%2 = [0,1,0,1,0], not a multiramp. + MultiRamp A{0, {1}, {5}}; + CHECK(!A.mod(2), "should reject ramp(0,1,5)%2"); +} + +// ---- End-to-end is_multiramp tests -------------------------------------- + +void check_recognize_1d_ramp() { + Expr e = Ramp::make(Expr(0), Expr(2), 4); + Scope scope; + MultiRamp m; + CHECK(is_multiramp(e, scope, &m), "recognize 1D ramp"); + if (m.lanes.size() == 1) { + CHECK(m.lanes[0] == 4, "1D lanes"); + } +} + +void check_recognize_nested_ramp() { + // ramp(ramp(0,1,2), broadcast(100,2), 3) -> strides [1,100], lanes [2,3]. + Expr inner = Ramp::make(Expr(0), Expr(1), 2); + Expr e = Ramp::make(inner, Broadcast::make(Expr(100), 2), 3); + Scope scope; + MultiRamp m; + CHECK(is_multiramp(e, scope, &m), "recognize nested ramp"); + if (m.lanes.size() == 2) { + CHECK(m.lanes[0] == 2 && m.lanes[1] == 3, "nested ramp lanes"); + } +} + +void check_recognize_add() { + Expr a = Ramp::make(Expr(0), Expr(1), 6); + Expr inner = Ramp::make(Expr(0), Expr(2), 2); + Expr b = Ramp::make(inner, Broadcast::make(Expr(100), 2), 3); + Expr sum = Add::make(a, b); + Scope scope; + MultiRamp m; + CHECK(is_multiramp(sum, scope, &m), "recognize add of two multiramps"); +} + +void check_recognize_div_const() { + Expr e = Div::make(Ramp::make(Expr(0), Expr(2), 6), + Broadcast::make(Expr(4), 6)); + Scope scope; + MultiRamp m; + CHECK(is_multiramp(e, scope, &m), "recognize ramp/const"); + CHECK_SEQ_LIT(expand(m), "recognized div values", 0, 0, 1, 1, 2, 2); +} + +void check_recognize_mod_const() { + Expr e = Mod::make(Ramp::make(Expr(0), Expr(1), 6), + Broadcast::make(Expr(2), 6)); + Scope scope; + MultiRamp m; + CHECK(is_multiramp(e, scope, &m), "recognize ramp%const"); + CHECK_SEQ_LIT(expand(m), "recognized mod values", 0, 1, 0, 1, 0, 1); +} + +void check_recognize_div_symbolic_strides() { + // (2*x) + ramp(0, 4, 4), divided by 2. Numerator has symbolic base, const + // strides that are multiples of 2. + Var x("x"); + Expr num = Broadcast::make(2 * x, 4) + Ramp::make(Expr(0), Expr(4), 4); + Expr e = Div::make(num, Broadcast::make(Expr(2), 4)); + Scope scope; + MultiRamp m; + CHECK(is_multiramp(e, scope, &m), "recognize symbolic-strides div"); + if (m.strides.size() == 1) { + auto s = as_const_int(simplify(m.strides[0])); + CHECK(s && *s == 2, "symbolic-strides div stride = 2"); + } +} + +// ---- Reordering and shuffle_from_permuted ------------------------------- + +void check_reorder() { + // Swap the two dims of a 2D multiramp. + // base 0, strides [1, 10], lanes [2, 3]: 0, 1, 10, 11, 20, 21 + // reordered [1, 0] -> strides [10, 1], lanes [3, 2]: 0, 10, 20, 1, 11, 21 + MultiRamp A{0, {1, 10}, {2, 3}}; + MultiRamp R = A; + R.reorder({1, 0}); + CHECK(R.lanes.size() == 2, "reordered dims"); + if (R.lanes.size() == 2) { + CHECK(R.lanes[0] == 3 && R.lanes[1] == 2, "reordered lane counts"); + auto s0 = as_const_int(simplify(R.strides[0])); + auto s1 = as_const_int(simplify(R.strides[1])); + CHECK(s0 && *s0 == 10, "reordered stride 0"); + CHECK(s1 && *s1 == 1, "reordered stride 1"); + } + CHECK_SEQ_LIT(expand(R), "reordered values", 0, 10, 20, 1, 11, 21); +} + +void check_shuffle_from_permuted_2d() { + // A has 2 dims; perm = [1, 0] swaps them. The shuffle takes the + // permuted lane order back to the original lane order. + MultiRamp A{0, {1, 10}, {2, 3}}; + MultiRamp P = A; + P.reorder({1, 0}); + std::vector idx = A.shuffle_from_permuted({1, 0}); + // For each output lane n (A's order), idx[n] is the input lane in P's + // order that carries the same value. + auto a_seq = expand(A); // 0, 1, 10, 11, 20, 21 + auto p_seq = expand(P); // 0, 10, 20, 1, 11, 21 + CHECK(idx.size() == a_seq.size(), "shuffle indices size"); + for (size_t n = 0; n < a_seq.size(); n++) { + CHECK(p_seq[idx[n]] == a_seq[n], "shuffle restores original lane"); + } + // And as a vector: [0, 3, 1, 4, 2, 5]. + std::vector want = {0, 3, 1, 4, 2, 5}; + CHECK(idx == want, "shuffle indices match expected"); +} + +void check_shuffle_from_permuted_identity() { + // perm = identity => indices = [0, 1, 2, ..., total_lanes-1]. + MultiRamp A{0, {1, 10, 100}, {2, 3, 4}}; + std::vector idx = A.shuffle_from_permuted({0, 1, 2}); + for (size_t n = 0; n < idx.size(); n++) { + CHECK((int)n == idx[n], "identity permutation indices"); + } +} + +void check_shuffle_from_permuted_3d() { + // 3D with cyclic permutation. Check by comparing expanded sequences. + // base 0, strides [1, 4, 20], lanes [2, 3, 2]. Values: + // i_0 + 4*i_1 + 20*i_2 for (i_0, i_1, i_2) in [2)x[3)x[2). + MultiRamp A{0, {1, 4, 20}, {2, 3, 2}}; + std::vector perm = {2, 0, 1}; // outermost becomes innermost + MultiRamp P = A; + P.reorder(perm); + std::vector idx = A.shuffle_from_permuted(perm); + auto a_seq = expand(A); + auto p_seq = expand(P); + CHECK(idx.size() == a_seq.size(), "3D shuffle size"); + for (size_t n = 0; n < a_seq.size(); n++) { + CHECK(p_seq[idx[n]] == a_seq[n], "3D shuffle restores original"); + } +} + +void check_shuffle_from_slice_2d() { + // A has 2 dims, lanes [2, 3]. Slice dim 1 at pos 2 should yield lanes + // [2]; the shuffle indices pick those lanes of A. + MultiRamp A{0, {1, 10}, {2, 3}}; + MultiRamp S = A; + S.slice(1, Expr(2)); + std::vector idx = A.shuffle_from_slice(std::vector{1}, std::vector{2}); + auto a_seq = expand(A); // 0, 1, 10, 11, 20, 21 + auto s_seq = expand(S); // 20, 21 + CHECK(idx.size() == s_seq.size(), "slice shuffle size"); + for (size_t n = 0; n < s_seq.size(); n++) { + CHECK(a_seq[idx[n]] == s_seq[n], "slice shuffle picks right lanes"); + } + std::vector want = {4, 5}; + CHECK(idx == want, "slice shuffle indices match expected"); +} + +void check_shuffle_from_slice_inner() { + // Slice the innermost dim. + MultiRamp A{0, {1, 10}, {2, 3}}; + MultiRamp S = A; + S.slice(0, Expr(1)); + std::vector idx = A.shuffle_from_slice(std::vector{0}, std::vector{1}); + auto a_seq = expand(A); // 0, 1, 10, 11, 20, 21 + auto s_seq = expand(S); // 1, 11, 21 + CHECK(idx.size() == s_seq.size(), "inner slice shuffle size"); + for (size_t n = 0; n < s_seq.size(); n++) { + CHECK(a_seq[idx[n]] == s_seq[n], "inner slice picks right lanes"); + } + std::vector want = {1, 3, 5}; + CHECK(idx == want, "inner slice indices match expected"); +} + +void check_shuffle_from_slice_3d() { + // 3D: strides [1, 4, 20], lanes [2, 3, 2]. Slice middle dim at pos 1. + MultiRamp A{0, {1, 4, 20}, {2, 3, 2}}; + MultiRamp S = A; + S.slice(1, Expr(1)); + std::vector idx = A.shuffle_from_slice(std::vector{1}, std::vector{1}); + auto a_seq = expand(A); + auto s_seq = expand(S); + CHECK(idx.size() == s_seq.size(), "3D slice shuffle size"); + for (size_t n = 0; n < s_seq.size(); n++) { + CHECK(a_seq[idx[n]] == s_seq[n], "3D slice picks right lanes"); + } +} + +// ---- MultiRamp::mul ------------------------------------------------------ + +void check_mul_basic() { + MultiRamp A{3, {1, 10}, {2, 3}}; // 3, 4, 13, 14, 23, 24 + A.mul(5); + CHECK_SEQ_LIT(expand(A), "mul values", 15, 20, 65, 70, 115, 120); +} + +// ---- MultiRamp::operator== ---------------------------------------------- + +void check_equality_same() { + MultiRamp A{0, {1, 10}, {2, 3}}; + Expr e = simplify(A == A); + CHECK(is_const_one(e), "multiramp equals itself"); +} + +void check_equality_different() { + MultiRamp A{0, {1, 10}, {2, 3}}; + MultiRamp B{0, {1, 10}, {3, 2}}; // same total lanes, different shape + // A.to_expr() == [0,1,10,11,20,21], B.to_expr() = [0,1,2,10,11,12]; + // so they are not equal in every lane. + Expr e = simplify(A == B); + CHECK(is_const_zero(e), "different multiramps compare false"); +} + +void check_equality_scalar() { + MultiRamp A{42, {}, {}}; + MultiRamp B{42, {}, {}}; + MultiRamp C{7, {}, {}}; + CHECK(is_const_one(simplify(A == B)), "scalar multiramp equality"); + CHECK(is_const_zero(simplify(A == C)), "scalar multiramp inequality"); +} + +// ---- MultiRamp::alias_free_slice ---------------------------------------- + +void check_alias_free_slice_all_unique() { + // All lanes of the returned slice should be unique. + MultiRamp A{5, {1, 16}, {4, 3}}; // clearly alias-free + auto peeled = A.alias_free_slice(); + CHECK(peeled.empty(), "fully alias-free: nothing peeled"); + auto seq = expand(A); + std::set unique(seq.begin(), seq.end()); + CHECK(unique.size() == seq.size(), "slice lanes are unique"); +} + +void check_alias_free_slice_peels_zero_stride() { + // Stride-zero inner dim must be peeled. + MultiRamp A{0, {0, 1}, {4, 5}}; + auto peeled = A.alias_free_slice(); + CHECK(peeled.size() == 1, "peeled the stride-zero dim"); + if (peeled.size() == 1) { + CHECK(peeled[0].dim == 0 && peeled[0].lanes == 4, + "peeled the right dim"); + CHECK(is_const_zero(peeled[0].stride), "peeled dim had stride zero"); + } + // Remaining is {base=0, strides=[1], lanes=[5]} — unique. + auto seq = expand(A); + std::set unique(seq.begin(), seq.end()); + CHECK(unique.size() == seq.size(), "remaining slice is unique"); +} + +void check_alias_free_slice_degenerate() { + // A 1-dim ramp with stride zero: only dim is a duplication. It should + // be peeled, leaving *this as a 0-dim scalar. + MultiRamp A{7, {0}, {4}}; + auto peeled = A.alias_free_slice(); + CHECK(peeled.size() == 1, "peeled the only dim"); + CHECK(A.dimensions() == 0, "remaining is scalar"); + auto seq = expand(A); + CHECK(seq.size() == 1 && seq[0] == 7, "scalar lane is base"); +} + +// ---- MultiRamp::rotate_stride_one_innermost ----------------------------- + +void check_rotate_stride_one_innermost() { + // Stride-1 dim not innermost: rotating should produce a MultiRamp + // whose expand, when transposed with cols = total / A, matches the + // original expand. + MultiRamp A{0, {10, 1}, {3, 4}}; // [0,10,20,1,11,21,2,12,22,3,13,23] + auto orig = expand(A); + int a = A.rotate_stride_one_innermost(); + CHECK(a > 0, "rotated (stride-1 was not innermost)"); + auto rotated = expand(A); + // Per the header: make_transpose(rotated, total/a) recovers orig. + // make_transpose(v, cols): output[j*rows + i] = v[i*cols + j], with + // rows = v.size()/cols. + int cols = (int)rotated.size() / a; + int rows = a; + std::vector roundtrip(rotated.size()); + for (int j = 0; j < cols; j++) { + for (int i = 0; i < rows; i++) { + roundtrip[j * rows + i] = rotated[i * cols + j]; + } + } + CHECK_SEQ(roundtrip, orig, "rotate + transpose = identity"); +} + +void check_rotate_stride_one_innermost_noop() { + // Stride-1 already innermost: no-op. + MultiRamp A{0, {1, 10}, {3, 4}}; + auto before = expand(A); + int a = A.rotate_stride_one_innermost(); + CHECK(a == 0, "no-op when stride-1 already innermost"); + CHECK_SEQ(expand(A), before, "unchanged"); +} + +// ---- is_multiramp round-trip -------------------------------------------- + +void check_roundtrip(const MultiRamp &mr, const char *msg) { + Expr e = mr.to_expr(); + MultiRamp parsed; + Scope scope; + if (e.type().is_vector()) { + CHECK(is_multiramp(e, scope, &parsed), msg); + if (parsed.dimensions() > 0 || mr.dimensions() > 0) { + auto got = expand(parsed); + auto want = expand(mr); + CHECK_SEQ(got, want, msg); + } + } +} + +void check_roundtrips() { + check_roundtrip(MultiRamp{0, {1}, {4}}, "1D ramp roundtrip"); + check_roundtrip(MultiRamp{7, {1, 10}, {2, 3}}, "2D ramp roundtrip"); + check_roundtrip(MultiRamp{0, {1, 10, 100}, {2, 3, 2}}, "3D ramp roundtrip"); + check_roundtrip(MultiRamp{0, {0, 1}, {4, 3}}, "stride-zero dim roundtrip"); +} + +void check_reject_non_multiramp_sum() { + // [0,1,2,100,101,102] + [0,2,100,102,200,202] = sum with shape conflict. + Expr a_inner = Ramp::make(Expr(0), Expr(1), 3); + Expr a = Ramp::make(a_inner, Broadcast::make(Expr(100), 3), 2); + Expr b_inner = Ramp::make(Expr(0), Expr(2), 2); + Expr b = Ramp::make(b_inner, Broadcast::make(Expr(100), 2), 3); + Expr sum = Add::make(a, b); + Scope scope; + MultiRamp m; + CHECK(!is_multiramp(sum, scope, &m), "reject coprime-shape add"); +} + +} // namespace + +int main(int argc, char **argv) { + check_add_refinable_shapes(); + check_add_same_shape(); + check_add_incompatible_shapes(); + check_add_cancels_to_zero(); + check_add_scaled_outer(); + + check_div_pure_carry_const(); + check_div_symbolic_strides(); + check_div_merges_adjacent_pure_carry(); + check_div_with_split(); + check_div_split_with_symbolic_stride(); + check_div_rejects_non_multiramp(); + check_div_rejects_unaligned_base(); + check_div_rejects_symbolic_denominator(); + + check_mod_basic(); + check_mod_with_split(); + check_mod_symbolic_strides(); + check_mod_rejects_non_multiramp(); + + check_recognize_1d_ramp(); + check_recognize_nested_ramp(); + check_recognize_add(); + check_recognize_div_const(); + check_recognize_mod_const(); + check_recognize_div_symbolic_strides(); + check_reorder(); + check_shuffle_from_permuted_2d(); + check_shuffle_from_permuted_identity(); + check_shuffle_from_permuted_3d(); + check_shuffle_from_slice_2d(); + check_shuffle_from_slice_inner(); + check_shuffle_from_slice_3d(); + check_mul_basic(); + check_equality_same(); + check_equality_different(); + check_equality_scalar(); + check_alias_free_slice_all_unique(); + check_alias_free_slice_peels_zero_stride(); + check_alias_free_slice_degenerate(); + check_rotate_stride_one_innermost(); + check_rotate_stride_one_innermost_noop(); + check_roundtrips(); + check_reject_non_multiramp_sum(); + + if (failures) { + printf("%d failures\n", failures); + return 1; + } + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/transposed_vector_reduce.cpp b/test/correctness/transposed_vector_reduce.cpp new file mode 100644 index 000000000000..ae349c70719b --- /dev/null +++ b/test/correctness/transposed_vector_reduce.cpp @@ -0,0 +1,256 @@ +#include "Halide.h" + +#include +#include +#include + +using namespace Halide; +using namespace Halide::Internal; + +constexpr int all = -1, success = 0, bad_output = 1, failed_vectorization = 2; + +int test(int which_case = all) { + + constexpr int vec = 8; + + Func g{"g"}; + Var x{"x"}, y{"y"}, z{"z"}; + RDom r(0, vec); + + ImageParam input(Int(32), 3); + Buffer input_buf(vec, vec, vec); + input_buf.for_each_element([&](int x, int y, int z) { + input_buf(x, y, z) = x + y * 10 + z * 100; + }); + input.set(input_buf); + + for (int i = 0; i < 6; i++) { + for (int j = 0; j < 3; j++) { + int idx = i * 3 + j; + if (which_case == all || which_case == idx) { + switch (j) { + case 0: + g(x, y) += input(x, y, r); + break; + case 1: + g(x, y) += input(x, r, y); + break; + case 2: + g(x, y) += input(r, x, y); + break; + } + } + } + } + + std::vector orders[6] = + {{x, y, r}, + {x, r, y}, + {r, x, y}, + {y, x, r}, + {y, r, x}, + {r, y, x}}; + + Buffer correct = g.realize({vec, vec}); + + g.bound(x, 0, vec) + .bound(y, 0, vec) + .vectorize(x) + .vectorize(y); + + int u = 0; + for (int i = 0; i < 6; i++) { + for (int j = 0; j < 3; j++) { + int idx = i * 3 + j; + if (which_case == all || idx == which_case) { + g.update(u++) + .vectorize(x) + .vectorize(y) + .atomic() + .vectorize(r) + .reorder(orders[i]); + } + } + } + + // We need to know the stride on the output buffer is such that rows don't + // alias each other. That would be UB, but not UB that the vectorizer knows + // how to exploit. It's more interesting if the stride is not vec - it's a genuine 2D store. + // g.output_buffer().dim(1).set_stride(vec + 7); + + int for_loops = 0, gathers = 0; + auto checker = LambdaMutator{ + [&](auto *self, const For *op) { + for_loops++; + return self->visit_base(op); + }, + [&](auto *self, const Load *op) { + const Ramp *r = op->index.as(); + gathers += !r || !is_const_one(r->stride); + return self->visit_base(op); + }}; + + g.add_custom_lowering_pass(&checker, nullptr); + + Buffer out = g.realize({vec, vec}); + + for (int y = 0; y < vec; y++) { + for (int x = 0; x < vec; x++) { + if (out(x, y) != correct(x, y)) { + printf("out(%d, %d) = %d instead of %d\n", x, y, out(x, y), correct(x, y)); + return bad_output; + } + } + } + + if (which_case == all && for_loops) { + printf("Atomic vectorization failed. Lowered code contained %d for loops\n", for_loops); + return failed_vectorization; + } + + if (which_case == all && gathers) { + printf("Atomic vectorization produced %d vector gathers\n", gathers); + return failed_vectorization; + } + + if (which_case != all) { + g.compile_to_lowered_stmt(std::string("test_") + std::to_string(which_case) + ".stmt", {input}, StmtOutputFormat::Text, Target{"host-no_asserts-no_runtime-no_bounds_query"}); + g.compile_to_assembly(std::string("test_") + std::to_string(which_case) + ".s", {input}, Target{"host-no_asserts-no_runtime-no_bounds_query"}); + } + + return success; +} + +// Generate a random quasi-affine expression in the given RVars: an affine +// combination of terms where each term is one of v, v/k, v%k, or recursively +// one of those of a nested term. All divisors are required to divide the +// corresponding RVar's extent, so the expression is representable as a +// multiramp of the RDom. Coefficients and constant terms are small ints. +struct RVarInfo { + RVar var; + int extent; +}; + +Expr random_term(std::mt19937 &rng, const RDom &rdom) { + const RVar &chosen = rdom[rng() % rdom.dimensions()]; + int extent = *as_const_int(chosen.extent()); + int op = (int)(rng() % 3); // 0: leaf, 1: /k, 2: %k + if (op == 0 || extent <= 1) { + return chosen; + } + std::vector divisors; + for (int d = 2; d <= extent; d++) { + if (extent % d == 0) { + divisors.push_back(d); + } + } + if (divisors.empty()) { + return chosen; + } + int k = divisors[rng() % divisors.size()]; + return (op == 1) ? chosen / k : chosen % k; +} + +Expr random_qa(std::mt19937 &rng, const RDom &rdom) { + int n_terms = 1 + (int)(rng() % 3); + Expr e; + for (int i = 0; i < n_terms; i++) { + int coeff = (int)(rng() % 5) - 2; // -2..2 + if (coeff == 0) continue; + Expr term = random_term(rng, rdom); + Expr part = (coeff == 1) ? term : coeff * term; + e = e.defined() ? e + part : part; + } + if (!e.defined()) e = 0; + int c0 = (int)(rng() % 7) - 3; // -3..3 + if (c0 != 0) e = e + c0; + return e; +} + +int test_random() { + std::mt19937 rng(0); + RDom r(0, 8, 0, 9, 0, 6); + + // Generous symmetric range for both the input and the output. Halide's + // bounds inference figures out what it actually needs within this. + constexpr int half = 256; + constexpr int range = 2 * half; + + Buffer input_buf(range); + input_buf.set_min(-half); + + constexpr int num_cases = 200; + int tried = 0; + while (tried < num_cases) { + Expr A = random_qa(rng, r); + Expr B = random_qa(rng, r); + int t = tried++; + + for (int i = 0; i < range; i++) { + input_buf(i - half) = (i * 31 + t * 7) & 0xff; + } + ImageParam input(Int(32), 1); + input.set(input_buf); + + auto build = [&](bool vectorized) { + Func f{"f_rand"}; + Var x{"x"}; + f(x) = 0; + f(A) += input(B) + 0 * r.x; // Force a dependence on the RDom + if (vectorized) { + f.update().atomic().vectorize(r.x).vectorize(r.y).vectorize(r.z); + } + return f; + }; + + auto realize = [&](Func f) { + Buffer buf(range); + buf.set_min(-half); + f.realize(buf); + return buf; + }; + + Buffer correct = realize(build(false)); + Buffer out = realize(build(true)); + + for (int i = -half; i < half; i++) { + if (out(i) != correct(i)) { + std::cout << "Random case " << t << " failed:\n" + << " A = " << A << "\n" + << " B = " << B << "\n" + << " out(" << i << ") = " << out(i) + << " instead of " << correct(i) << "\n"; + return bad_output; + } + } + } + return success; +} + +int main(int argc, char **argv) { + if (get_jit_target_from_environment().has_feature(Target::SVE2)) { + printf("[SKIP] LLVM's SVE backend chokes on the vector shuffles in this test.\n"); + return 0; + } + + int result = test(all); + + if (result == bad_output) { + for (int i = 0; i < 18; i++) { + if (test(i) != success) { + printf("Test case %d failed\n", i); + return result; + } + } + } else if (result != success) { + return result; + } + + int rand_result = test_random(); + if (rand_result != success) { + return rand_result; + } + + printf("Success!\n"); + return 0; +}