diff --git a/example/submitqueue/orchestrator/server/BUILD.bazel b/example/submitqueue/orchestrator/server/BUILD.bazel index d3a16bee..3dcb998a 100644 --- a/example/submitqueue/orchestrator/server/BUILD.bazel +++ b/example/submitqueue/orchestrator/server/BUILD.bazel @@ -19,6 +19,7 @@ go_library( "//extension/counter/mysql", "//extension/messagequeue", "//extension/messagequeue/mysql", + "//submitqueue/core/changeset", "//submitqueue/core/consumer", "//submitqueue/entity", "//submitqueue/extension/buildrunner", diff --git a/example/submitqueue/orchestrator/server/main.go b/example/submitqueue/orchestrator/server/main.go index 91e0d619..2c5411cb 100644 --- a/example/submitqueue/orchestrator/server/main.go +++ b/example/submitqueue/orchestrator/server/main.go @@ -38,6 +38,7 @@ import ( mysqlcounter "github.com/uber/submitqueue/extension/counter/mysql" extqueue "github.com/uber/submitqueue/extension/messagequeue" queueMySQL "github.com/uber/submitqueue/extension/messagequeue/mysql" + "github.com/uber/submitqueue/submitqueue/core/changeset" "github.com/uber/submitqueue/submitqueue/core/consumer" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/buildrunner" @@ -230,7 +231,7 @@ func run() error { // back to a baseline profile for queues without an explicit entry. This is // the single place queue topology is known; the extension packages stay // queue-agnostic. - queues, err := newQueueRegistry(logger, scope) + queues, err := newQueueRegistry(logger, scope, changeset.New(store.GetRequestStore(), store.GetChangeStore())) if err != nil { return fmt.Errorf("failed to build queue registry: %w", err) } @@ -796,7 +797,7 @@ func newPusher(logger *zap.Logger, scope tally.Scope) (pusher.Pusher, error) { // conflict analyzer. Queues without an explicit profile fall back to the // baseline. This is the one place queue topology lives; extension packages stay // queue-agnostic. -func newQueueRegistry(logger *zap.Logger, scope tally.Scope) (queueRegistry, error) { +func newQueueRegistry(logger *zap.Logger, scope tally.Scope, resolver changeset.Resolver) (queueRegistry, error) { mc, err := newMergeChecker(logger, scope) if err != nil { return queueRegistry{}, fmt.Errorf("failed to create merge checker: %w", err) @@ -833,7 +834,8 @@ func newQueueRegistry(logger *zap.Logger, scope tally.Scope) (queueRegistry, err changeProvider: cp, pusher: psh, buildRunner: buildfake.New(), - scorer: scorerfake.New(heuristic.New( + scorer: scorerfake.New(resolver, heuristic.New( + resolver, []heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: 0.5}}, batchLines, scope.SubScope("scorer.default"), )), @@ -845,7 +847,8 @@ func newQueueRegistry(logger *zap.Logger, scope tally.Scope) (queueRegistry, err // test-queue: bucketed heuristic scorer; conservative (serialized) conflicts // inherited from the baseline. testQueue := base - testQueue.scorer = scorerfake.New(heuristic.New( + testQueue.scorer = scorerfake.New(resolver, heuristic.New( + resolver, []heuristic.Bucket{ {Min: 0, Max: 1, Score: 0.95}, {Min: 2, Max: 5, Score: 0.80}, @@ -858,10 +861,10 @@ func newQueueRegistry(logger *zap.Logger, scope tally.Scope) (queueRegistry, err // e2e-test-queue: composite scorer; no conflicts (maximum parallelism). e2eQueue := base e2eQueue.analyzer = conflictfake.New(none.New(), nil) - e2eQueue.scorer = scorerfake.New(composite.New( + e2eQueue.scorer = scorerfake.New(resolver, composite.New( map[string]scorer.Scorer{ - "size": heuristic.New([]heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: 0.8}}, batchLines, scope), - "flat": heuristic.New([]heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: 0.6}}, batchLines, scope), + "size": heuristic.New(resolver, []heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: 0.8}}, batchLines, scope), + "flat": heuristic.New(resolver, []heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: 0.6}}, batchLines, scope), }, composite.Avg, scope.SubScope("scorer.e2e-test-queue"), )) diff --git a/submitqueue/extension/scorer/composite/scorer.go b/submitqueue/extension/scorer/composite/scorer.go index 2acce898..8d90c832 100644 --- a/submitqueue/extension/scorer/composite/scorer.go +++ b/submitqueue/extension/scorer/composite/scorer.go @@ -88,15 +88,15 @@ func New(scorers map[string]scorer.Scorer, reduce ReduceFunc, scope tally.Scope) } } -// Score evaluates all child scorers and combines their results using the reduce function. -// If any child scorer returns an error, that error is returned immediately. -func (c *compositeScorer) Score(ctx context.Context, changes entity.BatchChanges) (ret float64, retErr error) { +// Score evaluates all child scorers on the batch and combines their results using the +// reduce function. If any child scorer returns an error, that error is returned immediately. +func (c *compositeScorer) Score(ctx context.Context, batch entity.Batch) (ret float64, retErr error) { op := metrics.Begin(c.scope, "score") defer func() { op.Complete(retErr) }() scores := make(map[string]float64, len(c.scorers)) for name, s := range c.scorers { - score, err := s.Score(ctx, changes) + score, err := s.Score(ctx, batch) if err != nil { return 0, err } diff --git a/submitqueue/extension/scorer/composite/scorer_test.go b/submitqueue/extension/scorer/composite/scorer_test.go index 09e7266e..9052e547 100644 --- a/submitqueue/extension/scorer/composite/scorer_test.go +++ b/submitqueue/extension/scorer/composite/scorer_test.go @@ -31,14 +31,14 @@ type fixedScorer struct { score float64 } -func (f *fixedScorer) Score(_ context.Context, _ entity.BatchChanges) (float64, error) { +func (f *fixedScorer) Score(_ context.Context, _ entity.Batch) (float64, error) { return f.score, nil } // errorScorer always returns an error. type errorScorer struct{} -func (e *errorScorer) Score(_ context.Context, _ entity.BatchChanges) (float64, error) { +func (e *errorScorer) Score(_ context.Context, _ entity.Batch) (float64, error) { return 0, fmt.Errorf("scorer failed") } @@ -99,7 +99,7 @@ func TestScorer_Score(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := New(tt.scorers, tt.reduce, tally.NoopScope) - got, err := s.Score(context.Background(), entity.BatchChanges{}) + got, err := s.Score(context.Background(), entity.Batch{}) require.NoError(t, err) assert.InDelta(t, tt.want, got, 1e-9) }) @@ -111,7 +111,7 @@ func TestScorer_Score_ChildError(t *testing.T) { "error": &errorScorer{}, "files": &fixedScorer{0.9}, }, Min, tally.NoopScope) - _, err := s.Score(context.Background(), entity.BatchChanges{}) + _, err := s.Score(context.Background(), entity.Batch{}) require.Error(t, err) } @@ -140,7 +140,7 @@ func TestReduceFunc_ReceivesNames(t *testing.T) { "files": &fixedScorer{0.9}, "deps": &fixedScorer{0.95}, }, custom, tally.NoopScope) - got, err := s.Score(context.Background(), entity.BatchChanges{}) + got, err := s.Score(context.Background(), entity.Batch{}) require.NoError(t, err) assert.Equal(t, 0.9, got) assert.ElementsMatch(t, []string{"files", "deps"}, receivedNames) diff --git a/submitqueue/extension/scorer/fake/BUILD.bazel b/submitqueue/extension/scorer/fake/BUILD.bazel index f1335904..f5c857f2 100644 --- a/submitqueue/extension/scorer/fake/BUILD.bazel +++ b/submitqueue/extension/scorer/fake/BUILD.bazel @@ -6,6 +6,7 @@ go_library( importpath = "github.com/uber/submitqueue/submitqueue/extension/scorer/fake", visibility = ["//visibility:public"], deps = [ + "//submitqueue/core/changeset", "//submitqueue/core/fakemarker", "//submitqueue/entity", "//submitqueue/extension/scorer", @@ -17,6 +18,8 @@ go_test( srcs = ["fake_test.go"], embed = [":fake"], deps = [ + "//submitqueue/core/changeset", + "//submitqueue/core/changeset/fake", "//submitqueue/entity", "//submitqueue/extension/scorer", "//submitqueue/extension/scorer/heuristic", diff --git a/submitqueue/extension/scorer/fake/fake.go b/submitqueue/extension/scorer/fake/fake.go index fcd6219e..b5a76c78 100644 --- a/submitqueue/extension/scorer/fake/fake.go +++ b/submitqueue/extension/scorer/fake/fake.go @@ -27,6 +27,7 @@ import ( "context" "fmt" + "github.com/uber/submitqueue/submitqueue/core/changeset" "github.com/uber/submitqueue/submitqueue/core/fakemarker" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/scorer" @@ -36,25 +37,31 @@ import ( const tokenError = "score-error" // scorerFake decorates a delegate Scorer, injecting an error when a change URI -// carries the failure marker. +// carries the failure marker. It resolves the batch itself to inspect URIs. type scorerFake struct { + resolver changeset.Resolver delegate scorer.Scorer } // New returns a scorer.Scorer that delegates to the given scorer but returns an -// error when a change URI carries the "sq-fake=score-error" marker. The delegate -// is the existing scorer implementation to wrap (e.g. heuristic or composite). -func New(delegate scorer.Scorer) scorer.Scorer { - return scorerFake{delegate: delegate} +// error when a change URI carries the "sq-fake=score-error" marker. The resolver +// resolves the batch's changes so the marker can be inspected; the delegate is the +// existing scorer implementation to wrap (e.g. heuristic or composite). +func New(resolver changeset.Resolver, delegate scorer.Scorer) scorer.Scorer { + return scorerFake{resolver: resolver, delegate: delegate} } // Score returns an error when a change URI carries the failure marker; otherwise // it delegates to the wrapped scorer. -func (s scorerFake) Score(ctx context.Context, changes entity.BatchChanges) (float64, error) { +func (s scorerFake) Score(ctx context.Context, batch entity.Batch) (float64, error) { + changes, err := s.resolver.DetailedForBatch(ctx, batch) + if err != nil { + return 0, err + } if markerToken(changes) == tokenError { return 0, fmt.Errorf("fake: marked score error") } - return s.delegate.Score(ctx, changes) + return s.delegate.Score(ctx, batch) } // markerToken returns the marker token embedded in the first change URI that diff --git a/submitqueue/extension/scorer/fake/fake_test.go b/submitqueue/extension/scorer/fake/fake_test.go index 5a0529ef..7c09ed51 100644 --- a/submitqueue/extension/scorer/fake/fake_test.go +++ b/submitqueue/extension/scorer/fake/fake_test.go @@ -21,42 +21,50 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber-go/tally" + "github.com/uber/submitqueue/submitqueue/core/changeset" + changesetfake "github.com/uber/submitqueue/submitqueue/core/changeset/fake" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/scorer" "github.com/uber/submitqueue/submitqueue/extension/scorer/heuristic" ) +const batchID = "q/batch/1" + func TestNew_ImplementsInterface(t *testing.T) { - var _ scorer.Scorer = New(nil) + var _ scorer.Scorer = New(nil, nil) +} + +// resolverFor returns a changeset resolver seeded so that batchID's detailed +// changes carry the given URIs. +func resolverFor(uris ...string) changeset.Resolver { + changes := make([]entity.ChangeInfo, 0, len(uris)) + for _, u := range uris { + changes = append(changes, entity.ChangeInfo{URI: u}) + } + return changesetfake.New().SetDetailed(batchID, entity.BatchChanges{BatchID: batchID, Queue: "q", Changes: changes}) } -// delegate returns a heuristic scorer that scores every batch at want. -func delegate(t *testing.T, want float64) scorer.Scorer { - t.Helper() +// delegate returns a heuristic scorer (backed by resolver) that scores every batch at want. +func delegate(resolver changeset.Resolver, want float64) scorer.Scorer { return heuristic.New( + resolver, []heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: want}}, func(_ context.Context, c entity.BatchChanges) (int, error) { return len(c.Changes), nil }, tally.NoopScope, ) } -func batch(uris ...string) entity.BatchChanges { - changes := make([]entity.ChangeInfo, 0, len(uris)) - for _, u := range uris { - changes = append(changes, entity.ChangeInfo{URI: u}) - } - return entity.BatchChanges{BatchID: "q/batch/1", Queue: "q", Changes: changes} -} - func TestScore_DelegatesWhenUnmarked(t *testing.T) { - s := New(delegate(t, 0.7)) - got, err := s.Score(context.Background(), batch("github://o/r/pull/1/a")) + r := resolverFor("github://o/r/pull/1/a") + s := New(r, delegate(r, 0.7)) + got, err := s.Score(context.Background(), entity.Batch{ID: batchID}) require.NoError(t, err) assert.Equal(t, 0.7, got) } func TestScore_ErrorMarker(t *testing.T) { - s := New(delegate(t, 0.7)) - _, err := s.Score(context.Background(), batch("github://o/r/pull/1/a?sq-fake=score-error")) + r := resolverFor("github://o/r/pull/1/a?sq-fake=score-error") + s := New(r, delegate(r, 0.7)) + _, err := s.Score(context.Background(), entity.Batch{ID: batchID}) require.Error(t, err) } diff --git a/submitqueue/extension/scorer/heuristic/BUILD.bazel b/submitqueue/extension/scorer/heuristic/BUILD.bazel index 7b66d777..3f45785e 100644 --- a/submitqueue/extension/scorer/heuristic/BUILD.bazel +++ b/submitqueue/extension/scorer/heuristic/BUILD.bazel @@ -7,6 +7,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//core/metrics", + "//submitqueue/core/changeset", "//submitqueue/entity", "//submitqueue/extension/scorer", "@com_github_uber_go_tally//:tally", @@ -18,6 +19,7 @@ go_test( srcs = ["scorer_test.go"], embed = [":heuristic"], deps = [ + "//submitqueue/core/changeset/fake", "//submitqueue/entity", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/submitqueue/extension/scorer/heuristic/scorer.go b/submitqueue/extension/scorer/heuristic/scorer.go index b1afea1c..ac20d559 100644 --- a/submitqueue/extension/scorer/heuristic/scorer.go +++ b/submitqueue/extension/scorer/heuristic/scorer.go @@ -20,6 +20,7 @@ import ( "github.com/uber-go/tally" "github.com/uber/submitqueue/core/metrics" + "github.com/uber/submitqueue/submitqueue/core/changeset" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/scorer" ) @@ -40,6 +41,8 @@ type Bucket struct { // heuristicScorer computes a success probability by bucketing a metric extracted from a batch of changes. // It follows the Java HeuristicsBasedSuccessPredictor pattern. type heuristicScorer struct { + // resolver resolves the batch identity into its detailed changes. + resolver changeset.Resolver // buckets is the list of ranges to match against. buckets []Bucket // valueFunc extracts the numeric value from a batch of changes. @@ -48,24 +51,30 @@ type heuristicScorer struct { scope tally.Scope } -// New creates a new heuristic Scorer with the given buckets and value function. +// New creates a new heuristic Scorer with the given resolver, buckets and value function. // Panics if valueFunc is nil. -func New(buckets []Bucket, valueFunc ValueFunc, scope tally.Scope) scorer.Scorer { +func New(resolver changeset.Resolver, buckets []Bucket, valueFunc ValueFunc, scope tally.Scope) scorer.Scorer { if valueFunc == nil { panic("heuristic.New: valueFunc must not be nil") } return &heuristicScorer{ + resolver: resolver, buckets: buckets, valueFunc: valueFunc, scope: scope, } } -// Score extracts the value from the batch of changes, then returns the probability score for the -// first bucket whose range [Min, Max] contains the value. Returns an error if no bucket matches. -func (s *heuristicScorer) Score(ctx context.Context, changes entity.BatchChanges) (ret float64, retErr error) { +// Score resolves the batch's changes, extracts the metric, then returns the probability +// score for the first bucket whose range [Min, Max] contains the value. Returns an error +// if no bucket matches. +func (s *heuristicScorer) Score(ctx context.Context, batch entity.Batch) (ret float64, retErr error) { op := metrics.Begin(s.scope, "score") defer func() { op.Complete(retErr) }() + changes, err := s.resolver.DetailedForBatch(ctx, batch) + if err != nil { + return 0, err + } value, err := s.valueFunc(ctx, changes) if err != nil { return 0, err diff --git a/submitqueue/extension/scorer/heuristic/scorer_test.go b/submitqueue/extension/scorer/heuristic/scorer_test.go index d65a9bae..5255de64 100644 --- a/submitqueue/extension/scorer/heuristic/scorer_test.go +++ b/submitqueue/extension/scorer/heuristic/scorer_test.go @@ -21,6 +21,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber-go/tally" + changesetfake "github.com/uber/submitqueue/submitqueue/core/changeset/fake" "github.com/uber/submitqueue/submitqueue/entity" ) @@ -106,8 +107,8 @@ func TestScorer_Score(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := New(tt.buckets, tt.valueFunc, tally.NoopScope) - got, err := s.Score(context.Background(), entity.BatchChanges{}) + s := New(changesetfake.New(), tt.buckets, tt.valueFunc, tally.NoopScope) + got, err := s.Score(context.Background(), entity.Batch{}) if tt.wantErr { require.Error(t, err) return @@ -122,13 +123,13 @@ func TestScorer_Score_ValueFuncError(t *testing.T) { failing := func(_ context.Context, _ entity.BatchChanges) (int, error) { return 0, assert.AnError } - s := New([]Bucket{{Min: 0, Max: 10, Score: 0.9}}, failing, tally.NoopScope) - _, err := s.Score(context.Background(), entity.BatchChanges{}) + s := New(changesetfake.New(), []Bucket{{Min: 0, Max: 10, Score: 0.9}}, failing, tally.NoopScope) + _, err := s.Score(context.Background(), entity.Batch{}) require.Error(t, err) } func TestNew_NilValueFunc(t *testing.T) { assert.Panics(t, func() { - New([]Bucket{{Min: 0, Max: 10, Score: 0.85}}, nil, tally.NoopScope) + New(changesetfake.New(), []Bucket{{Min: 0, Max: 10, Score: 0.85}}, nil, tally.NoopScope) }) } diff --git a/submitqueue/extension/scorer/mock/scorer_mock.go b/submitqueue/extension/scorer/mock/scorer_mock.go index 72edc280..9b64b754 100644 --- a/submitqueue/extension/scorer/mock/scorer_mock.go +++ b/submitqueue/extension/scorer/mock/scorer_mock.go @@ -43,18 +43,18 @@ func (m *MockScorer) EXPECT() *MockScorerMockRecorder { } // Score mocks base method. -func (m *MockScorer) Score(ctx context.Context, changes entity.BatchChanges) (float64, error) { +func (m *MockScorer) Score(ctx context.Context, batch entity.Batch) (float64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Score", ctx, changes) + ret := m.ctrl.Call(m, "Score", ctx, batch) ret0, _ := ret[0].(float64) ret1, _ := ret[1].(error) return ret0, ret1 } // Score indicates an expected call of Score. -func (mr *MockScorerMockRecorder) Score(ctx, changes any) *gomock.Call { +func (mr *MockScorerMockRecorder) Score(ctx, batch any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Score", reflect.TypeOf((*MockScorer)(nil).Score), ctx, changes) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Score", reflect.TypeOf((*MockScorer)(nil).Score), ctx, batch) } // MockFactory is a mock of Factory interface. diff --git a/submitqueue/extension/scorer/scorer.go b/submitqueue/extension/scorer/scorer.go index 6837448e..b3af1b09 100644 --- a/submitqueue/extension/scorer/scorer.go +++ b/submitqueue/extension/scorer/scorer.go @@ -22,11 +22,12 @@ import ( "github.com/uber/submitqueue/submitqueue/entity" ) -// Scorer computes a success probability score for a batch of changes based on their characteristics. +// Scorer computes a success probability score for a batch based on its changes. type Scorer interface { // Score returns a probability between 0.0 and 1.0 indicating the likelihood - // of a successful land for the given batch of changes. - Score(ctx context.Context, changes entity.BatchChanges) (float64, error) + // of a successful land for the given batch. It is handed the batch identity + // and resolves the batch's changes itself through an injected changeset.Resolver. + Score(ctx context.Context, batch entity.Batch) (float64, error) } // Config carries the per-queue identity handed to a Factory. The system knows diff --git a/submitqueue/orchestrator/controller/score/score.go b/submitqueue/orchestrator/controller/score/score.go index 836c4449..6eb6e1f1 100644 --- a/submitqueue/orchestrator/controller/score/score.go +++ b/submitqueue/orchestrator/controller/score/score.go @@ -130,7 +130,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (r return nil } - // Score each request's change and take the minimum (worst-case) as the batch score + // Score the batch. The scorer resolves the batch's changes itself. batchScore, err := c.scoreBatch(ctx, batch) if err != nil { metrics.NamedCounter(c.metricsScope, opName, "scorer_errors", 1) @@ -173,56 +173,22 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (r return nil // Success - message will be acked } -// scoreBatch normalizes the batch's changes and scores them as a whole. It resolves -// each request in the batch, reads that request's change records (one per URI), and -// flattens their provider-supplied details into a single entity.BatchChanges, which -// the scorer turns into one probability for the batch. +// scoreBatch builds the queue's scorer and scores the batch. The scorer is handed +// the batch identity and resolves the batch's changes itself (via the shared +// changeset resolver injected at its factory), turning them into one probability. func (c *Controller) scoreBatch(ctx context.Context, batch entity.Batch) (float64, error) { sc, err := c.scorers.For(scorer.Config{QueueName: batch.Queue}) if err != nil { return 0, fmt.Errorf("failed to build scorer for batch %s: %w", batch.ID, err) } - changes, err := c.collectBatchChanges(ctx, batch) - if err != nil { - return 0, err - } - - score, err := sc.Score(ctx, changes) + score, err := sc.Score(ctx, batch) if err != nil { return 0, fmt.Errorf("failed to score batch %s: %w", batch.ID, err) } return score, nil } -// collectBatchChanges assembles the normalized entity.BatchChanges for a batch by -// resolving each request and reading its change records per URI. For each URI it -// selects the record owned by the request (GetByURI returns rows for all requests -// that ever claimed the URI) and appends its URI + details. -func (c *Controller) collectBatchChanges(ctx context.Context, batch entity.Batch) (entity.BatchChanges, error) { - changes := entity.BatchChanges{BatchID: batch.ID, Queue: batch.Queue} - for _, requestID := range batch.Contains { - request, err := c.store.GetRequestStore().Get(ctx, requestID) - if err != nil { - return entity.BatchChanges{}, fmt.Errorf("failed to get request %s: %w", requestID, err) - } - for _, uri := range request.Change.URIs { - records, err := c.store.GetChangeStore().GetByURI(ctx, batch.Queue, uri) - if err != nil { - return entity.BatchChanges{}, fmt.Errorf("failed to read change record for request %s uri=%s: %w", requestID, uri, err) - } - for _, rec := range records { - if rec.RequestID != requestID { - continue - } - changes.Changes = append(changes.Changes, entity.ChangeInfo{URI: rec.URI, Details: rec.Details}) - break - } - } - } - return changes, nil -} - // publish publishes a batch ID to the specified topic key. func (c *Controller) publish(ctx context.Context, key consumer.TopicKey, batchID string, partitionKey string) error { bid := entity.BatchID{ID: batchID} diff --git a/submitqueue/orchestrator/controller/score/score_test.go b/submitqueue/orchestrator/controller/score/score_test.go index a716815f..9a0ca183 100644 --- a/submitqueue/orchestrator/controller/score/score_test.go +++ b/submitqueue/orchestrator/controller/score/score_test.go @@ -166,8 +166,10 @@ func TestController_Process_Success(t *testing.T) { require.NoError(t, err) } -// TestController_Process_BatchLevelScore verifies the controller assembles all of the -// batch's changes into one BatchChanges and persists the single score the scorer returns. +// TestController_Process_BatchLevelScore verifies the controller hands the batch +// identity to the scorer and persists the single score it returns. Resolving the +// batch's changes is the scorer's concern (via the changeset resolver), not the +// controller's. func TestController_Process_BatchLevelScore(t *testing.T) { ctrl := gomock.NewController(t) @@ -179,41 +181,19 @@ func TestController_Process_BatchLevelScore(t *testing.T) { Version: 1, } - request1 := entity.Request{ - ID: "test-queue/1", - Queue: "test-queue", - Change: entity.Change{URIs: []string{"github://uber/repo/pull/1/aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}, - State: entity.RequestStateStarted, - Version: 1, - } - request2 := entity.Request{ - ID: "test-queue/2", - Queue: "test-queue", - Change: entity.Change{URIs: []string{"github://uber/repo/pull/2/bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}}, - State: entity.RequestStateStarted, - Version: 1, - } - mockBatchStore := storagemock.NewMockBatchStore(ctrl) mockBatchStore.EXPECT().Get(gomock.Any(), batch.ID).Return(batch, nil) // The single batch-level score is persisted. mockBatchStore.EXPECT().UpdateScoreAndState(gomock.Any(), batch.ID, batch.Version, batch.Version+1, 0.7, entity.BatchStateScored).Return(nil) - mockRequestStore := storagemock.NewMockRequestStore(ctrl) - mockRequestStore.EXPECT().Get(gomock.Any(), "test-queue/1").Return(request1, nil) - mockRequestStore.EXPECT().Get(gomock.Any(), "test-queue/2").Return(request2, nil) - store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() - store.EXPECT().GetRequestStore().Return(mockRequestStore).AnyTimes() - store.EXPECT().GetChangeStore().Return(mockChangeStore(ctrl, request1, request2)).AnyTimes() - // Capture the BatchChanges to assert both requests' changes were gathered. + // The controller passes the batch identity to the scorer and persists its score. mockScorer := scorermock.NewMockScorer(ctrl) mockScorer.EXPECT().Score(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, changes entity.BatchChanges) (float64, error) { - assert.Equal(t, batch.ID, changes.BatchID) - assert.Len(t, changes.Changes, 2) + func(_ context.Context, b entity.Batch) (float64, error) { + assert.Equal(t, batch.ID, b.ID) return 0.7, nil }, ) @@ -260,7 +240,7 @@ func TestController_Process_ScorerFailure(t *testing.T) { mockBatchStore.EXPECT().Get(gomock.Any(), batch.ID).Return(batch, nil) mockRequestStore := storagemock.NewMockRequestStore(ctrl) - mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil) + mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil).AnyTimes() store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() @@ -292,7 +272,7 @@ func TestController_Process_UpdateScoreFailure(t *testing.T) { mockBatchStore.EXPECT().UpdateScoreAndState(gomock.Any(), batch.ID, batch.Version, batch.Version+1, gomock.Any(), entity.BatchStateScored).Return(fmt.Errorf("version mismatch")) mockRequestStore := storagemock.NewMockRequestStore(ctrl) - mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil) + mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil).AnyTimes() store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes()