Skip to content

Commit cd3396a

Browse files
authored
Merge pull request #124 from hinshun/select-option-func
Allow option funcs to be used from imported
2 parents 77bc530 + b527031 commit cd3396a

File tree

6 files changed

+89
-42
lines changed

6 files changed

+89
-42
lines changed

checker/checker.go

+29-23
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ func (c *checker) CheckSelectors(mod *parser.Module) error {
200200
obj := mod.Scope.Lookup(n.Name())
201201
if obj.Kind == parser.DeclKind {
202202
if _, ok := obj.Node.(*parser.ImportDecl); ok {
203-
importScope := obj.Data.(*parser.Scope)
204-
205203
typ := fun.Type.ObjType
206204
if typ == parser.Option {
207205
// Inherit the secondary type from the calling function name.
@@ -210,7 +208,7 @@ func (c *checker) CheckSelectors(mod *parser.Module) error {
210208

211209
// Check call signature against the imported module's scope since it was
212210
// declared there.
213-
params, err := c.checkCallSignature(importScope, typ, n, args)
211+
params, err := c.checkCallSignature(mod.Scope, typ, n, args)
214212
if err != nil {
215213
c.errs = append(c.errs, err)
216214
return false
@@ -300,7 +298,7 @@ func (c *checker) checkBlockStmt(scope *parser.Scope, typ parser.ObjType, block
300298
}
301299

302300
call := stmt.Call
303-
if call.Func == nil || (call.Func.Ident != nil && call.Func.Ident.Name == "breakpoint") {
301+
if call.Func == nil || call.Func.Name() == "breakpoint" {
304302
continue
305303
}
306304

@@ -343,7 +341,7 @@ func (c *checker) checkBlockStmt(scope *parser.Scope, typ parser.ObjType, block
343341
if !ok {
344342
obj := scope.Lookup(name)
345343
if obj == nil {
346-
return ErrIdentNotDefined{Ident: call.Func.Ident}
344+
return ErrIdentNotDefined{Ident: call.Func.IdentNode()}
347345
}
348346

349347
// The retrieved object may be either a function declaration or a field
@@ -357,7 +355,7 @@ func (c *checker) checkBlockStmt(scope *parser.Scope, typ parser.ObjType, block
357355
case *parser.AliasDecl:
358356
callType = n.Func.Type
359357
case *parser.ImportDecl:
360-
c.errs = append(c.errs, ErrUseModuleWithoutSelector{Ident: call.Func.Ident})
358+
c.errs = append(c.errs, ErrUseModuleWithoutSelector{Ident: call.Func.IdentNode()})
361359
continue
362360
}
363361
case parser.FieldKind:
@@ -383,6 +381,10 @@ func (c *checker) checkBlockStmt(scope *parser.Scope, typ parser.ObjType, block
383381
}
384382

385383
func (c *checker) checkCallStmt(scope *parser.Scope, typ parser.ObjType, call *parser.CallStmt) error {
384+
if call.Func.Selector != nil {
385+
return nil
386+
}
387+
386388
params, err := c.checkCallSignature(scope, typ, call.Func, call.Args)
387389
if err != nil {
388390
return err
@@ -392,30 +394,18 @@ func (c *checker) checkCallStmt(scope *parser.Scope, typ parser.ObjType, call *p
392394
}
393395

394396
func (c *checker) checkCallSignature(scope *parser.Scope, typ parser.ObjType, expr *parser.Expr, args []*parser.Expr) ([]*parser.Field, error) {
395-
var ident *parser.Ident
396-
switch {
397-
case expr.Ident != nil:
398-
ident = expr.Ident
399-
case expr.Selector != nil:
400-
ident = expr.Selector.Select
401-
}
402-
403397
var signature []*parser.Field
404-
fun, ok := builtin.Lookup.ByType[typ].Func[ident.Name]
398+
fun, ok := builtin.Lookup.ByType[typ].Func[expr.Name()]
405399
if !ok && typ == parser.Group {
406-
fun, ok = builtin.Lookup.ByType[parser.Filesystem].Func[ident.Name]
400+
fun, ok = builtin.Lookup.ByType[parser.Filesystem].Func[expr.Name()]
407401
}
408402

409403
if ok {
410404
signature = fun.Params
411405
} else {
412-
obj := scope.Lookup(ident.Name)
406+
obj := scope.Lookup(expr.Name())
413407
if obj == nil {
414-
return nil, ErrIdentUndefined{ident}
415-
}
416-
417-
if expr.Selector != nil && !obj.Exported {
418-
return nil, ErrCallUnexported{expr.Selector}
408+
return nil, ErrIdentUndefined{expr.IdentNode()}
419409
}
420410

421411
if obj.Kind == parser.DeclKind {
@@ -425,7 +415,23 @@ func (c *checker) checkCallSignature(scope *parser.Scope, typ parser.ObjType, ex
425415
case *parser.AliasDecl:
426416
signature = n.Func.Params.List
427417
case *parser.ImportDecl:
428-
panic("todo: ErrCallImport")
418+
importScope := obj.Data.(*parser.Scope)
419+
importObj := importScope.Lookup(expr.Selector.Select.Name)
420+
if importObj == nil {
421+
return nil, ErrIdentUndefined{expr.Selector.Select}
422+
}
423+
if !importObj.Exported {
424+
return nil, ErrCallUnexported{expr.Selector}
425+
}
426+
427+
switch m := importObj.Node.(type) {
428+
case *parser.FuncDecl:
429+
signature = m.Params.List
430+
case *parser.AliasDecl:
431+
signature = m.Func.Params.List
432+
default:
433+
panic("implementation error")
434+
}
429435
default:
430436
panic("implementation error")
431437
}

checker/checker_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -419,14 +419,38 @@ func TestChecker_CheckSelectors(t *testing.T) {
419419
}
420420
`,
421421
nil,
422+
}, {
423+
"use imported option",
424+
`
425+
import myImportedModule "./myModule.hlb"
426+
427+
fs default(string foo) {
428+
image "busybox" with myImportedModule.resolveImage
429+
}
430+
`,
431+
nil,
432+
}, {
433+
"merge imported option",
434+
`
435+
import myImportedModule "./myModule.hlb"
436+
437+
fs default(string foo) {
438+
image "busybox" with option {
439+
myImportedModule.resolveImage
440+
}
441+
}
442+
`,
443+
nil,
422444
}} {
423445
tc := tc
424446
t.Run(tc.name, func(t *testing.T) {
425447
importedModuleDefinition := `
426448
export validSelector
427449
export validSelectorWithArgs
450+
export resolveImage
428451
fs validSelector() {}
429452
fs validSelectorWithArgs(string bar) {}
453+
option::image resolveImage() { resolve; }
430454
`
431455

432456
importedModule, err := parser.Parse(strings.NewReader(importedModuleDefinition))

codegen/chain.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ func (cg *CodeGen) EmitFilesystemChainStmt(ctx context.Context, scope *parser.Sc
8282
case *parser.ImportDecl:
8383
importScope := obj.Data.(*parser.Scope)
8484
importObj := importScope.Lookup(expr.Selector.Select.Name)
85+
if importObj == nil {
86+
return nil, errors.WithStack(ErrCodeGen{expr, errors.Errorf("could not find reference")})
87+
}
88+
8589
switch m := importObj.Node.(type) {
8690
case *parser.FuncDecl:
8791
v, err = cg.EmitFuncDecl(ctx, scope, m, args, ac, chainStart)
@@ -314,7 +318,7 @@ func (cg *CodeGen) EmitFilesystemBuiltinChainStmt(ctx context.Context, scope *pa
314318
// to be in the context of a specific function run is in.
315319
case with.Expr.FuncLit != nil:
316320
for _, stmt := range with.Expr.FuncLit.Body.NonEmptyStmts() {
317-
if stmt.Call.Func.Ident.Name != "mount" || stmt.Call.Alias == nil {
321+
if stmt.Call.Func.Name() != "mount" || stmt.Call.Alias == nil {
318322
continue
319323
}
320324

codegen/codegen.go

+23-14
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ type aliasCallback func(*parser.CallStmt, interface{}) bool
285285
func noopAliasCallback(_ *parser.CallStmt, _ interface{}) bool { return true }
286286

287287
func isBreakpoint(call *parser.CallStmt) bool {
288-
return call.Func.Ident != nil && call.Func.Ident.Name == "breakpoint"
288+
return call.Func.Name() == "breakpoint"
289289
}
290290

291291
func (cg *CodeGen) EmitBlock(ctx context.Context, scope *parser.Scope, typ parser.ObjType, stmts []*parser.Stmt, ac aliasCallback, chainStart interface{}) (interface{}, error) {
@@ -435,7 +435,7 @@ func (cg *CodeGen) EmitImageOptions(ctx context.Context, scope *parser.Scope, op
435435
for _, stmt := range stmts {
436436
if stmt.Call != nil {
437437
args := stmt.Call.Args
438-
switch stmt.Call.Func.Ident.Name {
438+
switch stmt.Call.Func.Name() {
439439
case "resolve":
440440
v, err := cg.MaybeEmitBoolExpr(ctx, scope, args)
441441
if err != nil {
@@ -463,6 +463,15 @@ func (cg *CodeGen) EmitOptionLookup(ctx context.Context, scope *parser.Scope, ex
463463
switch n := obj.Node.(type) {
464464
case *parser.FuncDecl:
465465
return cg.EmitOptionFuncDecl(ctx, scope, n, args)
466+
case *parser.ImportDecl:
467+
importScope := obj.Data.(*parser.Scope)
468+
importObj := importScope.Lookup(expr.Selector.Select.Name)
469+
switch m := importObj.Node.(type) {
470+
case *parser.FuncDecl:
471+
return cg.EmitOptionFuncDecl(ctx, scope, m, args)
472+
default:
473+
return opts, errors.WithStack(ErrCodeGen{expr, errors.Errorf("unknown option decl kind")})
474+
}
466475
default:
467476
return opts, errors.WithStack(ErrCodeGen{expr, errors.Errorf("unknown option decl kind")})
468477
}
@@ -480,7 +489,7 @@ func (cg *CodeGen) EmitHTTPOptions(ctx context.Context, scope *parser.Scope, op
480489
for _, stmt := range stmts {
481490
if stmt.Call != nil {
482491
args := stmt.Call.Args
483-
switch stmt.Call.Func.Ident.Name {
492+
switch stmt.Call.Func.Name() {
484493
case "checksum":
485494
dgst, err := cg.EmitStringExpr(ctx, scope, args[0])
486495
if err != nil {
@@ -515,7 +524,7 @@ func (cg *CodeGen) EmitGitOptions(ctx context.Context, scope *parser.Scope, op s
515524
for _, stmt := range stmts {
516525
if stmt.Call != nil {
517526
args := stmt.Call.Args
518-
switch stmt.Call.Func.Ident.Name {
527+
switch stmt.Call.Func.Name() {
519528
case "keepGitDir":
520529
v, err := cg.MaybeEmitBoolExpr(ctx, scope, args)
521530
if err != nil {
@@ -540,7 +549,7 @@ func (cg *CodeGen) EmitLocalOptions(ctx context.Context, scope *parser.Scope, op
540549
for _, stmt := range stmts {
541550
if stmt.Call != nil {
542551
args := stmt.Call.Args
543-
switch stmt.Call.Func.Ident.Name {
552+
switch stmt.Call.Func.Name() {
544553
case "includePatterns":
545554
patterns := make([]string, len(args))
546555
for i, arg := range args {
@@ -598,7 +607,7 @@ func (cg *CodeGen) EmitFrontendOptions(ctx context.Context, scope *parser.Scope,
598607
for _, stmt := range stmts {
599608
if stmt.Call != nil {
600609
args := stmt.Call.Args
601-
switch stmt.Call.Func.Ident.Name {
610+
switch stmt.Call.Func.Name() {
602611
case "input":
603612
key, err := cg.EmitStringExpr(ctx, scope, args[0])
604613
if err != nil {
@@ -639,7 +648,7 @@ func (cg *CodeGen) EmitMkdirOptions(ctx context.Context, scope *parser.Scope, op
639648
for _, stmt := range stmts {
640649
if stmt.Call != nil {
641650
args := stmt.Call.Args
642-
switch stmt.Call.Func.Ident.Name {
651+
switch stmt.Call.Func.Name() {
643652
case "createParents":
644653
v, err := cg.MaybeEmitBoolExpr(ctx, scope, args)
645654
if err != nil {
@@ -680,7 +689,7 @@ func (cg *CodeGen) EmitMkfileOptions(ctx context.Context, scope *parser.Scope, o
680689
for _, stmt := range stmts {
681690
if stmt.Call != nil {
682691
args := stmt.Call.Args
683-
switch stmt.Call.Func.Ident.Name {
692+
switch stmt.Call.Func.Name() {
684693
case "chown":
685694
owner, err := cg.EmitStringExpr(ctx, scope, args[0])
686695
if err != nil {
@@ -715,7 +724,7 @@ func (cg *CodeGen) EmitRmOptions(ctx context.Context, scope *parser.Scope, op st
715724
for _, stmt := range stmts {
716725
if stmt.Call != nil {
717726
args := stmt.Call.Args
718-
switch stmt.Call.Func.Ident.Name {
727+
switch stmt.Call.Func.Name() {
719728
case "allowNotFound":
720729
v, err := cg.MaybeEmitBoolExpr(ctx, scope, args)
721730
if err != nil {
@@ -801,7 +810,7 @@ func (cg *CodeGen) EmitCopyOptions(ctx context.Context, scope *parser.Scope, op
801810
for _, stmt := range stmts {
802811
if stmt.Call != nil {
803812
args := stmt.Call.Args
804-
switch stmt.Call.Func.Ident.Name {
813+
switch stmt.Call.Func.Name() {
805814
case "followSymlinks":
806815
follow, err := cg.MaybeEmitBoolExpr(ctx, scope, args)
807816
if err != nil {
@@ -882,7 +891,7 @@ func (cg *CodeGen) EmitTemplateOptions(ctx context.Context, scope *parser.Scope,
882891
for _, stmt := range stmts {
883892
if stmt.Call != nil {
884893
args := stmt.Call.Args
885-
switch stmt.Call.Func.Ident.Name {
894+
switch stmt.Call.Func.Name() {
886895
case "stringField":
887896
name, err := cg.EmitStringExpr(ctx, scope, args[0])
888897
if err != nil {
@@ -1307,7 +1316,7 @@ func (cg *CodeGen) EmitSSHOptions(ctx context.Context, scope *parser.Scope, op s
13071316
for _, stmt := range stmts {
13081317
if stmt.Call != nil {
13091318
args := stmt.Call.Args
1310-
switch stmt.Call.Func.Ident.Name {
1319+
switch stmt.Call.Func.Name() {
13111320
case "target":
13121321
target, err := cg.EmitStringExpr(ctx, scope, args[0])
13131322
if err != nil {
@@ -1399,7 +1408,7 @@ func (cg *CodeGen) EmitSecretOptions(ctx context.Context, scope *parser.Scope, o
13991408
for _, stmt := range stmts {
14001409
if stmt.Call != nil {
14011410
args := stmt.Call.Args
1402-
switch stmt.Call.Func.Ident.Name {
1411+
switch stmt.Call.Func.Name() {
14031412
case "id":
14041413
id, err := cg.EmitStringExpr(ctx, scope, args[0])
14051414
if err != nil {
@@ -1476,7 +1485,7 @@ func (cg *CodeGen) EmitMountOptions(ctx context.Context, scope *parser.Scope, op
14761485
for _, stmt := range stmts {
14771486
if stmt.Call != nil {
14781487
args := stmt.Call.Args
1479-
switch stmt.Call.Func.Ident.Name {
1488+
switch stmt.Call.Func.Name() {
14801489
case "readonly":
14811490
v, err := cg.MaybeEmitBoolExpr(ctx, scope, args)
14821491
if err != nil {

codegen/debug.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func NewDebugger(c *client.Client, w io.Writer, r *bufio.Reader, ibs map[string]
149149
Func: n,
150150
}
151151
case *parser.CallStmt:
152-
if n.Func.Ident.Name == "breakpoint" {
152+
if n.Func.Name() == "breakpoint" {
153153
fmt.Fprintf(w, "%s cannot break at breakpoint\n", checker.FormatPos(n.Pos))
154154
continue
155155
}

parser/cst.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -357,13 +357,17 @@ func (e *Expr) End() lexer.Position {
357357
}
358358

359359
func (e *Expr) Name() string {
360+
return e.IdentNode().Name
361+
}
362+
363+
func (e *Expr) IdentNode() *Ident {
360364
switch {
361365
case e.Selector != nil:
362-
return e.Selector.Ident.Name
366+
return e.Selector.Ident
363367
case e.Ident != nil:
364-
return e.Ident.Name
368+
return e.Ident
365369
default:
366-
return ""
370+
return &Ident{}
367371
}
368372
}
369373

0 commit comments

Comments
 (0)