diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index 4462bf64db97ea238e17df03774310e716387b3d..b7d4650c3e91b1f29da4b4d055aba044f8cecb1e 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -1659,7 +1659,7 @@ local function tryCallArg(state, position, results) return end ---@diagnostic disable-next-line: missing-fields - local node = vm.compileCallArg({ type = 'dummyarg' }, call, argIndex) + local node = vm.compileCallArg({ type = 'dummyarg', uri = state.uri }, call, argIndex) if not node then return end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 7e026474004a52c87164651bd2dc4c3851272b06..8a1fa96ac8b8482d3fccdaa409d5462dc94324c0 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -882,52 +882,69 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) end end - for n in callNode:eachObject() do - if n.type == 'function' then - ---@cast n parser.object - local sign = vm.getSign(n) + ---@param n parser.object + local function dealDocFunc(n) + local myEvent + if n.args[eventIndex] then + local argNode = vm.compileNode(n.args[eventIndex]) + myEvent = argNode:get(1) + end + if not myEvent + or not eventMap + or myIndex <= eventIndex + or myEvent.type ~= 'doc.type.string' + or eventMap[myEvent[1]] then local farg = getFuncArg(n, myIndex) if farg then for fn in vm.compileNode(farg):eachObject() do if isValidCallArgNode(arg, fn) then - if fn.type == 'doc.type.function' then - ---@cast fn parser.object - if sign then - local generic = vm.createGeneric(fn, sign) - local args = {} - for i = fixIndex + 1, myIndex - 1 do - args[#args+1] = call.args[i] - end - local resolvedNode = generic:resolve(guide.getUri(call), args) - vm.setNode(arg, resolvedNode) - goto CONTINUE - end - end vm.setNode(arg, fn) - ::CONTINUE:: end end end end - if n.type == 'doc.type.function' then - ---@cast n parser.object - local myEvent - if n.args[eventIndex] then - local argNode = vm.compileNode(n.args[eventIndex]) - myEvent = argNode:get(1) - end - if not myEvent - or not eventMap - or myIndex <= eventIndex - or myEvent.type ~= 'doc.type.string' - or eventMap[myEvent[1]] then - local farg = getFuncArg(n, myIndex) - if farg then - for fn in vm.compileNode(farg):eachObject() do - if isValidCallArgNode(arg, fn) then - vm.setNode(arg, fn) + end + + ---@param n parser.object + local function dealFunction(n) + local sign = vm.getSign(n) + local farg = getFuncArg(n, myIndex) + if farg then + for fn in vm.compileNode(farg):eachObject() do + if isValidCallArgNode(arg, fn) then + if fn.type == 'doc.type.function' then + ---@cast fn parser.object + if sign then + local generic = vm.createGeneric(fn, sign) + local args = {} + for i = fixIndex + 1, myIndex - 1 do + args[#args+1] = call.args[i] + end + local resolvedNode = generic:resolve(guide.getUri(call), args) + vm.setNode(arg, resolvedNode) + goto CONTINUE end end + vm.setNode(arg, fn) + ::CONTINUE:: + end + end + end + end + + for n in callNode:eachObject() do + if n.type == 'function' then + ---@cast n parser.object + dealFunction(n) + elseif n.type == 'doc.type.function' then + ---@cast n parser.object + dealDocFunc(n) + elseif n.type == 'global' and n.cate == 'type' then + ---@cast n vm.global + local overloads = vm.getOverloadsByTypeName(n.name, guide.getUri(arg)) + if overloads then + for _, func in ipairs(overloads) do + dealDocFunc(func) end end end diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 1f43447581661ebf92e11eadb8259587825ce04d..3cd6bc5d77d49a9122f6ce7816253667777fe491 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -254,7 +254,7 @@ function mt:resolve(uri, args) local argNode = vm.compileNode(arg) local knownTypes, genericNames = getSignInfo(sign) if not isAllResolved(genericNames) then - local newArgNode = buildArgNode(argNode,sign, knownTypes) + local newArgNode = buildArgNode(argNode, sign, knownTypes) resolve(sign, newArgNode) end end diff --git a/script/vm/type.lua b/script/vm/type.lua index 910d79600792d584be904fb922bb42749f4ed55d..545d2de594de84e1d0a0257981b86d3234fccd4b 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -767,3 +767,25 @@ function vm.viewTypeErrorMessage(uri, errs) return table.concat(lines, '\n') end end + +---@param name string +---@param uri uri +---@return parser.object[]? +function vm.getOverloadsByTypeName(name, uri) + local global = vm.getGlobal('type', name) + if not global then + return nil + end + local results + for _, set in ipairs(global:getSets(uri)) do + for _, doc in ipairs(set.bindGroup) do + if doc.type == 'doc.overload' then + if not results then + results = {} + end + results[#results+1] = doc.overload + end + end + end + return results +end diff --git a/test/completion/common.lua b/test/completion/common.lua index 7de1c325bf3a6a18e2f399c4ffb8ad10844ad531..90037c279f20f987c2f69240b7b5856547f6b175 100644 --- a/test/completion/common.lua +++ b/test/completion/common.lua @@ -4393,3 +4393,23 @@ f { kind = define.CompletionItemKind.Property, }, } + +TEST [[ +---@class A +---@overload fun(x: {id: string}) + +---@generic T +---@param t `T` +---@return T +local function new(t) end + +new 'A' { + <??> +} +]] +{ + { + label = 'id', + kind = define.CompletionItemKind.Property, + } +}