diff --git a/src/prometheus/compiler/compiler.lua b/src/prometheus/compiler/compiler.lua index acbbe39..eebb353 100644 --- a/src/prometheus/compiler/compiler.lua +++ b/src/prometheus/compiler/compiler.lua @@ -1969,19 +1969,30 @@ function Compiler:compileExpression(expression, funcDepth, numReturns) retRegs[i] = self:allocRegister(false); end end - local argRegs = {}; - + + local regs = {}; + local args = {}; for i, expr in ipairs(expression.args) do - table.insert(argRegs, self:compileExpression(expr, funcDepth, 1)[1]); + if i == #expression.args and (expr.kind == AstKind.FunctionCallExpression or expr.kind == AstKind.PassSelfFunctionCallExpression or expr.kind == AstKind.VarargExpression) then + local reg = self:compileExpression(expr, funcDepth, self.RETURN_ALL)[1]; + table.insert(args, Ast.FunctionCallExpression( + self:unpack(scope), + {self:register(scope, reg)})); + table.insert(regs, reg); + else + local reg = self:compileExpression(expr, funcDepth, 1)[1]; + table.insert(args, self:register(scope, reg)); + table.insert(regs, reg); + end end if(returnAll) then - self:addStatement(self:setRegister(scope, retRegs[1], Ast.TableConstructorExpression{Ast.TableEntry(Ast.FunctionCallExpression(self:register(scope, baseReg), self:registerList(scope, argRegs)))}), {retRegs[1]}, {baseReg, unpack(argRegs)}, true); + self:addStatement(self:setRegister(scope, retRegs[1], Ast.TableConstructorExpression{Ast.TableEntry(Ast.FunctionCallExpression(self:register(scope, baseReg), args))}), {retRegs[1]}, {baseReg, unpack(regs)}, true); else if(numReturns > 1) then local tmpReg = self:allocRegister(false); - self:addStatement(self:setRegister(scope, tmpReg, Ast.TableConstructorExpression{Ast.TableEntry(Ast.FunctionCallExpression(self:register(scope, baseReg), self:registerList(scope, argRegs)))}), {tmpReg}, {baseReg, unpack(argRegs)}, true); + self:addStatement(self:setRegister(scope, tmpReg, Ast.TableConstructorExpression{Ast.TableEntry(Ast.FunctionCallExpression(self:register(scope, baseReg), args))}), {tmpReg}, {baseReg, unpack(regs)}, true); for i, reg in ipairs(retRegs) do self:addStatement(self:setRegister(scope, reg, Ast.IndexExpression(self:register(scope, tmpReg), Ast.NumberExpression(i))), {reg}, {tmpReg}, false); @@ -1989,7 +2000,7 @@ function Compiler:compileExpression(expression, funcDepth, numReturns) self:freeRegister(tmpReg, false); else - self:addStatement(self:setRegister(scope, retRegs[1], Ast.FunctionCallExpression(self:register(scope, baseReg), self:registerList(scope, argRegs))), {retRegs[1]}, {baseReg, unpack(argRegs)}, true); + self:addStatement(self:setRegister(scope, retRegs[1], Ast.FunctionCallExpression(self:register(scope, baseReg), args)), {retRegs[1]}, {baseReg, unpack(regs)}, true); end end @@ -1997,7 +2008,7 @@ function Compiler:compileExpression(expression, funcDepth, numReturns) self:freeRegister(baseReg, false); - for i, reg in ipairs(argRegs) do + for i, reg in ipairs(regs) do self:freeRegister(reg, false); end