Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed iterator task in pmap to a function. Fixes #4035 and #4034 #4054

Merged
merged 1 commit into from
Aug 14, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Changed iterator task in pmap to a function. Fixes #4035 and #4034
  • Loading branch information
amitmurthy committed Aug 14, 2013
commit 9420bbd23b693d987d6d5f729512fe4e9aa88224
62 changes: 29 additions & 33 deletions base/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1313,70 +1313,66 @@ pmap(f) = f()
function pmap(f, lsts...; err_retry=true, err_stop=false)
len = length(lsts)
np = nprocs()
retrycond = Condition()

results = Dict{Int,Any}()
function setresult(idx,v)
results[idx] = v
notify(retrycond)
end

retryqueue = {}
function retry(idx,v,ex)
push!(retryqueue, (idx,v,ex))
notify(retrycond)
end

task_in_err = false
is_task_in_error() = task_in_err
set_task_in_error() = (task_in_err = true)

nextidx = 0
getnextidx() = (nextidx += 1)
getcurridx() = nextidx

states = [start(lsts[idx]) for idx in 1:len]
function producer()
while true
if (is_task_in_error() && err_stop)
break
elseif !isempty(retryqueue)
produce(shift!(retryqueue)[1:2])
elseif all([!done(lsts[idx],states[idx]) for idx in 1:len])
nxts = [next(lsts[idx],states[idx]) for idx in 1:len]
map(idx->states[idx]=nxts[idx][2], 1:len)
nxtvals = [x[1] for x in nxts]
produce((getnextidx(), nxtvals))
elseif (length(results) == getcurridx())
break
else
wait(retrycond)
end
function getnext_tasklet()
if is_task_in_error() && err_stop
return nothing
elseif all([!done(lsts[idx],states[idx]) for idx in 1:len])
nxts = [next(lsts[idx],states[idx]) for idx in 1:len]
map(idx->states[idx]=nxts[idx][2], 1:len)
nxtvals = [x[1] for x in nxts]
return (getnextidx(), nxtvals)

elseif !isempty(retryqueue)
return shift!(retryqueue)
else
return nothing
end
end

pt = Task(producer)
@sync begin
for p=1:np
wpid = PGRP.workers[p].id
if wpid != myid() || np == 1
@async begin
for (idx,nxtvals) in pt
tasklet = getnext_tasklet()
while (tasklet != nothing)
(idx, fvals) = tasklet
try
result = remotecall_fetch(wpid, f, nxtvals...)
isa(result, Exception) ? ((wpid == myid()) ? rethrow(result) : throw(result)) : setresult(idx, result)
result = remotecall_fetch(wpid, f, fvals...)
if isa(result, Exception)
((wpid == myid()) ? rethrow(result) : throw(result))
else
results[idx] = result
end
catch ex
err_retry ? retry(idx,nxtvals,ex) : setresult(idx, ex)
if err_retry
push!(retryqueue, (idx,fvals, ex))
else
results[idx] = ex
end
set_task_in_error()
break # remove this worker from accepting any more tasks
end

tasklet = getnext_tasklet()
end
end
end
end
end

!istaskdone(pt) && throwto(pt, InterruptException())
for failure in retryqueue
results[failure[1]] = failure[3]
end
Expand Down
46 changes: 43 additions & 3 deletions test/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,46 @@ et=toq()
@test isready(rr1)
@test !isready(rr3)

# make sure exceptions propagate when waiting on Tasks
# TODO: should be enabled but the error is printed by the event loop
#@test_throws (@sync (@async error("oops")))

# TODO: The below block should be always enabled but the error is printed by the event loop

# Hence in the event of any relevant changes to the parallel codebase,
# please define an ENV variable PTEST_FULL and ensure that the below block is
# executed successfully before committing/merging

if haskey(ENV, "PTEST_FULL")
println("START of parallel tests that print errors")

# make sure exceptions propagate when waiting on Tasks
@test_throws (@sync (@async error("oops")))

# pmap tests
# needs at least 4 processors (which are being created above for the @parallel tests)
s = "a"*"bcdefghijklmnopqrstuvwxyz"^100;
ups = "A"*"BCDEFGHIJKLMNOPQRSTUVWXYZ"^100;
@test ups == bytestring(Uint8[uint8(c) for c in pmap(x->uppercase(x), s)])
@test ups == bytestring(Uint8[uint8(c) for c in pmap(x->uppercase(char(x)), s.data)])

# retry, on error exit
res = pmap(x->(x=='a') ? error("test error. don't panic.") : uppercase(x), s; err_retry=true, err_stop=true);
@test length(res) < length(ups)
@test isa(res[1], Exception)

# no retry, on error exit
res = pmap(x->(x=='a') ? error("test error. don't panic.") : uppercase(x), s; err_retry=false, err_stop=true);
@test length(res) < length(ups)
@test isa(res[1], Exception)

# retry, on error continue
res = pmap(x->iseven(myid()) ? error("test error. don't panic.") : uppercase(x), s; err_retry=true, err_stop=false);
@test length(res) == length(ups)
@test ups == bytestring(Uint8[uint8(c) for c in res])

# no retry, on error continue
res = pmap(x->(x=='a') ? error("test error. don't panic.") : uppercase(x), s; err_retry=false, err_stop=false);
@test length(res) == length(ups)
@test isa(res[1], Exception)

println("END of parallel tests that print errors")
end