Skip to content

Commit 4fbca9f

Browse files
authored
Add Curry() (#37)
1 parent 41b9ba9 commit 4fbca9f

File tree

3 files changed

+303
-3
lines changed

3 files changed

+303
-3
lines changed

example_provider_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,25 @@ func ExampleSaveTo() {
121121
// Output: <nil> one 3
122122
}
123123

124+
func ExampleCurry() {
125+
lotsOfUnchangingArgs := func(s string, i int, u uint) string {
126+
return fmt.Sprintf("%s-%d-%d", s, i, u)
127+
}
128+
var shorthand func(i int) string
129+
fmt.Println(nject.Run("example",
130+
func() string { return "foo" },
131+
func() uint { return 33 },
132+
nject.MustCurry(lotsOfUnchangingArgs, &shorthand),
133+
func() {
134+
fmt.Println("actual injection goal")
135+
},
136+
))
137+
fmt.Println(shorthand(10))
138+
// Output: actual injection goal
139+
// <nil>
140+
// foo-10-33
141+
}
142+
124143
// This demonstrates how it to have a default that gets overridden by
125144
// by later inputs.
126145
func ExampleReorder() {

utils.go

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,133 @@ import (
55
"reflect"
66
)
77

8+
// Curry generates a Requird Provider that prefills arguments to a function to create a
9+
// new function that needs fewer args.
10+
//
11+
// Only arguments with a unique (to the function) type can be curried.
12+
//
13+
// The original function and the curried function must have the same outputs.
14+
//
15+
// The first curried input may not be a function.
16+
//
17+
// EXPERIMENTAL: This is currently considered experimental and could be removed or
18+
// moved to another package. If you're using this, open a pull request to remove
19+
// this comment.
20+
func Curry(originalFunction interface{}, pointerToCurriedFunction interface{}) (Provider, error) {
21+
o := reflect.ValueOf(originalFunction)
22+
if !o.IsValid() {
23+
return nil, fmt.Errorf("original function is not a valid value")
24+
}
25+
if o.Type().Kind() != reflect.Func {
26+
return nil, fmt.Errorf("first argument to Curry must be a function")
27+
}
28+
n := reflect.ValueOf(pointerToCurriedFunction)
29+
if !n.IsValid() {
30+
return nil, fmt.Errorf("curried function is not a valid value")
31+
}
32+
if n.Type().Kind() != reflect.Ptr {
33+
return nil, fmt.Errorf("curried function must be a pointer (to a function)")
34+
}
35+
if n.Type().Elem().Kind() != reflect.Func {
36+
return nil, fmt.Errorf("curried function must be a pointer to a function")
37+
}
38+
if n.IsNil() {
39+
return nil, fmt.Errorf("pointer to curried function cannot be nil")
40+
}
41+
if o.Type().NumOut() != n.Type().Elem().NumOut() {
42+
return nil, fmt.Errorf("current function doesn't have the same number of outputs, %d, as curried function, %d",
43+
o.Type().NumOut(), n.Type().Elem().NumOut())
44+
}
45+
outputs := make([]reflect.Type, o.Type().NumOut())
46+
for i := 0; i < len(outputs); i++ {
47+
if o.Type().Out(i) != n.Type().Elem().Out(i) {
48+
return nil, fmt.Errorf("current function return value #%d has a different type, %s, than the curried functions return value, %s",
49+
i+1, o.Type().Out(i), n.Type().Elem().Out(i))
50+
}
51+
outputs[i] = o.Type().Out(i)
52+
}
53+
54+
// Figure out the set of input types for the curried function
55+
ntypes := make(map[reflect.Type][]int)
56+
for i := 0; i < n.Type().Elem().NumIn(); i++ {
57+
t := n.Type().Elem().In(i)
58+
ntypes[t] = append(ntypes[t], i)
59+
}
60+
61+
// Now, for each input in the original function, figure out where it
62+
// is coming from.
63+
originalNumIn := o.Type().NumIn()
64+
used := make(map[reflect.Type]int)
65+
curryCount := originalNumIn - n.Type().Elem().NumIn()
66+
if curryCount < 1 {
67+
return nil, fmt.Errorf("curried function must take fewer arguments than original function")
68+
}
69+
curried := make([]reflect.Type, 0, curryCount) // injected inputs
70+
alreadyCurried := make(map[reflect.Type]struct{}) // to prevent double-dipping
71+
curryMap := make([]int, 0, curryCount) // maps postion from injected inputs to to original
72+
passMap := make([]int, n.Type().Elem().NumIn()) // maps position from curried to original
73+
for i := 0; i < o.Type().NumIn(); i++ {
74+
t := o.Type().In(i)
75+
if plist, ok := ntypes[t]; ok {
76+
if used[t] < len(plist) {
77+
passMap[plist[used[t]]] = i
78+
used[t]++
79+
} else {
80+
return nil, fmt.Errorf("original function takes more arguments of type %s than the curried function", t)
81+
}
82+
} else {
83+
if _, ok := alreadyCurried[t]; ok {
84+
return nil, fmt.Errorf("cannot curry the same type (%s) more than once", t)
85+
}
86+
alreadyCurried[t] = struct{}{}
87+
curryMap = append(curryMap, i)
88+
curried = append(curried, t)
89+
}
90+
}
91+
for t, plist := range ntypes {
92+
if used[t] < len(plist) {
93+
return nil, fmt.Errorf("not all of the %s inputs to the curried function were used by the original", t)
94+
}
95+
}
96+
if len(curried) > 0 && curried[0].Kind() == reflect.Func {
97+
return nil, fmt.Errorf("the first curried input, %s, may not be a function", curried[0])
98+
}
99+
100+
var fillInjected func(oi []reflect.Value)
101+
102+
curryFunc := func(inputs []reflect.Value) []reflect.Value {
103+
oi := make([]reflect.Value, originalNumIn)
104+
for i, in := range inputs {
105+
oi[passMap[i]] = in
106+
}
107+
fillInjected(oi)
108+
return o.Call(oi)
109+
}
110+
111+
return Required(MakeReflective(curried, nil, func(inputs []reflect.Value) []reflect.Value {
112+
fillInjected = func(oi []reflect.Value) {
113+
for i, in := range inputs {
114+
oi[curryMap[i]] = in
115+
}
116+
}
117+
n.Elem().Set(reflect.MakeFunc(n.Type().Elem(), curryFunc))
118+
return nil
119+
})), nil
120+
}
121+
122+
// MustSaveTo calls FillVars and panics if FillVars returns an error
123+
//
124+
// EXPERIMENTAL: This is currently considered experimental and could be removed or
125+
// moved to another package. If you're using this, open a pull request to remove
126+
// this comment.
127+
func MustSaveTo(varPointers ...interface{}) Provider {
128+
p, err := SaveTo(varPointers...)
129+
if err != nil {
130+
panic(err)
131+
}
132+
return p
133+
}
134+
8135
// SaveTo generates a required provider. The input parameters to FillVars
9136
// must be pointers. The generated provider takes as inputs the types needed
10137
// to assign through the pointers.
@@ -43,9 +170,9 @@ func SaveTo(varPointers ...interface{}) (Provider, error) {
43170
})), nil
44171
}
45172

