Skip to content

Commit

Permalink
Don't allow mutating aliases of buffers used more than once (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharlet authored Dec 18, 2024
1 parent 8b9fd64 commit 6fadef0
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 45 deletions.
109 changes: 64 additions & 45 deletions builder/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <map>
#include <numeric>
#include <optional>
#include <set>
Expand Down Expand Up @@ -205,6 +204,8 @@ class buffer_aliaser : public node_mutator {
node_context& ctx;

struct alias_info {
var target;

// Parameters for this alias's make_buffer call.
std::vector<dim_expr> dims;
expr elem_size;
Expand All @@ -224,6 +225,8 @@ class buffer_aliaser : public node_mutator {
bool is_copy = false;

bool is_contiguous_copy = false;

bool disabled = false;
};

class buffer_info {
Expand All @@ -236,21 +239,24 @@ class buffer_aliaser : public node_mutator {
bool is_output;

// Possible aliases of this allocation.
std::map<var, alias_info> aliases;
std::vector<alias_info> aliases;

// If we decided to alias this buffer, we might have grown the bounds. If so, we need to make a new allocation with
// this symbol, but make a crop of it for the original bounds.
var shared_alloc_sym;

public:
int uses = 0;

buffer_info(std::vector<dim_expr> dims, expr elem_size, bool is_input = false, bool is_output = false)
: dims(std::move(dims)), elem_size(std::move(elem_size)), is_input(is_input), is_output(is_output) {}

void maybe_alias(var s, alias_info a) {
assert(aliases.count(s) == 0);
aliases[s] = std::move(a);
void do_not_alias(var t) {
for (alias_info& i : aliases) {
if (i.target == t) {
i.disabled = true;
}
}
}
void do_not_alias(var s) { aliases.erase(s); }
};
symbol_map<buffer_info> buffers;

Expand All @@ -260,6 +266,13 @@ class buffer_aliaser : public node_mutator {
scoped_trace trace("alias_compatible");
assert(op->dims.size() == alias.dims.size());

if (target_info.uses > 1 && alias.may_mutate) {
// We can't use a mutating alias on a buffer that is used more than once.
// TODO: We could do better here: if the mutating alias is the *last* use, we can still use that alias.
// This is tricky to figure out especially when loops are involved.
return false;
}

if (alias.is_contiguous_copy) {
assert(alias.assume_in_bounds);
// We just assume flat copies are OK.
Expand Down Expand Up @@ -372,11 +385,12 @@ class buffer_aliaser : public node_mutator {
}

box_expr op_dims_bounds = dims_bounds(info.dims);
for (auto& target : info.aliases) {
var target_var = target.first;
alias_info& alias = target.second;
for (alias_info& alias : info.aliases) {
if (alias.disabled) {
continue;
}

var alloc_var = target_var;
var target_var = alias.target;
std::optional<buffer_info>& target_info = buffers[target_var];
assert(target_info);

Expand All @@ -393,6 +407,7 @@ class buffer_aliaser : public node_mutator {
i = substitute_bounds(i, op->sym, op_dims_bounds);
}

var alloc_var = target_var;
if (!alias.assume_in_bounds) {
assert(!target_info->is_output);
assert(!target_info->is_input); // We shouldn't be trying to write to an input anyways.
Expand Down Expand Up @@ -433,8 +448,8 @@ class buffer_aliaser : public node_mutator {
result = make_buffer::make(sym, expr(), elem_size, op->dims, result);

for (auto& i : target_info->aliases) {
i.second.may_mutate = i.second.may_mutate || alias.may_mutate;
i.second.assume_in_bounds = i.second.assume_in_bounds && alias.assume_in_bounds;
i.may_mutate = i.may_mutate || alias.may_mutate;
i.assume_in_bounds = i.assume_in_bounds && alias.assume_in_bounds;
}

if (elem_size.defined()) {
Expand All @@ -444,18 +459,6 @@ class buffer_aliaser : public node_mutator {
// If we aliased the source and destination of a copy with no padding, the copy can be removed.
result = remove_copy(result, op->sym, target_var);

if (!alias.is_copy) {
// This wasn't a copy, we actually did some computation in place. We can't alias another buffer to this target
// without understanding the lifetimes more carefully.
// TODO: I think this is a hack, but I'm not sure. I think maybe the proper thing to do is track a box_expr
// of the region that has been aliased so far, and allow another alias as long as it does not intersect that
// region. That will likely be very difficult to do symbolically.
for (std::optional<buffer_info>& i : buffers) {
if (!i) continue;
i->do_not_alias(target_var);
}
}

set_result(std::move(result));
return;
}
Expand Down Expand Up @@ -492,28 +495,34 @@ class buffer_aliaser : public node_mutator {
var in = op->inputs[0];
var out = op->outputs[0];
std::optional<buffer_info>& input_info = buffers[in];
if (input_info) input_info->uses++;
std::optional<buffer_info>& output_info = buffers[out];
if (input_info && output_info) {
alias_info fwd;
fwd.target = out;
fwd.dims = make_contiguous_dims(in, input_info->dims.size());
fwd.at = buffer_mins(out, output_info->dims.size());
fwd.is_contiguous_copy = true;
fwd.assume_in_bounds = true;
input_info->maybe_alias(out, std::move(fwd));
input_info->aliases.push_back(std::move(fwd));

alias_info back;
back.target = in;
back.dims = make_contiguous_dims(out, output_info->dims.size());
back.at = buffer_mins(in, input_info->dims.size());
back.is_contiguous_copy = true;
back.assume_in_bounds = true;
output_info->maybe_alias(in, std::move(back));
output_info->aliases.push_back(std::move(back));
}
} else if (op->attrs.allow_in_place) {
} else {
// If input is repeated, we don't want to add into the alias info again.
std::set<var> unique_inputs(op->inputs.begin(), op->inputs.end());
for (var i : unique_inputs) {
std::optional<buffer_info>& input_info = buffers[i];
if (!input_info || input_info->is_input) {
if (!input_info) continue;
input_info->uses++;

if (!op->attrs.allow_in_place || input_info->is_input) {
// We can't write to this buffer.
continue;
}
Expand All @@ -528,15 +537,17 @@ class buffer_aliaser : public node_mutator {
size_t rank = input_info->dims.size();

alias_info fwd;
fwd.target = o;
fwd.dims = buffer_dims(o, rank);
fwd.at = buffer_mins(i, rank);
fwd.assume_in_bounds = true;
fwd.may_mutate = false;
fwd.may_mutate = true;
fwd.permutation.resize(rank);
std::iota(fwd.permutation.begin(), fwd.permutation.end(), 0);
input_info->maybe_alias(o, std::move(fwd));
input_info->aliases.push_back(std::move(fwd));

alias_info back;
back.target = i;
// Use the bounds of the output, but the memory layout of the input.
back.dims.resize(rank);
for (int d = 0; d < static_cast<int>(rank); ++d) {
Expand All @@ -547,7 +558,7 @@ class buffer_aliaser : public node_mutator {
back.may_mutate = true;
back.permutation.resize(rank);
std::iota(back.permutation.begin(), back.permutation.end(), 0);
output_info->maybe_alias(i, std::move(back));
output_info->aliases.push_back(std::move(back));
}
}
}
Expand All @@ -565,6 +576,7 @@ class buffer_aliaser : public node_mutator {
// are the same dimensions we want the dst to be.

alias_info a;
a.target = op->src;
a.at.resize(op->src_x.size());
a.permutation.resize(op->dst_x.size());
a.dims = info->dims;
Expand All @@ -580,8 +592,8 @@ class buffer_aliaser : public node_mutator {
return;
}

// We want the bounds of the original dst dimension, but the memory layout of the src dimension. This may require
// the allocation to be expanded to accommodate this alias.
// We want the bounds of the original dst dimension, but the memory layout of the src dimension. This may
// require the allocation to be expanded to accommodate this alias.
a.dims[dst_d] = {buffer_bounds(op->dst, dst_d), src_dim.stride, src_dim.fold_factor};
a.permutation[dst_d] = src_d;
if (at.defined()) {
Expand All @@ -597,7 +609,7 @@ class buffer_aliaser : public node_mutator {

a.elem_size = buffer_elem_size(op->src);

info->maybe_alias(op->src, std::move(a));
info->aliases.push_back(std::move(a));
}

void alias_copy_src(const copy_stmt* op) {
Expand All @@ -613,6 +625,7 @@ class buffer_aliaser : public node_mutator {
// broadcasting).

alias_info a;
a.target = op->dst;
a.at.resize(op->dst_x.size());
a.dims.resize(op->src_x.size());
assert(op->src_x.size() == info->dims.size());
Expand Down Expand Up @@ -657,12 +670,15 @@ class buffer_aliaser : public node_mutator {
a.may_mutate = false;
a.elem_size = buffer_elem_size(op->dst);

info->maybe_alias(op->dst, std::move(a));
info->aliases.push_back(std::move(a));
}

void visit(const copy_stmt* op) override {
set_result(op);

std::optional<buffer_info>& src_info = buffers[op->src];
if (src_info) src_info->uses++;

alias_copy_dst(op);
alias_copy_src(op);
}
Expand All @@ -671,19 +687,18 @@ class buffer_aliaser : public node_mutator {
symbol_map<buffer_info>& old_buffers, var sym, var src, std::function<void(alias_info&)> handler) {
for (std::optional<buffer_info>& i : buffers) {
if (!i) continue;
auto j = i->aliases.find(sym);
if (j != i->aliases.end()) {
handler(j->second);
}
for (auto& a : i->aliases) {
if (a.target == sym) {
handler(a);
}
// We need to substitute uses of sym with uses of src in the aliases we added here.
for (dim_expr& d : a.second.dims) {
for (dim_expr& d : a.dims) {
d.bounds = substitute(d.bounds, sym, src);
d.stride = substitute(d.stride, sym, src);
d.fold_factor = substitute(d.fold_factor, sym, src);
}
a.second.elem_size = substitute(a.second.elem_size, sym, src);
for (expr& i : a.second.at) {
a.elem_size = substitute(a.elem_size, sym, src);
for (expr& i : a.at) {
i = substitute(i, sym, src);
}
}
Expand All @@ -705,8 +720,12 @@ class buffer_aliaser : public node_mutator {
assert(!old_info->shared_alloc_sym.defined() || old_info->shared_alloc_sym == info->shared_alloc_sym);
old_info->shared_alloc_sym = info->shared_alloc_sym;
}
for (auto& j : info->aliases) {
old_info->maybe_alias(j.first == sym ? src : j.first, std::move(j.second));
old_info->uses += info->uses;
for (alias_info& a : info->aliases) {
if (a.target == sym) {
a.target = src;
}
old_info->aliases.push_back(std::move(a));
}
}
std::swap(old_buffers, buffers);
Expand Down
62 changes: 62 additions & 0 deletions builder/test/cannot_alias.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,4 +309,66 @@ TEST_P(may_alias, unfolded) {
ASSERT_EQ(eval_ctx.heap.allocs.size(), may_alias ? 0 : 1);
}

class multiple_uses : public testing::TestWithParam<std::tuple<int, bool>> {};

INSTANTIATE_TEST_SUITE_P(alias_split, multiple_uses,
testing::Combine(testing::Values(0, 1), testing::Values(false, true)),
test_params_to_string<multiple_uses::ParamType>);

TEST_P(multiple_uses, cannot_alias) {
const int in_place = std::get<0>(GetParam());
const bool split = std::get<1>(GetParam());
// Make the pipeline
node_context ctx;

// In the pipeline:
// in -> a -> b
// a -> c

auto in = buffer_expr::make(ctx, "in", 2, sizeof(short));
auto out = buffer_expr::make(ctx, "out", 2, sizeof(short));

auto a = buffer_expr::make(ctx, "a", 2, sizeof(short));
auto b = buffer_expr::make(ctx, "b", 2, sizeof(short));
auto c = buffer_expr::make(ctx, "c", 2, sizeof(short));

var x(ctx, "x");
var y(ctx, "y");

func in_a = func::make(add_1<short>, {{in, {point(x), point(y)}}}, {{a, {x, y}}});
func a_b = func::make(add_1<short>, {{a, {point(x), point(y)}}}, {{b, {x, y}}},
call_stmt::attributes{.allow_in_place = in_place == 0, .name = "a_b"});
func b_c = func::make(add_1<short>, {{a, {point(x), point(y)}}}, {{c, {x, y}}},
call_stmt::attributes{.allow_in_place = in_place == 1, .name = "a_c"});

func sub = func::make(subtract<short>, {{b, {point(x), point(y)}}, {c, {point(x), point(y)}}}, {{out, {x, y}}});

if (split) {
sub.loops({{y}});
}

pipeline p = build_pipeline(ctx, {in}, {out});

// Run the pipeline.
const int W = 20;
const int H = 10;
buffer<short, 2> in_buf({W, H});
init_random(in_buf);

buffer<short, 2> out_buf({W, H});
out_buf.allocate();

// Not having span(std::initializer_list<T>) is unfortunate.
const raw_buffer* inputs[] = {&in_buf};
const raw_buffer* outputs[] = {&out_buf};
test_context eval_ctx;
p.evaluate(inputs, outputs, eval_ctx);

for (int y = 0; y < H; ++y) {
for (int x = 0; x < W; ++x) {
ASSERT_EQ(out_buf(x, y), 0);
}
}
}

} // namespace slinky

0 comments on commit 6fadef0

Please sign in to comment.