diff --git a/src/transformation/visitors/break-continue.ts b/src/transformation/visitors/break-continue.ts index 9211d4165..9e476b17e 100644 --- a/src/transformation/visitors/break-continue.ts +++ b/src/transformation/visitors/break-continue.ts @@ -6,12 +6,8 @@ import { unsupportedForTarget } from "../utils/diagnostics"; import { findScope, ScopeType } from "../utils/scope"; export const transformBreakStatement: FunctionVisitor = (breakStatement, context) => { - const breakableScope = findScope(context, ScopeType.Loop | ScopeType.Switch); - if (breakableScope?.type === ScopeType.Switch) { - return lua.createGotoStatement(`____switch${breakableScope.id}_end`); - } else { - return lua.createBreakStatement(breakStatement); - } + void context; + return lua.createBreakStatement(breakStatement); }; export const transformContinueStatement: FunctionVisitor = (statement, context) => { diff --git a/src/transformation/visitors/switch.ts b/src/transformation/visitors/switch.ts index a82ff0bf0..1a05bae41 100644 --- a/src/transformation/visitors/switch.ts +++ b/src/transformation/visitors/switch.ts @@ -1,61 +1,172 @@ import * as ts from "typescript"; -import { LuaTarget } from "../../CompilerOptions"; import * as lua from "../../LuaAST"; -import { FunctionVisitor } from "../context"; -import { unsupportedForTarget } from "../utils/diagnostics"; +import { FunctionVisitor, TransformationContext } from "../context"; import { performHoisting, popScope, pushScope, ScopeType } from "../utils/scope"; -export const transformSwitchStatement: FunctionVisitor = (statement, context) => { - if (context.luaTarget === LuaTarget.Universal || context.luaTarget === LuaTarget.Lua51) { - context.diagnostics.push(unsupportedForTarget(statement, "Switch statements", LuaTarget.Lua51)); +const containsBreakOrReturn = (nodes: Iterable): boolean => { + for (const s of nodes) { + if (ts.isBreakStatement(s) || ts.isReturnStatement(s)) { + return true; + } else if (ts.isBlock(s) && containsBreakOrReturn(s.getChildren())) { + return true; + } else if (s.kind === ts.SyntaxKind.SyntaxList && containsBreakOrReturn(s.getChildren())) { + return true; + } + } + + return false; +}; + +const coalesceCondition = ( + condition: lua.Expression | undefined, + switchVariable: lua.Identifier, + expression: ts.Expression, + context: TransformationContext +): lua.Expression => { + // Coalesce skipped statements + if (condition) { + return lua.createBinaryExpression( + condition, + lua.createBinaryExpression( + switchVariable, + context.transformExpression(expression), + lua.SyntaxKind.EqualityOperator + ), + lua.SyntaxKind.OrOperator + ); } + // Next condition + return lua.createBinaryExpression( + switchVariable, + context.transformExpression(expression), + lua.SyntaxKind.EqualityOperator + ); +}; + +export const transformSwitchStatement: FunctionVisitor = (statement, context) => { const scope = pushScope(context, ScopeType.Switch); - // Give the switch a unique name to prevent nested switches from acting up. + // Give the switch and condition accumulator a unique name to prevent nested switches from acting up. const switchName = `____switch${scope.id}`; + const conditionName = `____cond${scope.id}`; const switchVariable = lua.createIdentifier(switchName); + const conditionVariable = lua.createIdentifier(conditionName); + // If the switch only has a default clause, wrap it in a single do. + // Otherwise, we need to generate a set of if statements to emulate the switch. let statements: lua.Statement[] = []; - - // Starting from the back, concatenating ifs into one big if/elseif statement - const concatenatedIf = statement.caseBlock.clauses.reduceRight((previousCondition, clause, index) => { - if (ts.isDefaultClause(clause)) { - // Skip default clause here (needs to be included to ensure index lines up with index later) - return previousCondition; + const clauses = statement.caseBlock.clauses; + if (clauses.length === 1 && ts.isDefaultClause(clauses[0])) { + const defaultClause = clauses[0].statements; + if (defaultClause.length) { + statements.push(lua.createDoStatement(context.transformStatements(defaultClause))); } + } else { + // Build up the condition for each if statement + let isInitialCondition = true; + let condition: lua.Expression | undefined = undefined; + for (let i = 0; i < clauses.length; i++) { + const clause = clauses[i]; + const previousClause: ts.CaseOrDefaultClause | undefined = clauses[i - 1]; - // If the clause condition holds, go to the correct label - const condition = lua.createBinaryExpression( - switchVariable, - context.transformExpression(clause.expression), - lua.SyntaxKind.EqualityOperator - ); + // Skip redundant default clauses, will be handled in final default case + if (i === 0 && ts.isDefaultClause(clause)) continue; + if (ts.isDefaultClause(clause) && previousClause && containsBreakOrReturn(previousClause.statements)) { + continue; + } - const goto = lua.createGotoStatement(`${switchName}_case_${index}`); - return lua.createIfStatement(condition, lua.createBlock([goto]), previousCondition); - }, undefined as lua.IfStatement | undefined); + // Compute the condition for the if statement + if (!ts.isDefaultClause(clause)) { + condition = coalesceCondition(condition, switchVariable, clause.expression, context); - if (concatenatedIf) { - statements.push(concatenatedIf); - } + // Skip empty clauses unless final clause (i.e side-effects) + if (i !== clauses.length - 1 && clause.statements.length === 0) continue; - const hasDefaultCase = statement.caseBlock.clauses.some(ts.isDefaultClause); - statements.push(lua.createGotoStatement(`${switchName}_${hasDefaultCase ? "case_default" : "end"}`)); + // Declare or assign condition variable + statements.push( + isInitialCondition + ? lua.createVariableDeclarationStatement(conditionVariable, condition) + : lua.createAssignmentStatement( + conditionVariable, + lua.createBinaryExpression(conditionVariable, condition, lua.SyntaxKind.OrOperator) + ) + ); + isInitialCondition = false; + } else { + // If the default is proceeded by empty clauses and will be emitted we may need to initialize the condition + if (isInitialCondition) { + statements.push( + lua.createVariableDeclarationStatement( + conditionVariable, + condition ?? lua.createBooleanLiteral(false) + ) + ); - for (const [index, clause] of statement.caseBlock.clauses.entries()) { - const labelName = `${switchName}_case_${ts.isCaseClause(clause) ? index : "default"}`; - statements.push(lua.createLabelStatement(labelName)); - statements.push(lua.createDoStatement(context.transformStatements(clause.statements))); - } + // Clear condition ot ensure it is not evaluated twice + condition = undefined; + isInitialCondition = false; + } + + // Allow default to fallthrough to final default clause + if (i === clauses.length - 1) { + // Evaluate the final condition that we may be skipping + if (condition) { + statements.push( + lua.createAssignmentStatement( + conditionVariable, + lua.createBinaryExpression(conditionVariable, condition, lua.SyntaxKind.OrOperator) + ) + ); + } + continue; + } + } - statements.push(lua.createLabelStatement(`${switchName}_end`)); + // Transform the clause and append the final break statement if necessary + const clauseStatements = context.transformStatements(clause.statements); + if (i === clauses.length - 1 && !containsBreakOrReturn(clause.statements)) { + clauseStatements.push(lua.createBreakStatement()); + } + + // Push if statement for case + statements.push(lua.createIfStatement(conditionVariable, lua.createBlock(clauseStatements))); + + // Clear condition for next clause + condition = undefined; + } + + // If no conditions above match, we need to create the final default case code-path, + // as we only handle fallthrough into defaults in the previous if statement chain + const start = clauses.findIndex(c => ts.isDefaultClause(c)); + if (start >= 0) { + // Find the last clause that we can fallthrough to + const end = clauses.findIndex( + (clause, index) => index >= start && containsBreakOrReturn(clause.statements) + ); + + // Combine the default and all fallthrough statements + const defaultStatements: lua.Statement[] = []; + clauses + .slice(start, end >= 0 ? end + 1 : undefined) + .forEach(c => defaultStatements.push(...context.transformStatements(c.statements))); + + // Add the default clause if it has any statements + // The switch will always break on the final clause and skip execution if valid to do so + if (defaultStatements.length) { + statements.push(lua.createDoStatement(defaultStatements)); + } + } + } + // Hoist the variable, function, and import statements to the top of the switch statements = performHoisting(context, statements); popScope(context); + // Add the switch expression after hoisting const expression = context.transformExpression(statement.expression); statements.unshift(lua.createVariableDeclarationStatement(switchVariable, expression)); - return statements; + // Wrap the statements in a repeat until true statement to facilitate dynamic break/returns + return lua.createRepeatStatement(lua.createBlock(statements), lua.createBooleanLiteral(true)); }; diff --git a/test/unit/__snapshots__/switch.spec.ts.snap b/test/unit/__snapshots__/switch.spec.ts.snap index 19a763566..0fa333d93 100644 --- a/test/unit/__snapshots__/switch.spec.ts.snap +++ b/test/unit/__snapshots__/switch.spec.ts.snap @@ -1,53 +1,80 @@ // Jest Snapshot v1, https://goo.gl/fbAQLP -exports[`switch not allowed in 5.1: code 1`] = ` -"local ____exports = {} +exports[`switch empty fallthrough to default (0) 1`] = ` +"require(\\"lualib_bundle\\"); +local ____exports = {} function ____exports.__main(self) - local ____switch3 = \\"abc\\" - goto ____switch3_end - ::____switch3_end:: + local out = {} + repeat + local ____switch3 = 0 + local ____cond3 = ____switch3 == 1 + do + __TS__ArrayPush(out, \\"default\\") + end + until true + return out end return ____exports" `; -exports[`switch not allowed in 5.1: diagnostics 1`] = `"main.ts(2,9): error TSTL: Switch statements is/are not supported for target Lua 5.1."`; - -exports[`switch uses elseif 1`] = ` -"local ____exports = {} +exports[`switch empty fallthrough to default (1) 1`] = ` +"require(\\"lualib_bundle\\"); +local ____exports = {} function ____exports.__main(self) - local result = -1 - local ____switch3 = 2 - if ____switch3 == 0 then - goto ____switch3_case_0 - elseif ____switch3 == 1 then - goto ____switch3_case_1 - elseif ____switch3 == 2 then - goto ____switch3_case_2 - end - goto ____switch3_end - ::____switch3_case_0:: - do + local out = {} + repeat + local ____switch3 = 1 + local ____cond3 = ____switch3 == 1 do - result = 200 - goto ____switch3_end + __TS__ArrayPush(out, \\"default\\") end - end - ::____switch3_case_1:: - do - do - result = 100 - goto ____switch3_end + until true + return out +end +return ____exports" +`; + +exports[`switch produces optimal output 1`] = ` +"require(\\"lualib_bundle\\"); +local ____exports = {} +function ____exports.__main(self) + local x = 0 + local out = {} + repeat + local ____switch3 = 0 + local ____cond3 = ((____switch3 == 0) or (____switch3 == 1)) or (____switch3 == 2) + if ____cond3 then + __TS__ArrayPush(out, \\"0,1,2\\") + break + end + ____cond3 = ____cond3 or (____switch3 == 3) + if ____cond3 then + do + __TS__ArrayPush(out, \\"3\\") + break + end + end + ____cond3 = ____cond3 or (____switch3 == 4) + if ____cond3 then + break end - end - ::____switch3_case_2:: - do do - result = 1 - goto ____switch3_end + x = x + 1 + __TS__ArrayPush( + out, + \\"default = \\" .. tostring(x) + ) + do + __TS__ArrayPush(out, \\"3\\") + break + end end - end - ::____switch3_end:: - return result + until true + __TS__ArrayPush( + out, + tostring(x) + ) + return out end return ____exports" `; diff --git a/test/unit/switch.spec.ts b/test/unit/switch.spec.ts index c264850c2..6ed965595 100644 --- a/test/unit/switch.spec.ts +++ b/test/unit/switch.spec.ts @@ -1,5 +1,3 @@ -import * as tstl from "../../src"; -import { unsupportedForTarget } from "../../src/transformation/utils/diagnostics"; import * as util from "../util"; test.each([0, 1, 2, 3])("switch (%p)", inp => { @@ -209,7 +207,7 @@ test.each([0, 1, 2, 3])("switchWithBrackets (%p)", inp => { `.expectToMatchJsResult(); }); -test.each([0, 1, 2, 3])("switchWithBracketsBreakInConditional (%p)", inp => { +test.each([0, 1, 2, 3, 4])("switchWithBracketsBreakInConditional (%p)", inp => { util.testFunction` let result: number = -1; @@ -225,6 +223,11 @@ test.each([0, 1, 2, 3])("switchWithBracketsBreakInConditional (%p)", inp => { } case 2: { result = 2; + + if (result != 2) break; + } + case 3: { + result = 3; break; } } @@ -261,10 +264,9 @@ test.each([0, 1, 2, 3])("switchWithBracketsBreakInInternalLoop (%p)", inp => { `.expectToMatchJsResult(); }); -test("switch uses elseif", () => { +test("switch executes only one clause", () => { util.testFunction` let result: number = -1; - switch (2 as number) { case 0: { result = 200; @@ -281,19 +283,8 @@ test("switch uses elseif", () => { break; } } - return result; - ` - .expectLuaToMatchSnapshot() - .expectToMatchJsResult(); -}); - -test("switch not allowed in 5.1", () => { - util.testFunction` - switch ("abc") {} - ` - .setOptions({ luaTarget: tstl.LuaTarget.Lua51 }) - .expectDiagnosticsToMatchSnapshot([unsupportedForTarget.code]); + `.expectToMatchJsResult(); }); // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/967 @@ -321,6 +312,17 @@ test("switch default case not last - second", () => { `.expectToMatchJsResult(); }); +test("switch default case only", () => { + util.testFunction` + let out = 0; + switch (4 as number) { + default: + out = 1 + } + return out; + `.expectToMatchJsResult(); +}); + test("switch fallthrough enters default", () => { util.testFunction` const out = []; @@ -359,3 +361,164 @@ test("switch fallthrough stops after default", () => { return out; `.expectToMatchJsResult(); }); + +test.each([0, 1])("switch empty fallthrough to default (%p)", inp => { + util.testFunction` + const out = []; + switch (${inp} as number) { + case 1: + default: + out.push("default"); + + } + return out; + ` + .expectLuaToMatchSnapshot() + .expectToMatchJsResult(); +}); + +test("switch does not pollute parent scope", () => { + util.testFunction` + let x: number = 0; + let y = 1; + switch (x) { + case 0: + let y = 2; + } + return y; + `.expectToMatchJsResult(); +}); + +test.each([0, 1, 2, 3, 4])("switch handles side-effects (%p)", inp => { + util.testFunction` + const out = []; + + let y = 0; + function foo() { + return y++; + } + + let x = ${inp} as number; + switch (x) { + case foo(): + out.push(1); + case foo(): + out.push(2); + case foo(): + out.push(3); + default: + out.push("default"); + case foo(): + } + + out.push(y); + return out; + `.expectToMatchJsResult(); +}); + +test.each([1, 2])("switch handles side-effects with empty fallthrough (%p)", inp => { + util.testFunction` + const out = []; + + let y = 0; + function foo() { + return y++; + } + + let x = 0 as number; + switch (x) { + // empty fallthrough 1 or many times + ${new Array(inp).fill("case foo():").join("\n")} + default: + out.push("default"); + + } + + out.push(y); + return out; + `.expectToMatchJsResult(); +}); + +test.each([1, 2])("switch handles side-effects with empty fallthrough (preceding clause) (%p)", inp => { + util.testFunction` + const out = []; + + let y = 0; + function foo() { + return y++; + } + + let x = 0 as number; + switch (x) { + case 1: + out.push(1); + // empty fallthrough 1 or many times + ${new Array(inp).fill("case foo():").join("\n")} + default: + out.push("default"); + + } + + out.push(y); + return out; + `.expectToMatchJsResult(); +}); + +test.each([0, 1, 2, 3, 4])("switch handles async side-effects (%p)", inp => { + util.testFunction` + (async () => { + const out = []; + + let y = 0; + async function foo() { + return new Promise((resolve) => y++ && resolve(0)); + } + + let x = ${inp} as number; + switch (x) { + case await foo(): + out.push(1); + case await foo(): + out.push(2); + case await foo(): + out.push(3); + default: + out.push("default"); + case await foo(): + } + + out.push(y); + return out; + })(); + `.expectToMatchJsResult(); +}); + +const optimalOutput = (c: number) => util.testFunction` + let x: number = 0; + const out = []; + switch (${c} as number) { + case 0: + case 1: + case 2: + out.push("0,1,2"); + break; + default: + x++; + out.push("default = " + x); + case 3: { + out.push("3"); + break; + } + case 4: + } + out.push(x.toString()); + return out; +`; + +test("switch produces optimal output", () => { + optimalOutput(0).expectLuaToMatchSnapshot(); +}); + +test.each([0, 1, 2, 3, 4, 5])("switch produces valid optimal output (%p)", inp => { + optimalOutput(inp).expectToMatchJsResult(); +});