2023-03-26 00:33:03 +00:00
|
|
|
//go:build tools
|
2023-03-25 20:55:24 +00:00
|
|
|
|
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"go/ast"
|
|
|
|
"log"
|
|
|
|
"os"
|
|
|
|
"path/filepath"
|
|
|
|
"text/template"
|
|
|
|
|
2023-12-09 21:36:19 +00:00
|
|
|
phpast "git.maride.cc/maride/php-parser/pkg/ast"
|
2023-03-25 20:55:24 +00:00
|
|
|
"golang.org/x/tools/go/packages"
|
|
|
|
)
|
|
|
|
|
|
|
|
type nodeType struct {
|
|
|
|
// The type name.
|
|
|
|
Name string
|
|
|
|
// The function name.
|
|
|
|
FuncName string
|
|
|
|
// Map of field names to whether the field is an array or not.
|
|
|
|
Traversable map[string]bool
|
|
|
|
}
|
|
|
|
|
|
|
|
type templData struct {
|
|
|
|
Types []nodeType
|
|
|
|
}
|
|
|
|
|
|
|
|
var fileTempl = template.Must(
|
|
|
|
template.New("").Parse(`// Code generated by "go generate go run traverser_gen.go"; DO NOT EDIT.
|
|
|
|
|
|
|
|
package traverser
|
|
|
|
|
2023-12-09 21:36:19 +00:00
|
|
|
import "git.maride.cc/maride/php-parser/pkg/ast"
|
2023-03-25 20:55:24 +00:00
|
|
|
{{range $typ := .Types}}
|
|
|
|
func (t *Traverser) {{$typ.FuncName}}(n *ast.{{$typ.Name}}) {
|
|
|
|
if !t.checkEntrance(n) {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
n.Accept(t.v)
|
|
|
|
{{range $name, $isArray := $typ.Traversable}}
|
|
|
|
{{- if $isArray}}
|
|
|
|
for _, nn := range n.{{$name}} {
|
|
|
|
nn.Accept(t)
|
|
|
|
}
|
|
|
|
{{else}}
|
|
|
|
t.Traverse(n.{{$name}})
|
|
|
|
{{end}}
|
|
|
|
{{- end}}
|
|
|
|
t.leave(n)
|
|
|
|
}
|
|
|
|
{{end}}`),
|
|
|
|
)
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
ast := astOfFile("../../ast/node.go")
|
|
|
|
nodeTypes := getNodeTypes(ast)
|
|
|
|
|
|
|
|
file, err := os.Create("traverser_impl.go")
|
|
|
|
if err != nil {
|
|
|
|
panic(fmt.Errorf("open 'traverser_impl.go': %w", err))
|
|
|
|
}
|
|
|
|
|
|
|
|
fileTempl.Execute(file, templData{Types: nodeTypes})
|
|
|
|
}
|
|
|
|
|
|
|
|
func astOfFile(path string) *ast.File {
|
|
|
|
nodesFile, err := filepath.Abs(path)
|
|
|
|
if err != nil {
|
|
|
|
panic(fmt.Errorf("getting '%s' absolute path: %w", path, err))
|
|
|
|
}
|
|
|
|
|
|
|
|
cfg := &packages.Config{
|
|
|
|
Mode: packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
|
|
|
|
Tests: false,
|
|
|
|
}
|
|
|
|
pkgs, err := packages.Load(cfg, fmt.Sprintf("file=%s", nodesFile))
|
|
|
|
if err != nil {
|
|
|
|
panic(fmt.Errorf("loading '%s' package: %w", nodesFile, err))
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(pkgs) == 0 {
|
|
|
|
panic(fmt.Errorf("could not get any package for file '%s'", nodesFile))
|
|
|
|
}
|
|
|
|
|
|
|
|
pkg := pkgs[0]
|
|
|
|
var syntax *ast.File
|
|
|
|
for i, fn := range pkg.GoFiles {
|
|
|
|
if fn == nodesFile {
|
|
|
|
syntax = pkg.Syntax[i]
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if syntax == nil {
|
|
|
|
panic(fmt.Errorf("could not get ast of file: '%s'", nodesFile))
|
|
|
|
}
|
|
|
|
|
|
|
|
return syntax
|
|
|
|
}
|
|
|
|
|
|
|
|
func getNodeTypes(syntax *ast.File) []nodeType {
|
|
|
|
types := []nodeType{}
|
|
|
|
for _, decl := range syntax.Decls {
|
|
|
|
typedDecl, ok := decl.(*ast.GenDecl)
|
|
|
|
if !ok {
|
|
|
|
log.Printf("%T not *ast.GenDecl", decl)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, spec := range typedDecl.Specs {
|
|
|
|
typedSpec, ok := spec.(*ast.TypeSpec)
|
|
|
|
if !ok {
|
|
|
|
if _, ok := spec.(*ast.ImportSpec); !ok {
|
|
|
|
log.Printf("%T not *ast.TypeSpec", spec)
|
|
|
|
}
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
structType, ok := typedSpec.Type.(*ast.StructType)
|
|
|
|
if !ok {
|
|
|
|
log.Printf("%T not *ast.StructType", spec)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
nType := nodeType{
|
|
|
|
Name: typedSpec.Name.String(),
|
|
|
|
FuncName: typedSpec.Name.String(),
|
|
|
|
Traversable: map[string]bool{},
|
|
|
|
}
|
|
|
|
if funcName, ok := phpast.TypeToVisitorNameMap[typedSpec.Name.String()]; ok {
|
|
|
|
nType.FuncName = funcName
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, field := range structType.Fields.List {
|
|
|
|
if traversable, isArray := checkType(field.Type); traversable {
|
|
|
|
for _, n := range field.Names {
|
|
|
|
nType.Traversable[n.String()] = isArray
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
types = append(types, nType)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return types
|
|
|
|
}
|
|
|
|
|
|
|
|
func checkType(t ast.Expr) (traversable bool, isArray bool) {
|
|
|
|
switch ft := t.(type) {
|
|
|
|
case *ast.ArrayType:
|
|
|
|
ok, _ := checkType(ft.Elt)
|
|
|
|
return ok, true
|
|
|
|
case *ast.Ident:
|
|
|
|
switch ft.Name {
|
|
|
|
case "Vertex":
|
|
|
|
return true, false
|
|
|
|
default:
|
|
|
|
return false, false
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
return false, false
|
|
|
|
}
|
|
|
|
}
|