feat: add enter and leave checks for traversers
- Improved performance (can short circuit/not continue when done) - Backwards compatible because added interfaces - Now generating the implementation for easier updates
This commit is contained in:
File diff suppressed because it is too large
Load Diff
173
pkg/visitor/traverser/traverser_gen.go
Normal file
173
pkg/visitor/traverser/traverser_gen.go
Normal file
@@ -0,0 +1,173 @@
|
||||
//go:build ignore
|
||||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"text/template"
|
||||
|
||||
phpast "github.com/VKCOM/php-parser/pkg/ast"
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
|
||||
type nodeType struct {
|
||||
// The type name.
|
||||
Name string
|
||||
// The function name.
|
||||
FuncName string
|
||||
// Map of field names to whether the field is an array or not.
|
||||
Traversable map[string]bool
|
||||
}
|
||||
|
||||
type templData struct {
|
||||
Types []nodeType
|
||||
}
|
||||
|
||||
var fileTempl = template.Must(
|
||||
template.New("").Parse(`// Code generated by "go generate go run traverser_gen.go"; DO NOT EDIT.
|
||||
|
||||
package traverser
|
||||
|
||||
import (
|
||||
"github.com/VKCOM/php-parser/pkg/ast"
|
||||
// Importing packages here, so that go mod tidy does not remove the dependency on it.
|
||||
// It is used in traverser_gen.go but that is ignored with go mod tidy.
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
{{range $typ := .Types}}
|
||||
func (t *Traverser) {{$typ.FuncName}}(n *ast.{{$typ.Name}}) {
|
||||
if !t.checkEntrance(n) {
|
||||
return
|
||||
}
|
||||
|
||||
n.Accept(t.v)
|
||||
{{range $name, $isArray := $typ.Traversable}}
|
||||
{{- if $isArray}}
|
||||
for _, nn := range n.{{$name}} {
|
||||
nn.Accept(t)
|
||||
}
|
||||
{{else}}
|
||||
t.Traverse(n.{{$name}})
|
||||
{{end}}
|
||||
{{- end}}
|
||||
t.leave(n)
|
||||
}
|
||||
{{end}}`),
|
||||
)
|
||||
|
||||
func main() {
|
||||
ast := astOfFile("../../ast/node.go")
|
||||
nodeTypes := getNodeTypes(ast)
|
||||
|
||||
file, err := os.Create("traverser_impl.go")
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("open 'traverser_impl.go': %w", err))
|
||||
}
|
||||
|
||||
fileTempl.Execute(file, templData{Types: nodeTypes})
|
||||
}
|
||||
|
||||
func astOfFile(path string) *ast.File {
|
||||
nodesFile, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("getting '%s' absolute path: %w", path, err))
|
||||
}
|
||||
|
||||
cfg := &packages.Config{
|
||||
Mode: packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
|
||||
Tests: false,
|
||||
}
|
||||
pkgs, err := packages.Load(cfg, fmt.Sprintf("file=%s", nodesFile))
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("loading '%s' package: %w", nodesFile, err))
|
||||
}
|
||||
|
||||
if len(pkgs) == 0 {
|
||||
panic(fmt.Errorf("could not get any package for file '%s'", nodesFile))
|
||||
}
|
||||
|
||||
pkg := pkgs[0]
|
||||
var syntax *ast.File
|
||||
for i, fn := range pkg.GoFiles {
|
||||
if fn == nodesFile {
|
||||
syntax = pkg.Syntax[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if syntax == nil {
|
||||
panic(fmt.Errorf("could not get ast of file: '%s'", nodesFile))
|
||||
}
|
||||
|
||||
return syntax
|
||||
}
|
||||
|
||||
func getNodeTypes(syntax *ast.File) []nodeType {
|
||||
types := []nodeType{}
|
||||
for _, decl := range syntax.Decls {
|
||||
typedDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok {
|
||||
log.Printf("%T not *ast.GenDecl", decl)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, spec := range typedDecl.Specs {
|
||||
typedSpec, ok := spec.(*ast.TypeSpec)
|
||||
if !ok {
|
||||
if _, ok := spec.(*ast.ImportSpec); !ok {
|
||||
log.Printf("%T not *ast.TypeSpec", spec)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
structType, ok := typedSpec.Type.(*ast.StructType)
|
||||
if !ok {
|
||||
log.Printf("%T not *ast.StructType", spec)
|
||||
continue
|
||||
}
|
||||
|
||||
nType := nodeType{
|
||||
Name: typedSpec.Name.String(),
|
||||
FuncName: typedSpec.Name.String(),
|
||||
Traversable: map[string]bool{},
|
||||
}
|
||||
if funcName, ok := phpast.TypeToVisitorNameMap[typedSpec.Name.String()]; ok {
|
||||
nType.FuncName = funcName
|
||||
}
|
||||
|
||||
for _, field := range structType.Fields.List {
|
||||
if traversable, isArray := checkType(field.Type); traversable {
|
||||
for _, n := range field.Names {
|
||||
nType.Traversable[n.String()] = isArray
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
types = append(types, nType)
|
||||
}
|
||||
}
|
||||
|
||||
return types
|
||||
}
|
||||
|
||||
func checkType(t ast.Expr) (traversable bool, isArray bool) {
|
||||
switch ft := t.(type) {
|
||||
case *ast.ArrayType:
|
||||
ok, _ := checkType(ft.Elt)
|
||||
return ok, true
|
||||
case *ast.Ident:
|
||||
switch ft.Name {
|
||||
case "Vertex":
|
||||
return true, false
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
2424
pkg/visitor/traverser/traverser_impl.go
Normal file
2424
pkg/visitor/traverser/traverser_impl.go
Normal file
File diff suppressed because it is too large
Load Diff
67
pkg/visitor/traverser/traverser_test.go
Normal file
67
pkg/visitor/traverser/traverser_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package traverser_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/VKCOM/php-parser/pkg/ast"
|
||||
"github.com/VKCOM/php-parser/pkg/visitor"
|
||||
"github.com/VKCOM/php-parser/pkg/visitor/traverser"
|
||||
)
|
||||
|
||||
// testCase is a traverser that does not go into any class statement or its children.
|
||||
// If it does, the test fails.
|
||||
type testCase struct {
|
||||
t *testing.T
|
||||
visitor.Null
|
||||
traversedFunction bool
|
||||
}
|
||||
|
||||
var _ ast.Visitor = &testCase{}
|
||||
|
||||
func (t *testCase) EnterNode(n ast.Vertex) bool {
|
||||
t.t.Logf("EnterNode: %T", n)
|
||||
if _, ok := n.(*ast.StmtClass); ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *testCase) LeaveNode(n ast.Vertex) {
|
||||
t.t.Logf("LeaveNode: %T", n)
|
||||
if _, ok := n.(*ast.Root); ok {
|
||||
if !t.traversedFunction {
|
||||
t.t.Error("traverser did not traverse function")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *testCase) StmtClass(n *ast.StmtClass) {
|
||||
t.t.Errorf("traverser got to class")
|
||||
}
|
||||
|
||||
func (t *testCase) StmtClassMethod(n *ast.StmtClassMethod) {
|
||||
t.t.Errorf("traverser got to method")
|
||||
}
|
||||
|
||||
func (t *testCase) StmtFunction(n *ast.StmtFunction) {
|
||||
t.traversedFunction = true
|
||||
}
|
||||
|
||||
func TestEnterNodeIsRespected(t *testing.T) {
|
||||
tc := &testCase{t: t}
|
||||
tv := traverser.NewTraverser(tc)
|
||||
|
||||
root := &ast.Root{
|
||||
Stmts: []ast.Vertex{
|
||||
&ast.StmtFunction{},
|
||||
&ast.StmtClass{
|
||||
Stmts: []ast.Vertex{
|
||||
&ast.StmtClassMethod{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
root.Accept(tv)
|
||||
}
|
||||
Reference in New Issue
Block a user