Files
quic-go/logging/generate_multiplexer.go
Marten Seemann a7795541c9 logging: use code generation to generate the multiplexed tracers (#4677)
* logging: use code generation to generate the multiplexed tracers

* logging: ignore auto-generated files for code coverage

* logging: move code generation comment to generator script
2024-11-28 23:57:11 +08:00

162 lines
4.2 KiB
Go

//go:build generate
package main
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"log"
"os"
"strings"
"text/template"
"golang.org/x/tools/imports"
)
func main() {
if len(os.Args) != 5 {
log.Fatalf("Usage: %s <struct_name> <input_file> <template_file> <output_file>", os.Args[0])
}
structName := os.Args[1]
inputFile := os.Args[2]
templateFile := os.Args[3]
outputFile := os.Args[4]
fset := token.NewFileSet()
// Parse the input file containing the struct type
file, err := parser.ParseFile(fset, inputFile, nil, parser.AllErrors)
if err != nil {
log.Fatalf("Failed to parse file: %v", err)
}
var fields []*ast.Field
// Find the specified struct type in the AST
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok || typeSpec.Name.Name != structName {
continue
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
log.Fatalf("%s is not a struct", structName)
}
fields = structType.Fields.List
break
}
}
if fields == nil {
log.Fatalf("Could not find %s type", structName)
}
// Prepare data for the template
type FieldData struct {
Name string
Params string
Args string
HasParams bool
ReturnTypes string
HasReturn bool
}
var fieldDataList []FieldData
for _, field := range fields {
funcType, ok := field.Type.(*ast.FuncType)
if !ok {
continue
}
for _, name := range field.Names {
fieldData := FieldData{Name: name.Name}
// extract parameters
var params []string
var args []string
if funcType.Params != nil {
for i, param := range funcType.Params.List {
// We intentionally reject unnamed (and, further down, "_") function parameters.
// We could auto-generate parameter names,
// but having meaningful variable names will be more helpful for the user.
if len(param.Names) == 0 {
log.Fatalf("encountered unnamed parameter at position %d in function %s", i, fieldData.Name)
}
var buf bytes.Buffer
printer.Fprint(&buf, fset, param.Type)
paramType := buf.String()
for _, paramName := range param.Names {
if paramName.Name == "_" {
log.Fatalf("encountered underscore parameter at position %d in function %s", i, fieldData.Name)
}
params = append(params, fmt.Sprintf("%s %s", paramName.Name, paramType))
args = append(args, paramName.Name)
}
}
}
fieldData.Params = strings.Join(params, ", ")
fieldData.Args = strings.Join(args, ", ")
fieldData.HasParams = len(params) > 0
// extract return types
if funcType.Results != nil && len(funcType.Results.List) > 0 {
fieldData.HasReturn = true
var returns []string
for _, result := range funcType.Results.List {
var buf bytes.Buffer
printer.Fprint(&buf, fset, result.Type)
returns = append(returns, buf.String())
}
if len(returns) == 1 {
fieldData.ReturnTypes = fmt.Sprintf(" %s", returns[0])
} else {
fieldData.ReturnTypes = fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
}
}
fieldDataList = append(fieldDataList, fieldData)
}
}
// Read the template from file
templateContent, err := os.ReadFile(templateFile)
if err != nil {
log.Fatalf("Failed to read template file: %v", err)
}
// Generate the code using the template
tmpl, err := template.New("multiplexer").Funcs(template.FuncMap{"join": strings.Join}).Parse(string(templateContent))
if err != nil {
log.Fatalf("Failed to parse template: %v", err)
}
var generatedCode bytes.Buffer
generatedCode.WriteString("// Code generated by generate_multiplexer.go; DO NOT EDIT.\n\n")
if err = tmpl.Execute(&generatedCode, map[string]interface{}{
"Fields": fieldDataList,
"StructName": structName,
}); err != nil {
log.Fatalf("Failed to execute template: %v", err)
}
// Format the generated code and add imports
formattedCode, err := imports.Process(outputFile, generatedCode.Bytes(), nil)
if err != nil {
log.Fatalf("Failed to process imports: %v", err)
}
if err := os.WriteFile(outputFile, formattedCode, 0o644); err != nil {
log.Fatalf("Failed to write output file: %v", err)
}
}