diff --git a/cmd/labs/github/github.go b/cmd/labs/github/github.go index 4251bc3e1e7..c33eec99ca8 100644 --- a/cmd/labs/github/github.go +++ b/cmd/labs/github/github.go @@ -64,6 +64,7 @@ func getPagedBytes(ctx context.Context, method, url string, body io.Reader) (*pa if err != nil { return nil, err } + defer res.Body.Close() if res.StatusCode == http.StatusNotFound { return nil, ErrNotFound } @@ -71,7 +72,6 @@ func getPagedBytes(ctx context.Context, method, url string, body io.Reader) (*pa return nil, fmt.Errorf("github request failed: %s", res.Status) } nextLink := parseNextLink(res.Header.Get("link")) - defer res.Body.Close() bodyBytes, err := io.ReadAll(res.Body) if err != nil { return nil, err diff --git a/cmd/labs/github/github_test.go b/cmd/labs/github/github_test.go index bd8fd3a4251..b9f0395ef05 100644 --- a/cmd/labs/github/github_test.go +++ b/cmd/labs/github/github_test.go @@ -1,11 +1,65 @@ package github import ( + "fmt" + "io" + "net/http" + "strings" "testing" "github.com/stretchr/testify/assert" ) +type closeRecordingBody struct { + io.Reader + closed *bool +} + +func (b *closeRecordingBody) Close() error { + *b.closed = true + return nil +} + +type stubTransport struct { + status int + closed bool +} + +func (s *stubTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: s.status, + Status: fmt.Sprintf("%d %s", s.status, http.StatusText(s.status)), + Header: http.Header{}, + Body: &closeRecordingBody{Reader: strings.NewReader("{}"), closed: &s.closed}, + Request: req, + }, nil +} + +func TestGetPagedBytesClosesBodyOnHTTPError(t *testing.T) { + tests := []struct { + name string + status int + }{ + {"not found", http.StatusNotFound}, + {"server error", http.StatusInternalServerError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // getPagedBytes hardcodes http.DefaultClient, so swapping its + // transport is the only seam to observe body closure. + stub := &stubTransport{status: tt.status} + prev := http.DefaultClient.Transport + http.DefaultClient.Transport = stub + t.Cleanup(func() { http.DefaultClient.Transport = prev }) + + _, err := getPagedBytes(t.Context(), "GET", "https://api.github.test/x", nil) + assert.Error(t, err) + assert.True(t, stub.closed) + }) + } +} + func TestParseNextLink(t *testing.T) { tests := []struct { name string