php-parser/pkg/visitor/traverser/traverser_gen.go
2023-12-10 00:36:22 +01:00

168 lines
3.4 KiB
Go

//go:build tools
package main
import (
"fmt"
"go/ast"
"log"
"os"
"path/filepath"
"text/template"
phpast "git.maride.cc/maride/php-parser/pkg/ast"
"golang.org/x/tools/go/packages"
)
type nodeType struct {
// The type name.
Name string
// The function name.
FuncName string
// Map of field names to whether the field is an array or not.
Traversable map[string]bool
}
type templData struct {
Types []nodeType
}
var fileTempl = template.Must(
template.New("").Parse(`// Code generated by "go generate go run traverser_gen.go"; DO NOT EDIT.
package traverser
import "git.maride.cc/maride/php-parser/pkg/ast"
{{range $typ := .Types}}
func (t *Traverser) {{$typ.FuncName}}(n *ast.{{$typ.Name}}) {
if !t.checkEntrance(n) {
return
}
n.Accept(t.v)
{{range $name, $isArray := $typ.Traversable}}
{{- if $isArray}}
for _, nn := range n.{{$name}} {
nn.Accept(t)
}
{{else}}
t.Traverse(n.{{$name}})
{{end}}
{{- end}}
t.leave(n)
}
{{end}}`),
)
func main() {
ast := astOfFile("../../ast/node.go")
nodeTypes := getNodeTypes(ast)
file, err := os.Create("traverser_impl.go")
if err != nil {
panic(fmt.Errorf("open 'traverser_impl.go': %w", err))
}
fileTempl.Execute(file, templData{Types: nodeTypes})
}
func astOfFile(path string) *ast.File {
nodesFile, err := filepath.Abs(path)
if err != nil {
panic(fmt.Errorf("getting '%s' absolute path: %w", path, err))
}
cfg := &packages.Config{
Mode: packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
Tests: false,
}
pkgs, err := packages.Load(cfg, fmt.Sprintf("file=%s", nodesFile))
if err != nil {
panic(fmt.Errorf("loading '%s' package: %w", nodesFile, err))
}
if len(pkgs) == 0 {
panic(fmt.Errorf("could not get any package for file '%s'", nodesFile))
}
pkg := pkgs[0]
var syntax *ast.File
for i, fn := range pkg.GoFiles {
if fn == nodesFile {
syntax = pkg.Syntax[i]
break
}
}
if syntax == nil {
panic(fmt.Errorf("could not get ast of file: '%s'", nodesFile))
}
return syntax
}
func getNodeTypes(syntax *ast.File) []nodeType {
types := []nodeType{}
for _, decl := range syntax.Decls {
typedDecl, ok := decl.(*ast.GenDecl)
if !ok {
log.Printf("%T not *ast.GenDecl", decl)
continue
}
for _, spec := range typedDecl.Specs {
typedSpec, ok := spec.(*ast.TypeSpec)
if !ok {
if _, ok := spec.(*ast.ImportSpec); !ok {
log.Printf("%T not *ast.TypeSpec", spec)
}
continue
}
structType, ok := typedSpec.Type.(*ast.StructType)
if !ok {
log.Printf("%T not *ast.StructType", spec)
continue
}
nType := nodeType{
Name: typedSpec.Name.String(),
FuncName: typedSpec.Name.String(),
Traversable: map[string]bool{},
}
if funcName, ok := phpast.TypeToVisitorNameMap[typedSpec.Name.String()]; ok {
nType.FuncName = funcName
}
for _, field := range structType.Fields.List {
if traversable, isArray := checkType(field.Type); traversable {
for _, n := range field.Names {
nType.Traversable[n.String()] = isArray
}
}
}
types = append(types, nType)
}
}
return types
}
func checkType(t ast.Expr) (traversable bool, isArray bool) {
switch ft := t.(type) {
case *ast.ArrayType:
ok, _ := checkType(ft.Elt)
return ok, true
case *ast.Ident:
switch ft.Name {
case "Vertex":
return true, false
default:
return false, false
}
default:
return false, false
}
}