flags_test.go 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. package viper
  2. import (
  3. "testing"
  4. "github.com/spf13/pflag"
  5. "github.com/stretchr/testify/assert"
  6. )
  7. func TestBindFlagValueSet(t *testing.T) {
  8. flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)
  9. var testValues = map[string]*string{
  10. "host": nil,
  11. "port": nil,
  12. "endpoint": nil,
  13. }
  14. var mutatedTestValues = map[string]string{
  15. "host": "localhost",
  16. "port": "6060",
  17. "endpoint": "/public",
  18. }
  19. for name := range testValues {
  20. testValues[name] = flagSet.String(name, "", "test")
  21. }
  22. flagValueSet := pflagValueSet{flagSet}
  23. err := BindFlagValues(flagValueSet)
  24. if err != nil {
  25. t.Fatalf("error binding flag set, %v", err)
  26. }
  27. flagSet.VisitAll(func(flag *pflag.Flag) {
  28. flag.Value.Set(mutatedTestValues[flag.Name])
  29. flag.Changed = true
  30. })
  31. for name, expected := range mutatedTestValues {
  32. assert.Equal(t, Get(name), expected)
  33. }
  34. }
  35. func TestBindFlagValue(t *testing.T) {
  36. var testString = "testing"
  37. var testValue = newStringValue(testString, &testString)
  38. flag := &pflag.Flag{
  39. Name: "testflag",
  40. Value: testValue,
  41. Changed: false,
  42. }
  43. flagValue := pflagValue{flag}
  44. BindFlagValue("testvalue", flagValue)
  45. assert.Equal(t, testString, Get("testvalue"))
  46. flag.Value.Set("testing_mutate")
  47. flag.Changed = true //hack for pflag usage
  48. assert.Equal(t, "testing_mutate", Get("testvalue"))
  49. }