@@ -5,6 +5,133 @@ import (
5
5
"reflect"
6
6
)
7
7
8
+ // Curry generates a Requird Provider that prefills arguments to a function to create a
9
+ // new function that needs fewer args.
10
+ //
11
+ // Only arguments with a unique (to the function) type can be curried.
12
+ //
13
+ // The original function and the curried function must have the same outputs.
14
+ //
15
+ // The first curried input may not be a function.
16
+ //
17
+ // EXPERIMENTAL: This is currently considered experimental and could be removed or
18
+ // moved to another package. If you're using this, open a pull request to remove
19
+ // this comment.
20
+ func Curry (originalFunction interface {}, pointerToCurriedFunction interface {}) (Provider , error ) {
21
+ o := reflect .ValueOf (originalFunction )
22
+ if ! o .IsValid () {
23
+ return nil , fmt .Errorf ("original function is not a valid value" )
24
+ }
25
+ if o .Type ().Kind () != reflect .Func {
26
+ return nil , fmt .Errorf ("first argument to Curry must be a function" )
27
+ }
28
+ n := reflect .ValueOf (pointerToCurriedFunction )
29
+ if ! n .IsValid () {
30
+ return nil , fmt .Errorf ("curried function is not a valid value" )
31
+ }
32
+ if n .Type ().Kind () != reflect .Ptr {
33
+ return nil , fmt .Errorf ("curried function must be a pointer (to a function)" )
34
+ }
35
+ if n .Type ().Elem ().Kind () != reflect .Func {
36
+ return nil , fmt .Errorf ("curried function must be a pointer to a function" )
37
+ }
38
+ if n .IsNil () {
39
+ return nil , fmt .Errorf ("pointer to curried function cannot be nil" )
40
+ }
41
+ if o .Type ().NumOut () != n .Type ().Elem ().NumOut () {
42
+ return nil , fmt .Errorf ("current function doesn't have the same number of outputs, %d, as curried function, %d" ,
43
+ o .Type ().NumOut (), n .Type ().Elem ().NumOut ())
44
+ }
45
+ outputs := make ([]reflect.Type , o .Type ().NumOut ())
46
+ for i := 0 ; i < len (outputs ); i ++ {
47
+ if o .Type ().Out (i ) != n .Type ().Elem ().Out (i ) {
48
+ return nil , fmt .Errorf ("current function return value #%d has a different type, %s, than the curried functions return value, %s" ,
49
+ i + 1 , o .Type ().Out (i ), n .Type ().Elem ().Out (i ))
50
+ }
51
+ outputs [i ] = o .Type ().Out (i )
52
+ }
53
+
54
+ // Figure out the set of input types for the curried function
55
+ ntypes := make (map [reflect.Type ][]int )
56
+ for i := 0 ; i < n .Type ().Elem ().NumIn (); i ++ {
57
+ t := n .Type ().Elem ().In (i )
58
+ ntypes [t ] = append (ntypes [t ], i )
59
+ }
60
+
61
+ // Now, for each input in the original function, figure out where it
62
+ // is coming from.
63
+ originalNumIn := o .Type ().NumIn ()
64
+ used := make (map [reflect.Type ]int )
65
+ curryCount := originalNumIn - n .Type ().Elem ().NumIn ()
66
+ if curryCount < 1 {
67
+ return nil , fmt .Errorf ("curried function must take fewer arguments than original function" )
68
+ }
69
+ curried := make ([]reflect.Type , 0 , curryCount ) // injected inputs
70
+ alreadyCurried := make (map [reflect.Type ]struct {}) // to prevent double-dipping
71
+ curryMap := make ([]int , 0 , curryCount ) // maps postion from injected inputs to to original
72
+ passMap := make ([]int , n .Type ().Elem ().NumIn ()) // maps position from curried to original
73
+ for i := 0 ; i < o .Type ().NumIn (); i ++ {
74
+ t := o .Type ().In (i )
75
+ if plist , ok := ntypes [t ]; ok {
76
+ if used [t ] < len (plist ) {
77
+ passMap [plist [used [t ]]] = i
78
+ used [t ]++
79
+ } else {
80
+ return nil , fmt .Errorf ("original function takes more arguments of type %s than the curried function" , t )
81
+ }
82
+ } else {
83
+ if _ , ok := alreadyCurried [t ]; ok {
84
+ return nil , fmt .Errorf ("cannot curry the same type (%s) more than once" , t )
85
+ }
86
+ alreadyCurried [t ] = struct {}{}
87
+ curryMap = append (curryMap , i )
88
+ curried = append (curried , t )
89
+ }
90
+ }
91
+ for t , plist := range ntypes {
92
+ if used [t ] < len (plist ) {
93
+ return nil , fmt .Errorf ("not all of the %s inputs to the curried function were used by the original" , t )
94
+ }
95
+ }
96
+ if len (curried ) > 0 && curried [0 ].Kind () == reflect .Func {
97
+ return nil , fmt .Errorf ("the first curried input, %s, may not be a function" , curried [0 ])
98
+ }
99
+
100
+ var fillInjected func (oi []reflect.Value )
101
+
102
+ curryFunc := func (inputs []reflect.Value ) []reflect.Value {
103
+ oi := make ([]reflect.Value , originalNumIn )
104
+ for i , in := range inputs {
105
+ oi [passMap [i ]] = in
106
+ }
107
+ fillInjected (oi )
108
+ return o .Call (oi )
109
+ }
110
+
111
+ return Required (MakeReflective (curried , nil , func (inputs []reflect.Value ) []reflect.Value {
112
+ fillInjected = func (oi []reflect.Value ) {
113
+ for i , in := range inputs {
114
+ oi [curryMap [i ]] = in
115
+ }
116
+ }
117
+ n .Elem ().Set (reflect .MakeFunc (n .Type ().Elem (), curryFunc ))
118
+ return nil
119
+ })), nil
120
+ }
121
+
122
+ // MustSaveTo calls FillVars and panics if FillVars returns an error
123
+ //
124
+ // EXPERIMENTAL: This is currently considered experimental and could be removed or
125
+ // moved to another package. If you're using this, open a pull request to remove
126
+ // this comment.
127
+ func MustSaveTo (varPointers ... interface {}) Provider {
128
+ p , err := SaveTo (varPointers ... )
129
+ if err != nil {
130
+ panic (err )
131
+ }
132
+ return p
133
+ }
134
+
8
135
// SaveTo generates a required provider. The input parameters to FillVars
9
136
// must be pointers. The generated provider takes as inputs the types needed
10
137
// to assign through the pointers.
@@ -43,9 +170,9 @@ func SaveTo(varPointers ...interface{}) (Provider, error) {
43
170
})), nil
44
171
}
45
172
46
- // MustSaveTo calls FillVars and panics if FillVars returns an error
47
- func MustSaveTo ( varPointers ... interface {}) Provider {
48
- p , err := SaveTo ( varPointers ... )
173
+ // MustCurry calls Curry and panics if Curry returns error
174
+ func MustCurry ( originalFunction interface {}, pointerToCurriedFunction interface {}) Provider {
175
+ p , err := Curry ( originalFunction , pointerToCurriedFunction )
49
176
if err != nil {
50
177
panic (err )
51
178
}
0 commit comments