//go:build tools package main import ( "fmt" "go/ast" "log" "os" "path/filepath" "text/template" phpast "github.com/laytan/php-parser/pkg/ast" "golang.org/x/tools/go/packages" ) type nodeType struct { // The type name. Name string // The function name. FuncName string // Map of field names to whether the field is an array or not. Traversable map[string]bool } type templData struct { Types []nodeType } var fileTempl = template.Must( template.New("").Parse(`// Code generated by "go generate go run traverser_gen.go"; DO NOT EDIT. package traverser import "github.com/laytan/php-parser/pkg/ast" {{range $typ := .Types}} func (t *Traverser) {{$typ.FuncName}}(n *ast.{{$typ.Name}}) { if !t.checkEntrance(n) { return } n.Accept(t.v) {{range $name, $isArray := $typ.Traversable}} {{- if $isArray}} for _, nn := range n.{{$name}} { nn.Accept(t) } {{else}} t.Traverse(n.{{$name}}) {{end}} {{- end}} t.leave(n) } {{end}}`), ) func main() { ast := astOfFile("../../ast/node.go") nodeTypes := getNodeTypes(ast) file, err := os.Create("traverser_impl.go") if err != nil { panic(fmt.Errorf("open 'traverser_impl.go': %w", err)) } fileTempl.Execute(file, templData{Types: nodeTypes}) } func astOfFile(path string) *ast.File { nodesFile, err := filepath.Abs(path) if err != nil { panic(fmt.Errorf("getting '%s' absolute path: %w", path, err)) } cfg := &packages.Config{ Mode: packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax, Tests: false, } pkgs, err := packages.Load(cfg, fmt.Sprintf("file=%s", nodesFile)) if err != nil { panic(fmt.Errorf("loading '%s' package: %w", nodesFile, err)) } if len(pkgs) == 0 { panic(fmt.Errorf("could not get any package for file '%s'", nodesFile)) } pkg := pkgs[0] var syntax *ast.File for i, fn := range pkg.GoFiles { if fn == nodesFile { syntax = pkg.Syntax[i] break } } if syntax == nil { panic(fmt.Errorf("could not get ast of file: '%s'", nodesFile)) } return syntax } func getNodeTypes(syntax *ast.File) []nodeType { types := []nodeType{} for _, decl := range syntax.Decls { typedDecl, ok := decl.(*ast.GenDecl) if !ok { log.Printf("%T not *ast.GenDecl", decl) continue } for _, spec := range typedDecl.Specs { typedSpec, ok := spec.(*ast.TypeSpec) if !ok { if _, ok := spec.(*ast.ImportSpec); !ok { log.Printf("%T not *ast.TypeSpec", spec) } continue } structType, ok := typedSpec.Type.(*ast.StructType) if !ok { log.Printf("%T not *ast.StructType", spec) continue } nType := nodeType{ Name: typedSpec.Name.String(), FuncName: typedSpec.Name.String(), Traversable: map[string]bool{}, } if funcName, ok := phpast.TypeToVisitorNameMap[typedSpec.Name.String()]; ok { nType.FuncName = funcName } for _, field := range structType.Fields.List { if traversable, isArray := checkType(field.Type); traversable { for _, n := range field.Names { nType.Traversable[n.String()] = isArray } } } types = append(types, nType) } } return types } func checkType(t ast.Expr) (traversable bool, isArray bool) { switch ft := t.(type) { case *ast.ArrayType: ok, _ := checkType(ft.Elt) return ok, true case *ast.Ident: switch ft.Name { case "Vertex": return true, false default: return false, false } default: return false, false } }