Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions src/RegisterWorkerShell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module RegisterWorkerShell
using SimpleTraits, ImageAxes, ImageMetadata, Distributed, SharedArrays
using AxisArrays: AxisArray, Axis

export AbstractWorker, AnyValue, ArrayDecl, close!, init!, maybe_sharedarray, monitor, monitor!, worker, workerpid, getindex_t
export AbstractWorker, AnyValue, ArrayDecl, close!, init!, maybe_sharedarray, monitor
export monitor_thread, monitor!, worker, workertid, workerpid, getindex_t
export load_mm_package

"""
Expand Down Expand Up @@ -36,7 +37,7 @@ subtypes. The exported operations are:
- `init!` and `close!`: functions you may specialize if your algorithm
needs to initialize or clean up resources
- `worker`: perform registration on an image
- `workerpid`: extract the process-id for a given worker
- `workertid`: extract the thread-id
"""
RegisterWorkerShell

Expand All @@ -58,26 +59,20 @@ The worker algorithm should call `monitor!(mon, algorithm)` to copy
the values into `mon`, and `monitor!(mon, :var3, var3)` for an
internal variable `var3` that is not taken from `algorithm`. See
`monitor!` for more detail.

An important detail is that if `workerpid(algorithm) ≠ myid()`, then any
requested `AbstractArray` fields in `algorithm` will be turned into
`SharedArray`s for `mon`. This reduces the cost of communication
between the worker and driver processes.
"""
function monitor(algorithm::AbstractWorker, fields::Union{NTuple{N,Symbol},Vector{Symbol}}, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where N
pid = workerpid(algorithm)
mon = Dict{Symbol,Any}()
for f in fields
isdefined(algorithm, f) || continue
mon[f] = maybe_sharedarray(getfield(algorithm, f), pid)
mon[f] = getfield(algorithm, f)
end
for (k,v) in morevars
mon[k] = maybe_sharedarray(v, pid)
mon[k] = v
end
mon
end

monitor(algorithm::Vector{W}, fields, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where {W<:AbstractWorker} = map(alg->monitor(alg, fields, morevars), algorithm)
monitor(algorithms::Vector{W}, fields, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where {W<:AbstractWorker} =
map(alg->monitor(alg, fields, morevars), algorithms) # for multi-thread

"""
`monitor!(mon, algorithm)` updates `mon` with the current values of
Expand Down Expand Up @@ -150,12 +145,12 @@ worker(algorithm::AbstractWorker, img, tindex, mon) = error("Worker modules must
worker(rr::RemoteChannel, img, tindex, mon) = worker(fetch(rr), img, tindex, mon)

"""
`workerpid(algorithm)` extracts the `pid` associated with the worker
`workertid(algorithm)` extracts the `workertid` associated with the thread
that will be assigned tasks for `algorithm`. All `AbstractWorker`
subtypes should include a `workerpid` field, or overload this function
subtypes should include a `workertid` field, or overload this function
to return myid().
"""
workerpid(w::AbstractWorker) = w.workerpid
workertid(w::AbstractWorker) = w.workertid

"""
`load_mm_package(dev)` loads appropriate mismatch module conditioned on
Expand Down
148 changes: 148 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
using RegisterWorkerShell, Test
using ImageAxes, ImageMetadata, AxisArrays
using SharedArrays

# Concrete AbstractWorker subtype for testing
mutable struct TestWorker <: AbstractWorker
workertid::Int
param_scalar::Float64
param_array::Vector{Int}
param_string::String
end

@testset "RegisterWorkerShell" begin

@testset "ArrayDecl" begin
adcl = ArrayDecl(Array{Float64,2}, (3, 4))
@test adcl.arraysize == (3, 4)
@test eltype(adcl) == Float64

adcl3 = ArrayDecl(Array{Int32,3}, (2, 3, 4))
@test adcl3.arraysize == (2, 3, 4)
@test eltype(adcl3) == Int32
end

@testset "monitor" begin
w = TestWorker(1, 3.14, [1, 2, 3], "hello")

mon = monitor(w, (:param_scalar, :param_array))
@test mon[:param_scalar] == 3.14
@test mon[:param_array] == [1, 2, 3]
@test !haskey(mon, :param_string)

# monitor with extra variables
mon2 = monitor(w, (:param_scalar,), Dict{Symbol,Any}(:extra => 42))
@test mon2[:param_scalar] == 3.14
@test mon2[:extra] == 42

# field that doesn't exist should be silently skipped
mon3 = monitor(w, (:nonexistent,))
@test isempty(mon3)

# Vector of workers
workers = [TestWorker(i, Float64(i), [i], "w$i") for i in 1:3]
mons = monitor(workers, (:param_scalar,))
@test length(mons) == 3
for i in 1:3
@test mons[i][:param_scalar] == Float64(i)
end
end

@testset "monitor!" begin
w = TestWorker(1, 2.71, [4, 5], "world")

# monitor! updates all monitored fields
mon = Dict{Symbol,Any}(:param_scalar => 0.0, :param_array => [0, 0])
monitor!(mon, w)
@test mon[:param_scalar] == 2.71
@test mon[:param_array] == [4, 5]

# monitor! with symbol + array of same size: should copyto! in place
original_ref = [0, 0]
mon[:param_array] = original_ref
monitor!(mon, :param_array, [10, 20])
@test mon[:param_array] == [10, 20]
@test mon[:param_array] === original_ref # same object, mutated

# monitor! with symbol + array of different size: should replace
monitor!(mon, :param_array, [1, 2, 3])
@test mon[:param_array] == [1, 2, 3]

# monitor! with scalar value
monitor!(mon, :param_scalar, 9.99)
@test mon[:param_scalar] == 9.99

# monitor! with key not in mon should NOT add it
monitor!(mon, :not_monitored, 100)
@test !haskey(mon, :not_monitored)
end

@testset "init! and close!" begin
w = TestWorker(1, 1.0, [1], "test")

@test init!(w) === nothing
@test close!(w) === nothing

# extra args are accepted
@test init!(w, 1, :a) === nothing
@test close!(w, 1, :a) === nothing
end

@testset "worker (unimplemented)" begin
w = TestWorker(1, 1.0, [1], "test")
@test_throws ErrorException worker(w, nothing, 1, Dict{Symbol,Any}())
end

@testset "workertid" begin
w = TestWorker(7, 1.0, [1], "test")
@test workertid(w) == 7
end

@testset "load_mm_package" begin
@test load_mm_package(:cpu) === nothing
@test load_mm_package(:gpu, 1, 2) === nothing
end

@testset "maybe_sharedarray" begin
# Array on same process: returned as-is
A = [1.0, 2.0, 3.0]
@test maybe_sharedarray(A) === A

# Bits type + size: creates SharedArray
S = maybe_sharedarray(Float64, (3, 4))
@test S isa SharedArray{Float64}
@test size(S) == (3, 4)

# Non-bits type + size: creates regular Array
R = maybe_sharedarray(String, (2, 3))
@test R isa Array{String}
@test size(R) == (2, 3)

# ArrayDecl with bits eltype: creates SharedArray
adcl = ArrayDecl(Array{Float32,2}, (5, 6))
Sa = maybe_sharedarray(adcl)
@test Sa isa SharedArray{Float32}
@test size(Sa) == (5, 6)

# Non-array scalar passthrough
@test maybe_sharedarray(42) === 42
@test maybe_sharedarray("hello") === "hello"
end

@testset "getindex_t" begin
# Image without time axis: return img itself
img_no_t = rand(3, 4)
@test getindex_t(img_no_t, 1) === img_no_t

# AxisArray with time axis: return view at tindex
data = rand(3, 4, 5)
img_t = AxisArray(data,
Axis{:x}(1:3),
Axis{:y}(1:4),
Axis{:time}(1:5))
slice = getindex_t(img_t, 3)
@test size(slice) == (3, 4)
@test slice == view(data, :, :, 3)
end

end
Loading