diff --git a/printer/printer.go b/printer/printer.go index 8b0cb45..a7fd1fa 100644 --- a/printer/printer.go +++ b/printer/printer.go @@ -19,6 +19,21 @@ func Print(o io.Writer, n node.Node) { fn(o, n) } +func printStmt(o io.Writer, n node.Node) { + switch nn := n.(type) { + case *stmt.Nop: + Print(o, nn) + break + case *stmt.StmtList: + io.WriteString(o, " {\n") + printNodes(o, nn.Stmts) + io.WriteString(o, "}\n") + default: + io.WriteString(o, "\n") + Print(o, nn) + } +} + func joinPrint(glue string, o io.Writer, nn []node.Node) { for k, n := range nn { if k > 0 { @@ -300,11 +315,15 @@ func getPrintFuncByNode(n node.Node) func(o io.Writer, n node.Node) { return printStmtConstant case *stmt.Continue: return printStmtContinue + case *stmt.Declare: + return printStmtDeclare case *stmt.StmtList: return printStmtStmtList case *stmt.Expression: return printStmtExpression + case *stmt.Nop: + return printStmtNop } panic("printer is missing for the node") @@ -1430,6 +1449,15 @@ func printStmtContinue(o io.Writer, n node.Node) { io.WriteString(o, ";\n") } +func printStmtDeclare(o io.Writer, n node.Node) { + nn := n.(*stmt.Declare) + + io.WriteString(o, "declare(") + joinPrint(", ", o, nn.Consts) + io.WriteString(o, ")") + printStmt(o, nn.Stmt) +} + func printStmtStmtList(o io.Writer, n node.Node) { nn := n.(*stmt.StmtList) @@ -1443,3 +1471,7 @@ func printStmtExpression(o io.Writer, n node.Node) { io.WriteString(o, ";\n") } + +func printStmtNop(o io.Writer, n node.Node) { + io.WriteString(o, ";\n") +} diff --git a/printer/printer_test.go b/printer/printer_test.go index ad0595a..a47c6f6 100644 --- a/printer/printer_test.go +++ b/printer/printer_test.go @@ -2259,6 +2259,73 @@ func TestPrintStmtContinue(t *testing.T) { } } +func TestPrintStmtDeclareStmts(t *testing.T) { + o := bytes.NewBufferString("") + + printer.Print(o, &stmt.Declare{ + Consts: []node.Node{ + &stmt.Constant{ + ConstantName: &node.Identifier{Value: "FOO"}, + Expr: &scalar.String{Value: "bar"}, + }, + }, + Stmt: &stmt.StmtList{ + Stmts: []node.Node{ + &stmt.Nop{}, + }, + }, + }) + + expected := "declare(FOO = 'bar') {\n;\n}\n" + actual := o.String() + + if expected != actual { + t.Errorf("\nexpected: %s\ngot: %s\n", expected, actual) + } +} + +func TestPrintStmtDeclareExpr(t *testing.T) { + o := bytes.NewBufferString("") + + printer.Print(o, &stmt.Declare{ + Consts: []node.Node{ + &stmt.Constant{ + ConstantName: &node.Identifier{Value: "FOO"}, + Expr: &scalar.String{Value: "bar"}, + }, + }, + Stmt: &stmt.Expression{Expr: &scalar.String{Value: "bar"}}, + }) + + expected := "declare(FOO = 'bar')\n'bar';\n" + actual := o.String() + + if expected != actual { + t.Errorf("\nexpected: %s\ngot: %s\n", expected, actual) + } +} + +func TestPrintStmtDeclareNop(t *testing.T) { + o := bytes.NewBufferString("") + + printer.Print(o, &stmt.Declare{ + Consts: []node.Node{ + &stmt.Constant{ + ConstantName: &node.Identifier{Value: "FOO"}, + Expr: &scalar.String{Value: "bar"}, + }, + }, + Stmt: &stmt.Nop{}, + }) + + expected := "declare(FOO = 'bar');\n" + actual := o.String() + + if expected != actual { + t.Errorf("\nexpected: %s\ngot: %s\n", expected, actual) + } +} + func TestPrintStmtList(t *testing.T) { o := bytes.NewBufferString("") @@ -2289,3 +2356,16 @@ func TestPrintExpression(t *testing.T) { t.Errorf("\nexpected: %s\ngot: %s\n", expected, actual) } } + +func TestPrintNop(t *testing.T) { + o := bytes.NewBufferString("") + + printer.Print(o, &stmt.Nop{}) + + expected := ";\n" + actual := o.String() + + if expected != actual { + t.Errorf("\nexpected: %s\ngot: %s\n", expected, actual) + } +}