From 74f9694386c253ecb2a4bcee8ce528cad667c422 Mon Sep 17 00:00:00 2001 From: Dae Woo Kim Date: Mon, 20 Oct 2025 14:14:03 -0500 Subject: [PATCH 1/2] for testing multi-thread --- src/RegisterWorkerShell.jl | 19 +++++++++++++++++-- test/runtests.jl | 3 +++ 2 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 test/runtests.jl diff --git a/src/RegisterWorkerShell.jl b/src/RegisterWorkerShell.jl index 6f834a3..d3667ec 100644 --- a/src/RegisterWorkerShell.jl +++ b/src/RegisterWorkerShell.jl @@ -3,7 +3,7 @@ module RegisterWorkerShell using SimpleTraits, ImageAxes, ImageMetadata, Distributed, SharedArrays using AxisArrays: AxisArray, Axis -export AbstractWorker, AnyValue, ArrayDecl, close!, init!, maybe_sharedarray, monitor, monitor!, worker, workerpid, getindex_t +export AbstractWorker, AnyValue, ArrayDecl, close!, init!, maybe_sharedarray, monitor, monitor_thread, monitor!, worker, workerpid, getindex_t export load_mm_package """ @@ -77,7 +77,22 @@ function monitor(algorithm::AbstractWorker, fields::Union{NTuple{N,Symbol},Vecto mon end -monitor(algorithm::Vector{W}, fields, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where {W<:AbstractWorker} = map(alg->monitor(alg, fields, morevars), algorithm) +function monitor_thread(algorithm::AbstractWorker, fields::Union{NTuple{N,Symbol},Vector{Symbol}}, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where N + mon = Dict{Symbol,Any}() + for f in fields + isdefined(algorithm, f) || continue + mon[f] = getfield(algorithm, f) + end + for (k,v) in morevars + mon[k] = v + end + mon +end + +monitor(algorithms::Vector{W}, fields, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where {W<:AbstractWorker} = + map(alg->monitor(alg, fields, morevars), algorithms) # for multi-process +monitor_thread(algorithms::Vector{W}, fields, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where {W<:AbstractWorker} = + map(alg->monitor_thread(alg, fields, morevars), algorithms) # for multi-thread """ `monitor!(mon, algorithm)` updates `mon` with the current values of diff --git a/test/runtests.jl b/test/runtests.jl new file mode 100644 index 0000000..94c9506 --- /dev/null +++ b/test/runtests.jl @@ -0,0 +1,3 @@ +using RegisterWorkerShell, Test + +@test 1+1 == 2 From 0e00780c653200458da6b131d8f3ba353a4f2292 Mon Sep 17 00:00:00 2001 From: Dae Woo Kim Date: Wed, 22 Oct 2025 14:25:38 -0500 Subject: [PATCH 2/2] multi-thread base Apertures --- src/RegisterWorkerShell.jl | 34 ++------- test/runtests.jl | 147 ++++++++++++++++++++++++++++++++++++- 2 files changed, 153 insertions(+), 28 deletions(-) diff --git a/src/RegisterWorkerShell.jl b/src/RegisterWorkerShell.jl index d3667ec..b30e0e8 100644 --- a/src/RegisterWorkerShell.jl +++ b/src/RegisterWorkerShell.jl @@ -3,7 +3,8 @@ module RegisterWorkerShell using SimpleTraits, ImageAxes, ImageMetadata, Distributed, SharedArrays using AxisArrays: AxisArray, Axis -export AbstractWorker, AnyValue, ArrayDecl, close!, init!, maybe_sharedarray, monitor, monitor_thread, monitor!, worker, workerpid, getindex_t +export AbstractWorker, AnyValue, ArrayDecl, close!, init!, maybe_sharedarray, monitor +export monitor_thread, monitor!, worker, workertid, workerpid, getindex_t export load_mm_package """ @@ -36,7 +37,7 @@ subtypes. The exported operations are: - `init!` and `close!`: functions you may specialize if your algorithm needs to initialize or clean up resources - `worker`: perform registration on an image - - `workerpid`: extract the process-id for a given worker + - `workertid`: extract the thread-id """ RegisterWorkerShell @@ -58,26 +59,8 @@ The worker algorithm should call `monitor!(mon, algorithm)` to copy the values into `mon`, and `monitor!(mon, :var3, var3)` for an internal variable `var3` that is not taken from `algorithm`. See `monitor!` for more detail. - -An important detail is that if `workerpid(algorithm) ≠ myid()`, then any -requested `AbstractArray` fields in `algorithm` will be turned into -`SharedArray`s for `mon`. This reduces the cost of communication -between the worker and driver processes. """ function monitor(algorithm::AbstractWorker, fields::Union{NTuple{N,Symbol},Vector{Symbol}}, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where N - pid = workerpid(algorithm) - mon = Dict{Symbol,Any}() - for f in fields - isdefined(algorithm, f) || continue - mon[f] = maybe_sharedarray(getfield(algorithm, f), pid) - end - for (k,v) in morevars - mon[k] = maybe_sharedarray(v, pid) - end - mon -end - -function monitor_thread(algorithm::AbstractWorker, fields::Union{NTuple{N,Symbol},Vector{Symbol}}, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where N mon = Dict{Symbol,Any}() for f in fields isdefined(algorithm, f) || continue @@ -88,11 +71,8 @@ function monitor_thread(algorithm::AbstractWorker, fields::Union{NTuple{N,Symbol end mon end - monitor(algorithms::Vector{W}, fields, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where {W<:AbstractWorker} = - map(alg->monitor(alg, fields, morevars), algorithms) # for multi-process -monitor_thread(algorithms::Vector{W}, fields, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where {W<:AbstractWorker} = - map(alg->monitor_thread(alg, fields, morevars), algorithms) # for multi-thread + map(alg->monitor(alg, fields, morevars), algorithms) # for multi-thread """ `monitor!(mon, algorithm)` updates `mon` with the current values of @@ -165,12 +145,12 @@ worker(algorithm::AbstractWorker, img, tindex, mon) = error("Worker modules must worker(rr::RemoteChannel, img, tindex, mon) = worker(fetch(rr), img, tindex, mon) """ -`workerpid(algorithm)` extracts the `pid` associated with the worker +`workertid(algorithm)` extracts the `workertid` associated with the thread that will be assigned tasks for `algorithm`. All `AbstractWorker` -subtypes should include a `workerpid` field, or overload this function +subtypes should include a `workertid` field, or overload this function to return myid(). """ -workerpid(w::AbstractWorker) = w.workerpid +workertid(w::AbstractWorker) = w.workertid """ `load_mm_package(dev)` loads appropriate mismatch module conditioned on diff --git a/test/runtests.jl b/test/runtests.jl index 94c9506..faa6a9f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,148 @@ using RegisterWorkerShell, Test +using ImageAxes, ImageMetadata, AxisArrays +using SharedArrays -@test 1+1 == 2 +# Concrete AbstractWorker subtype for testing +mutable struct TestWorker <: AbstractWorker + workertid::Int + param_scalar::Float64 + param_array::Vector{Int} + param_string::String +end + +@testset "RegisterWorkerShell" begin + + @testset "ArrayDecl" begin + adcl = ArrayDecl(Array{Float64,2}, (3, 4)) + @test adcl.arraysize == (3, 4) + @test eltype(adcl) == Float64 + + adcl3 = ArrayDecl(Array{Int32,3}, (2, 3, 4)) + @test adcl3.arraysize == (2, 3, 4) + @test eltype(adcl3) == Int32 + end + + @testset "monitor" begin + w = TestWorker(1, 3.14, [1, 2, 3], "hello") + + mon = monitor(w, (:param_scalar, :param_array)) + @test mon[:param_scalar] == 3.14 + @test mon[:param_array] == [1, 2, 3] + @test !haskey(mon, :param_string) + + # monitor with extra variables + mon2 = monitor(w, (:param_scalar,), Dict{Symbol,Any}(:extra => 42)) + @test mon2[:param_scalar] == 3.14 + @test mon2[:extra] == 42 + + # field that doesn't exist should be silently skipped + mon3 = monitor(w, (:nonexistent,)) + @test isempty(mon3) + + # Vector of workers + workers = [TestWorker(i, Float64(i), [i], "w$i") for i in 1:3] + mons = monitor(workers, (:param_scalar,)) + @test length(mons) == 3 + for i in 1:3 + @test mons[i][:param_scalar] == Float64(i) + end + end + + @testset "monitor!" begin + w = TestWorker(1, 2.71, [4, 5], "world") + + # monitor! updates all monitored fields + mon = Dict{Symbol,Any}(:param_scalar => 0.0, :param_array => [0, 0]) + monitor!(mon, w) + @test mon[:param_scalar] == 2.71 + @test mon[:param_array] == [4, 5] + + # monitor! with symbol + array of same size: should copyto! in place + original_ref = [0, 0] + mon[:param_array] = original_ref + monitor!(mon, :param_array, [10, 20]) + @test mon[:param_array] == [10, 20] + @test mon[:param_array] === original_ref # same object, mutated + + # monitor! with symbol + array of different size: should replace + monitor!(mon, :param_array, [1, 2, 3]) + @test mon[:param_array] == [1, 2, 3] + + # monitor! with scalar value + monitor!(mon, :param_scalar, 9.99) + @test mon[:param_scalar] == 9.99 + + # monitor! with key not in mon should NOT add it + monitor!(mon, :not_monitored, 100) + @test !haskey(mon, :not_monitored) + end + + @testset "init! and close!" begin + w = TestWorker(1, 1.0, [1], "test") + + @test init!(w) === nothing + @test close!(w) === nothing + + # extra args are accepted + @test init!(w, 1, :a) === nothing + @test close!(w, 1, :a) === nothing + end + + @testset "worker (unimplemented)" begin + w = TestWorker(1, 1.0, [1], "test") + @test_throws ErrorException worker(w, nothing, 1, Dict{Symbol,Any}()) + end + + @testset "workertid" begin + w = TestWorker(7, 1.0, [1], "test") + @test workertid(w) == 7 + end + + @testset "load_mm_package" begin + @test load_mm_package(:cpu) === nothing + @test load_mm_package(:gpu, 1, 2) === nothing + end + + @testset "maybe_sharedarray" begin + # Array on same process: returned as-is + A = [1.0, 2.0, 3.0] + @test maybe_sharedarray(A) === A + + # Bits type + size: creates SharedArray + S = maybe_sharedarray(Float64, (3, 4)) + @test S isa SharedArray{Float64} + @test size(S) == (3, 4) + + # Non-bits type + size: creates regular Array + R = maybe_sharedarray(String, (2, 3)) + @test R isa Array{String} + @test size(R) == (2, 3) + + # ArrayDecl with bits eltype: creates SharedArray + adcl = ArrayDecl(Array{Float32,2}, (5, 6)) + Sa = maybe_sharedarray(adcl) + @test Sa isa SharedArray{Float32} + @test size(Sa) == (5, 6) + + # Non-array scalar passthrough + @test maybe_sharedarray(42) === 42 + @test maybe_sharedarray("hello") === "hello" + end + + @testset "getindex_t" begin + # Image without time axis: return img itself + img_no_t = rand(3, 4) + @test getindex_t(img_no_t, 1) === img_no_t + + # AxisArray with time axis: return view at tindex + data = rand(3, 4, 5) + img_t = AxisArray(data, + Axis{:x}(1:3), + Axis{:y}(1:4), + Axis{:time}(1:5)) + slice = getindex_t(img_t, 3) + @test size(slice) == (3, 4) + @test slice == view(data, :, :, 3) + end + +end