Skip to content

Commit 09ba965

Browse files
committed
fix: improve function type narrow by checking params' literal identical
1 parent ba8f90e commit 09ba965

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

changelog.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44
<!-- Add all new changes here. They will be moved under a version at release -->
5+
* `FIX` Improve type narrow by checking exact match on literal type params
56

67
## 3.10.5
78
`2024-8-19`

script/vm/function.lua

+44-11
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,31 @@ local function isAllParamMatched(uri, args, params)
353353
return true
354354
end
355355

356+
---@param uri uri
357+
---@param args parser.object[]
358+
---@param func parser.object
359+
---@return integer
360+
local function calcFunctionMatchScore(uri, args, func)
361+
if vm.isVarargFunctionWithOverloads(func)
362+
or not isAllParamMatched(uri, args, func.args)
363+
then
364+
return -1
365+
end
366+
local matchScore = 0
367+
for i = 1, math.min(#args, #func.args) do
368+
local arg, param = args[i], func.args[i]
369+
-- if param's literals map contains arg's literal, this is the most narrowed exact match
370+
local argLiteral = guide.getLiteral(arg)
371+
if argLiteral ~= nil then
372+
local defLiterals = vm.getLiterals(param)
373+
if defLiterals and defLiterals[argLiteral] then
374+
matchScore = matchScore + 1
375+
end
376+
end
377+
end
378+
return matchScore
379+
end
380+
356381
---@param func parser.object
357382
---@param args? parser.object[]
358383
---@return parser.object[]?
@@ -365,21 +390,29 @@ function vm.getExactMatchedFunctions(func, args)
365390
return funcs
366391
end
367392
local uri = guide.getUri(func)
368-
local needRemove
393+
local matchScores = {}
369394
for i, n in ipairs(funcs) do
370-
if vm.isVarargFunctionWithOverloads(n)
371-
or not isAllParamMatched(uri, args, n.args) then
372-
if not needRemove then
373-
needRemove = {}
374-
end
375-
needRemove[#needRemove+1] = i
376-
end
395+
matchScores[i] = calcFunctionMatchScore(uri, args, n)
396+
end
397+
398+
local maxMatchScore = math.max(table.unpack(matchScores))
399+
if maxMatchScore == -1 then
400+
-- all should be removed
401+
return nil
377402
end
378-
if not needRemove then
403+
404+
local minMatchScore = math.min(table.unpack(matchScores))
405+
if minMatchScore == maxMatchScore then
406+
-- all should be kept
379407
return funcs
380408
end
381-
if #needRemove == #funcs then
382-
return nil
409+
410+
-- remove functions that have matchScore < maxMatchScore
411+
local needRemove = {}
412+
for i, matchScore in ipairs(matchScores) do
413+
if matchScore < maxMatchScore then
414+
needRemove[#needRemove + 1] = i
415+
end
383416
end
384417
util.tableMultiRemove(funcs, needRemove)
385418
return funcs

0 commit comments

Comments
 (0)