46-
// MustSaveTo calls FillVars and panics if FillVars returns an error
47-
func MustSaveTo(varPointers ...interface{}) Provider {
48-
p, err := SaveTo(varPointers...)
173+
// MustCurry calls Curry and panics if Curry returns error
174+
func MustCurry(originalFunction interface{}, pointerToCurriedFunction interface{}) Provider {
175+
p, err := Curry(originalFunction, pointerToCurriedFunction)
49176
if err != nil {
50177
panic(err)
51178
}

utils_test.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package nject
22

33
import (
4+
"fmt"
45
"testing"
56

67
"github.com/stretchr/testify/assert"
@@ -36,3 +37,156 @@ func TestSaveTo(t *testing.T) {
3637
}
3738
}
3839
}
40+
41+
func TestCurry(t *testing.T) {
42+
t.Parallel()
43+
seq := Sequence("available",
44+
func() string { return "foo" },
45+
func() int { return 3 },
46+
func() uint { return 7 },
47+
)
48+
var c1 func(string) string
49+
var c2 func(bool, bool, string, string) string
50+
cases := []struct {
51+
name string
52+
fail string
53+
curry interface{}
54+
check func(t *testing.T)
55+
original interface{}
56+
}{
57+
{
58+
curry: &c1,
59+
original: func(x int, s string) string {
60+
return fmt.Sprintf("%s-%d", s, x)
61+
},
62+
check: func(t *testing.T) {
63+
assert.Equal(t, "bar-3", c1("bar"))
64+
},
65+
},
66+
{
67+
curry: &c1,
68+
original: func(x int, s string, u uint) string {
69+
return fmt.Sprintf("%s-%d/%d", s, x, u)
70+
},
71+
check: func(t *testing.T) {
72+
assert.Equal(t, "bar-3/7", c1("bar"))
73+
},
74+
},
75+
{
76+
curry: &c2,
77+
original: func(b1 bool, x int, b2 bool, s1 string, s2 string, u uint) string {
78+
return fmt.Sprintf("%v-%d-%v %s %s-%d", b1, x, b2, s1, s2, u)
79+
},
80+
check: func(t *testing.T) {
81+
assert.Equal(t, "true-3-false bee boot-7", c2(true, false, "bee", "boot"))
82+
},
83+
},
84+
{
85+
curry: &c2,
86+
original: func(b1 bool, s1 string, s2 string) string { return "" },
87+
fail: "curried function must take fewer arguments",
88+
},
89+
{
90+
curry: &c2,
91+
original: func(b1 bool, b2 bool, b3 bool, s1 string, s2 string) string { return "" },
92+
fail: "original function takes more arguments of type bool",
93+
},
94+
{
95+
name: "no original",
96+
curry: &c2,
97+
fail: "original function is not a valid value",
98+
},
99+
{
100+
name: "no curry",
101+
original: func(b1 bool, b2 bool, b3 bool, s1 string, s2 string) string { return "" },
102+
fail: "curried function is not a valid value",
103+
},
104+
{
105+
name: "non-pointer",
106+
curry: 7,
107+
original: func(b1 bool, b2 bool, b3 bool, s1 string, s2 string) string { return "" },
108+
fail: "pointer (to a function)",
109+
},
110+
{
111+
name: "non-func",
112+
curry: seq,
113+
original: func(b1 bool, b2 bool, b3 bool, s1 string, s2 string) string { return "" },
114+
fail: "pointer to a function",
115+
},
116+
{
117+
curry: &c2,
118+
original: "original non-func",
119+
fail: "first argument to Curry must be a function",
120+
},
121+
{
122+
name: "nil",
123+
curry: (*func())(nil),
124+
original: func(string) {},
125+
fail: "pointer to curried function cannot be nil",
126+
},
127+
{
128+
curry: &c1,
129+
original: func(string) {},
130+
fail: "same number of outputs",
131+
},
132+
{
133+
curry: &c2,
134+
original: func(b1 bool, x int, b2 bool, s1 string, s2 string, u uint) int {
135+
return 22
136+
},
137+
fail: "return value #1 has a different type",
138+
},
139+
{
140+
curry: &c1,
141+
original: func(i1 int, i2 int, s string) string {
142+
return "foo"
143+
},
144+
fail: "cannot curry the same type (int) more than once",
145+
},
146+
{
147+
curry: &c1,
148+
original: func(uint, int) string {
149+
return "foo"
150+
},
151+
fail: "not all of the string inputs to the curried function were used",
152+
},
153+
{
154+
curry: &c1,
155+
original: func(s string, inner func(), i int) string {
156+
return fmt.Sprintf("%s-%d", s, i)
157+
},
158+
fail: "may not be a function",
159+
},
160+
}
161+
162+
for _, tc := range cases {
163+
name := tc.name
164+
if name == "" {
165+
name = fmt.Sprintf("%T", tc.original)
166+
}
167+
t.Run(name, func(t *testing.T) {
168+
var called bool
169+
p, err := Curry(tc.original, tc.curry)
170+
if tc.fail != "" && err != nil {
171+
assert.Contains(t, err.Error(), tc.fail, "curry")
172+
assert.Panics(t, func() {
173+
_ = MustCurry(tc.original, tc.curry)
174+
}, "curry")
175+
return
176+
} else {
177+
if !assert.NoError(t, err, "curry") {
178+
return
179+
}
180+
}
181+
err = Run(name, seq, p, func() { called = true })
182+
if tc.fail != "" {
183+
if assert.Error(t, err, "run") {
184+
assert.Contains(t, err.Error(), tc.fail, "run")
185+
}
186+
return
187+
}
188+
assert.True(t, called, "called")
189+
tc.check(t)
190+
})
191+
}
192+
}

0 commit comments

Comments
 (0)