Skip to content

Commit 742117c

Browse files
committed
fix: improve function type narrow by checking params' literal identical
1 parent ba8f90e commit 742117c

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

changelog.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44
<!-- Add all new changes here. They will be moved under a version at release -->
5+
* `FIX` Improve type narrow by checking exact match on literal type params
56

67
## 3.10.5
78
`2024-8-19`

script/vm/function.lua

+46-11
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,33 @@ local function isAllParamMatched(uri, args, params)
353353
return true
354354
end
355355

356+
---@param uri uri
357+
---@param args parser.object[]
358+
---@param func parser.object
359+
---@return integer
360+
local function calcFunctionMatchScore(uri, args, func)
361+
if vm.isVarargFunctionWithOverloads(func)
362+
or not isAllParamMatched(uri, args, func.args)
363+
then
364+
return -1
365+
end
366+
local matchScore = 0
367+
for i = 1, math.min(#args, #func.args) do
368+
local arg, param = args[i], func.args[i]
369+
local defLiterals = vm.getLiterals(param)
370+
if defLiterals then
371+
for n in vm.compileNode(arg):eachObject() do
372+
-- if param's literals map contains arg's literal, this is the most narrowed exact match
373+
if defLiterals[guide.getLiteral(n)] then
374+
matchScore = matchScore + 1
375+
break
376+
end
377+
end
378+
end
379+
end
380+
return matchScore
381+
end
382+
356383
---@param func parser.object
357384
---@param args? parser.object[]
358385
---@return parser.object[]?
@@ -365,21 +392,29 @@ function vm.getExactMatchedFunctions(func, args)
365392
return funcs
366393
end
367394
local uri = guide.getUri(func)
368-
local needRemove
395+
local matchScores = {}
369396
for i, n in ipairs(funcs) do
370-
if vm.isVarargFunctionWithOverloads(n)
371-
or not isAllParamMatched(uri, args, n.args) then
372-
if not needRemove then
373-
needRemove = {}
374-
end
375-
needRemove[#needRemove+1] = i
376-
end
397+
matchScores[i] = calcFunctionMatchScore(uri, args, n)
398+
end
399+
400+
local maxMatchScore = math.max(table.unpack(matchScores))
401+
if maxMatchScore == -1 then
402+
-- all should be removed
403+
return nil
377404
end
378-
if not needRemove then
405+
406+
local minMatchScore = math.min(table.unpack(matchScores))
407+
if minMatchScore == maxMatchScore then
408+
-- all should be kept
379409
return funcs
380410
end
381-
if #needRemove == #funcs then
382-
return nil
411+
412+
-- remove functions that have matchScore < maxMatchScore
413+
local needRemove = {}
414+
for i, matchScore in ipairs(matchScores) do
415+
if matchScore < maxMatchScore then
416+
needRemove[#needRemove + 1] = i
417+
end
383418
end
384419
util.tableMultiRemove(funcs, needRemove)
385420
return funcs

0 commit comments

Comments
 (0)