Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wma/simpler ptr #602

Closed
wants to merge 13 commits into from
Prev Previous commit
Next Next commit
Merge remote-tracking branch 'origin/main' into wma/fix-sparse-dict
  • Loading branch information
willow-ahrens committed Jun 14, 2024
commit da7e2687f0de6cb4217c5e68a973fcf2d906637d
51 changes: 41 additions & 10 deletions src/tensors/levels/sparse_dict_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,19 +239,50 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseLevel, pos_stop)
i = freshen(ctx, :i)
q = freshen(ctx, :q)
v = freshen(ctx, :v)
max_pos = freshen(ctx, :max_pos)
idx_temp = freshen(ctx, :idx_temp)
val_temp = freshen(ctx, :val_temp)
perm_vec = freshen(ctx, :perm_vec)
pos_pts = freshen(ctx, :pos_pts)
start = freshen(ctx, :start)
stop = freshen(ctx, :stop)
push_preamble!(ctx, quote
srt = sort(collect(pairs($(lvl.tbl))))
resize!($(lvl.idx), length(srt))
resize!($(lvl.val), length(srt))
for q in 1:length(srt)
((p, i), v) = srt[q]
$(lvl.val)[q] = v
$(lvl.idx)[q] = i
end
$max_pos = maximum($(lvl.ptr))
resize!($(lvl.ptr), $(ctx(pos_stop)) + 1)
$(lvl.ptr)[1] = 1
for p = 2:$(ctx(pos_stop)) + 1
$(lvl.ptr)[p] += $(lvl.ptr)[p - 1]
for $p = 2:$(ctx(pos_stop)) + 1
$(lvl.ptr)[$p] += $(lvl.ptr)[$p - 1]
end

resize!($(lvl.idx), length($(lvl.tbl)))
resize!($(lvl.val), length($(lvl.tbl)))
$pos_pts = copy($(lvl.ptr))
for entry in pairs($(lvl.tbl))
(($p, $i), $v) = entry
pos = $pos_pts[$p]
$(lvl.idx)[pos] = $i
$(lvl.val)[pos] = $v
$pos_pts[$p] += 1
end

# To reduce allocations, we pre-allocate the workspaces for perm, idx, and val
$perm_vec = Vector{Int64}(undef, $max_pos)
$idx_temp = typeof($(lvl.idx))(undef, $max_pos)
$val_temp = typeof($(lvl.val))(undef, $max_pos)
for $p = 1:$(ctx(pos_stop))
$start = $(lvl.ptr)[$p]
$stop = $(lvl.ptr)[$p+1] - 1
sortperm!((@view $perm_vec[1:$stop-$start+1]), $(lvl.idx)[$start:$stop])
# Store the correctly permuted version of the idxs and vals in a temporary
for $i in 1:($stop-$start+1)
$idx_temp[$i] = $(lvl.idx)[$start + $perm_vec[$i] - 1]
$val_temp[$i] = $(lvl.val)[$start + $perm_vec[$i] - 1]
end
# Overwrite the segment of the idx and vals array with the correct order
for $i in 1:($stop-$start+1)
$(lvl.idx)[$start + $i - 1] = $idx_temp[$i]
$(lvl.val)[$start + $i - 1] = $val_temp[$i]
end
end
$qos_stop = $(lvl.ptr)[$(ctx(pos_stop)) + 1] - 1
end)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,40 @@ quote
end
end
end
srt = sort(collect(pairs(tmp_lvl_tbl)))
resize!(tmp_lvl_idx, length(srt))
resize!(tmp_lvl_val, length(srt))
for q = 1:length(srt)
sugar_1 = srt[q]
sugar_2 = sugar_1[1]
p = sugar_2[1]
i = sugar_2[2]
v = sugar_1[2]
tmp_lvl_val[q] = v
tmp_lvl_idx[q] = i
end
max_pos = maximum(tmp_lvl_ptr)
resize!(tmp_lvl_ptr, 1 + 1)
tmp_lvl_ptr[1] = 1
for p = 2:1 + 1
tmp_lvl_ptr[p] += tmp_lvl_ptr[p - 1]
for p_2 = 2:1 + 1
tmp_lvl_ptr[p_2] += tmp_lvl_ptr[p_2 - 1]
end
resize!(tmp_lvl_idx, length(tmp_lvl_tbl))
resize!(tmp_lvl_val, length(tmp_lvl_tbl))
pos_pts = copy(tmp_lvl_ptr)
for entry = pairs(tmp_lvl_tbl)
sugar_2 = entry[1]
p_2 = sugar_2[1]
i_9 = sugar_2[2]
v = entry[2]
pos = pos_pts[p_2]
tmp_lvl_idx[pos] = i_9
tmp_lvl_val[pos] = v
pos_pts[p_2] += 1
end
perm_vec = Vector{Int64}(undef, max_pos)
idx_temp = (typeof(tmp_lvl_idx))(undef, max_pos)
val_temp = (typeof(tmp_lvl_val))(undef, max_pos)
for p_2 = 1:1
start = tmp_lvl_ptr[p_2]
stop = tmp_lvl_ptr[p_2 + 1] - 1
sortperm!(@view(perm_vec[1:(stop - start) + 1]), tmp_lvl_idx[start:stop])
for i_9 = 1:(stop - start) + 1
idx_temp[i_9] = tmp_lvl_idx[(start + perm_vec[i_9]) - 1]
val_temp[i_9] = tmp_lvl_val[(start + perm_vec[i_9]) - 1]
end
for i_9 = 1:(stop - start) + 1
tmp_lvl_idx[(start + i_9) - 1] = idx_temp[i_9]
tmp_lvl_val[(start + i_9) - 1] = val_temp[i_9]
end
end
qos_stop = tmp_lvl_ptr[1 + 1] - 1
resize!(tmp_lvl_val_2, qos_stop)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,40 @@ quote
end
end
end
srt = sort(collect(pairs(tmp_lvl_tbl)))
resize!(tmp_lvl_idx, length(srt))
resize!(tmp_lvl_val, length(srt))
for q = 1:length(srt)
sugar_1 = srt[q]
sugar_2 = sugar_1[1]
p = sugar_2[1]
i = sugar_2[2]
v = sugar_1[2]
tmp_lvl_val[q] = v
tmp_lvl_idx[q] = i
end
max_pos = maximum(tmp_lvl_ptr)
resize!(tmp_lvl_ptr, 1 + 1)
tmp_lvl_ptr[1] = 1
for p = 2:1 + 1
tmp_lvl_ptr[p] += tmp_lvl_ptr[p - 1]
for p_2 = 2:1 + 1
tmp_lvl_ptr[p_2] += tmp_lvl_ptr[p_2 - 1]
end
resize!(tmp_lvl_idx, length(tmp_lvl_tbl))
resize!(tmp_lvl_val, length(tmp_lvl_tbl))
pos_pts = copy(tmp_lvl_ptr)
for entry = pairs(tmp_lvl_tbl)
sugar_2 = entry[1]
p_2 = sugar_2[1]
i_9 = sugar_2[2]
v = entry[2]
pos = pos_pts[p_2]
tmp_lvl_idx[pos] = i_9
tmp_lvl_val[pos] = v
pos_pts[p_2] += 1
end
perm_vec = Vector{Int64}(undef, max_pos)
idx_temp = (typeof(tmp_lvl_idx))(undef, max_pos)
val_temp = (typeof(tmp_lvl_val))(undef, max_pos)
for p_2 = 1:1
start = tmp_lvl_ptr[p_2]
stop = tmp_lvl_ptr[p_2 + 1] - 1
sortperm!(@view(perm_vec[1:(stop - start) + 1]), tmp_lvl_idx[start:stop])
for i_9 = 1:(stop - start) + 1
idx_temp[i_9] = tmp_lvl_idx[(start + perm_vec[i_9]) - 1]
val_temp[i_9] = tmp_lvl_val[(start + perm_vec[i_9]) - 1]
end
for i_9 = 1:(stop - start) + 1
tmp_lvl_idx[(start + i_9) - 1] = idx_temp[i_9]
tmp_lvl_val[(start + i_9) - 1] = val_temp[i_9]
end
end
qos_stop = tmp_lvl_ptr[1 + 1] - 1
resize!(tmp_lvl_val_2, qos_stop)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,40 @@ begin
end
end
result = ()
srt = sort(collect(pairs(fmt_lvl_tbl)))
resize!(fmt_lvl_idx, length(srt))
resize!(fmt_lvl_val, length(srt))
for q = 1:length(srt)
sugar_1 = srt[q]
sugar_2 = sugar_1[1]
p = sugar_2[1]
i = sugar_2[2]
v = sugar_1[2]
fmt_lvl_val[q] = v
fmt_lvl_idx[q] = i
end
max_pos = maximum(fmt_lvl_ptr)
resize!(fmt_lvl_ptr, fmt_lvl.shape + 1)
fmt_lvl_ptr[1] = 1
for p = 2:fmt_lvl.shape + 1
fmt_lvl_ptr[p] += fmt_lvl_ptr[p - 1]
for p_3 = 2:fmt_lvl.shape + 1
fmt_lvl_ptr[p_3] += fmt_lvl_ptr[p_3 - 1]
end
resize!(fmt_lvl_idx, length(fmt_lvl_tbl))
resize!(fmt_lvl_val, length(fmt_lvl_tbl))
pos_pts = copy(fmt_lvl_ptr)
for entry = pairs(fmt_lvl_tbl)
sugar_2 = entry[1]
p_3 = sugar_2[1]
i_14 = sugar_2[2]
v = entry[2]
pos = pos_pts[p_3]
fmt_lvl_idx[pos] = i_14
fmt_lvl_val[pos] = v
pos_pts[p_3] += 1
end
perm_vec = Vector{Int64}(undef, max_pos)
idx_temp = (typeof(fmt_lvl_idx))(undef, max_pos)
val_temp = (typeof(fmt_lvl_val))(undef, max_pos)
for p_3 = 1:fmt_lvl.shape
start = fmt_lvl_ptr[p_3]
stop = fmt_lvl_ptr[p_3 + 1] - 1
sortperm!(@view(perm_vec[1:(stop - start) + 1]), fmt_lvl_idx[start:stop])
for i_14 = 1:(stop - start) + 1
idx_temp[i_14] = fmt_lvl_idx[(start + perm_vec[i_14]) - 1]
val_temp[i_14] = fmt_lvl_val[(start + perm_vec[i_14]) - 1]
end
for i_14 = 1:(stop - start) + 1
fmt_lvl_idx[(start + i_14) - 1] = idx_temp[i_14]
fmt_lvl_val[(start + i_14) - 1] = val_temp[i_14]
end
end
qos_stop = fmt_lvl_ptr[fmt_lvl.shape + 1] - 1
resize!(fmt_lvl_2_val, qos_stop)
Expand Down
92 changes: 64 additions & 28 deletions test/reference64/typical/typical_transpose_csc_to_coo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,40 +130,76 @@ quote
end
end
end
srt = sort(collect(pairs(B_lvl_tbl)))
resize!(B_lvl_idx, length(srt))
resize!(B_lvl_val, length(srt))
for q = 1:length(srt)
sugar_1 = srt[q]
sugar_2 = sugar_1[1]
p = sugar_2[1]
i = sugar_2[2]
v = sugar_1[2]
B_lvl_val[q] = v
B_lvl_idx[q] = i
end
max_pos = maximum(B_lvl_ptr)
resize!(B_lvl_ptr, 1 + 1)
B_lvl_ptr[1] = 1
for p = 2:1 + 1
B_lvl_ptr[p] += B_lvl_ptr[p - 1]
for p_2 = 2:1 + 1
B_lvl_ptr[p_2] += B_lvl_ptr[p_2 - 1]
end
qos_stop = B_lvl_ptr[1 + 1] - 1
srt = sort(collect(pairs(B_lvl_tbl_2)))
resize!(B_lvl_idx_2, length(srt))
resize!(B_lvl_val_2, length(srt))
for q = 1:length(srt)
sugar_3 = srt[q]
sugar_4 = sugar_3[1]
p = sugar_4[1]
i = sugar_4[2]
v = sugar_3[2]
B_lvl_val_2[q] = v
B_lvl_idx_2[q] = i
resize!(B_lvl_idx, length(B_lvl_tbl))
resize!(B_lvl_val, length(B_lvl_tbl))
pos_pts = copy(B_lvl_ptr)
for entry = pairs(B_lvl_tbl)
sugar_2 = entry[1]
p_2 = sugar_2[1]
i_13 = sugar_2[2]
v_8 = entry[2]
pos = pos_pts[p_2]
B_lvl_idx[pos] = i_13
B_lvl_val[pos] = v_8
pos_pts[p_2] += 1
end
perm_vec = Vector{Int64}(undef, max_pos)
idx_temp = (typeof(B_lvl_idx))(undef, max_pos)
val_temp = (typeof(B_lvl_val))(undef, max_pos)
for p_2 = 1:1
start = B_lvl_ptr[p_2]
stop = B_lvl_ptr[p_2 + 1] - 1
sortperm!(@view(perm_vec[1:(stop - start) + 1]), B_lvl_idx[start:stop])
for i_13 = 1:(stop - start) + 1
idx_temp[i_13] = B_lvl_idx[(start + perm_vec[i_13]) - 1]
val_temp[i_13] = B_lvl_val[(start + perm_vec[i_13]) - 1]
end
for i_13 = 1:(stop - start) + 1
B_lvl_idx[(start + i_13) - 1] = idx_temp[i_13]
B_lvl_val[(start + i_13) - 1] = val_temp[i_13]
end
end
qos_stop = B_lvl_ptr[1 + 1] - 1
max_pos_2 = maximum(B_lvl_ptr_2)
resize!(B_lvl_ptr_2, qos_stop + 1)
B_lvl_ptr_2[1] = 1
for p = 2:qos_stop + 1
B_lvl_ptr_2[p] += B_lvl_ptr_2[p - 1]
for p_4 = 2:qos_stop + 1
B_lvl_ptr_2[p_4] += B_lvl_ptr_2[p_4 - 1]
end
resize!(B_lvl_idx_2, length(B_lvl_tbl_2))
resize!(B_lvl_val_2, length(B_lvl_tbl_2))
pos_pts_2 = copy(B_lvl_ptr_2)
for entry = pairs(B_lvl_tbl_2)
sugar_4 = entry[1]
p_4 = sugar_4[1]
i_14 = sugar_4[2]
v_9 = entry[2]
pos = pos_pts_2[p_4]
B_lvl_idx_2[pos] = i_14
B_lvl_val_2[pos] = v_9
pos_pts_2[p_4] += 1
end
perm_vec_2 = Vector{Int64}(undef, max_pos_2)
idx_temp_2 = (typeof(B_lvl_idx_2))(undef, max_pos_2)
val_temp_2 = (typeof(B_lvl_val_2))(undef, max_pos_2)
for p_4 = 1:qos_stop
start_2 = B_lvl_ptr_2[p_4]
stop_2 = B_lvl_ptr_2[p_4 + 1] - 1
sortperm!(@view(perm_vec_2[1:(stop_2 - start_2) + 1]), B_lvl_idx_2[start_2:stop_2])
for i_14 = 1:(stop_2 - start_2) + 1
idx_temp_2[i_14] = B_lvl_idx_2[(start_2 + perm_vec_2[i_14]) - 1]
val_temp_2[i_14] = B_lvl_val_2[(start_2 + perm_vec_2[i_14]) - 1]
end
for i_14 = 1:(stop_2 - start_2) + 1
B_lvl_idx_2[(start_2 + i_14) - 1] = idx_temp_2[i_14]
B_lvl_val_2[(start_2 + i_14) - 1] = val_temp_2[i_14]
end
end
qos_stop_2 = B_lvl_ptr_2[qos_stop + 1] - 1
resize!(B_lvl_2_val, qos_stop_2)
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.