diff --git a/urlpath.go b/urlpath.go index d40fdd1..3296751 100644 --- a/urlpath.go +++ b/urlpath.go @@ -42,10 +42,23 @@ func (p Parameter) Name() string { // Currently only int and string types are supported. type Schema map[Parameter]Parser +// MustParse will parse the URL path vars from r given the +// element names and parsers defined in schema. If there is +// a problem parsing the values, a panic is triggered. +// +// This function only works with requests being processed by +// handlers of a gorilla/mux. +func MustParse(r *http.Request, schema Schema) { + if err := Parse(r, schema); err != nil { + panic("urlpath: " + err.Error()) + } +} + // Parse will parse the URL path vars from r given the -// element names and parsers defined in schema. +// element names and parsers defined in schema. If there +// is a problem parsing the values, an error is returned. // -// This method only works with requests being processed by +// This function only works with requests being processed by // handlers of a gorilla/mux. func Parse(r *http.Request, schema Schema) error { return ParseValues(mux.Vars(r), schema) diff --git a/urlpath_test.go b/urlpath_test.go index 9cb34d0..5067ffc 100644 --- a/urlpath_test.go +++ b/urlpath_test.go @@ -13,6 +13,64 @@ import ( "github.com/shoenig/test/must" ) +func Test_MustParse(t *testing.T) { + t.Parallel() + + router := mux.NewRouter() + executed := false + + router.HandleFunc("/v1/{foo}/{bar}", func(_ http.ResponseWriter, r *http.Request) { + var foo string + var bar int + + MustParse(r, Schema{ + "foo": String(&foo), + "bar": Int(&bar), + }) + + must.Eq(t, "blah", foo) + must.Eq(t, 31, bar) + executed = true + }) + + w := httptest.NewRecorder() + ctx := context.Background() + request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/blah/31", nil) + must.NoError(t, err) + + router.ServeHTTP(w, request) + must.True(t, executed) +} + +func Test_MustParse_panic(t *testing.T) { + t.Parallel() + + router := mux.NewRouter() + executed := false + + router.HandleFunc("/v1/{foo}/{bar}", func(_ http.ResponseWriter, r *http.Request) { + var foo string + var bar int + + executed = true + + must.Panic(t, func() { + MustParse(r, Schema{ + "foo": String(&foo), + "bar": Int(&bar), + }) + }) + }) + + w := httptest.NewRecorder() + ctx := context.Background() + request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/blah/bad", nil) + must.NoError(t, err) + + router.ServeHTTP(w, request) + must.True(t, executed) +} + func Test_Parse(t *testing.T) { t.Parallel()