forked from quic-go/quic-go
logging: use code generation to generate the multiplexed tracers (#4677)
* logging: use code generation to generate the multiplexed tracers * logging: ignore auto-generated files for code coverage * logging: move code generation comment to generator script
This commit is contained in:
161
logging/generate_multiplexer.go
Normal file
161
logging/generate_multiplexer.go
Normal file
@@ -0,0 +1,161 @@
|
||||
//go:build generate
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"golang.org/x/tools/imports"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 5 {
|
||||
log.Fatalf("Usage: %s <struct_name> <input_file> <template_file> <output_file>", os.Args[0])
|
||||
}
|
||||
|
||||
structName := os.Args[1]
|
||||
inputFile := os.Args[2]
|
||||
templateFile := os.Args[3]
|
||||
outputFile := os.Args[4]
|
||||
|
||||
fset := token.NewFileSet()
|
||||
|
||||
// Parse the input file containing the struct type
|
||||
file, err := parser.ParseFile(fset, inputFile, nil, parser.AllErrors)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse file: %v", err)
|
||||
}
|
||||
|
||||
var fields []*ast.Field
|
||||
|
||||
// Find the specified struct type in the AST
|
||||
for _, decl := range file.Decls {
|
||||
genDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok || genDecl.Tok != token.TYPE {
|
||||
continue
|
||||
}
|
||||
for _, spec := range genDecl.Specs {
|
||||
typeSpec, ok := spec.(*ast.TypeSpec)
|
||||
if !ok || typeSpec.Name.Name != structName {
|
||||
continue
|
||||
}
|
||||
structType, ok := typeSpec.Type.(*ast.StructType)
|
||||
if !ok {
|
||||
log.Fatalf("%s is not a struct", structName)
|
||||
}
|
||||
fields = structType.Fields.List
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if fields == nil {
|
||||
log.Fatalf("Could not find %s type", structName)
|
||||
}
|
||||
|
||||
// Prepare data for the template
|
||||
type FieldData struct {
|
||||
Name string
|
||||
Params string
|
||||
Args string
|
||||
HasParams bool
|
||||
ReturnTypes string
|
||||
HasReturn bool
|
||||
}
|
||||
|
||||
var fieldDataList []FieldData
|
||||
|
||||
for _, field := range fields {
|
||||
funcType, ok := field.Type.(*ast.FuncType)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, name := range field.Names {
|
||||
fieldData := FieldData{Name: name.Name}
|
||||
|
||||
// extract parameters
|
||||
var params []string
|
||||
var args []string
|
||||
if funcType.Params != nil {
|
||||
for i, param := range funcType.Params.List {
|
||||
// We intentionally reject unnamed (and, further down, "_") function parameters.
|
||||
// We could auto-generate parameter names,
|
||||
// but having meaningful variable names will be more helpful for the user.
|
||||
if len(param.Names) == 0 {
|
||||
log.Fatalf("encountered unnamed parameter at position %d in function %s", i, fieldData.Name)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
printer.Fprint(&buf, fset, param.Type)
|
||||
paramType := buf.String()
|
||||
for _, paramName := range param.Names {
|
||||
if paramName.Name == "_" {
|
||||
log.Fatalf("encountered underscore parameter at position %d in function %s", i, fieldData.Name)
|
||||
}
|
||||
params = append(params, fmt.Sprintf("%s %s", paramName.Name, paramType))
|
||||
args = append(args, paramName.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
fieldData.Params = strings.Join(params, ", ")
|
||||
fieldData.Args = strings.Join(args, ", ")
|
||||
fieldData.HasParams = len(params) > 0
|
||||
|
||||
// extract return types
|
||||
if funcType.Results != nil && len(funcType.Results.List) > 0 {
|
||||
fieldData.HasReturn = true
|
||||
var returns []string
|
||||
for _, result := range funcType.Results.List {
|
||||
var buf bytes.Buffer
|
||||
printer.Fprint(&buf, fset, result.Type)
|
||||
returns = append(returns, buf.String())
|
||||
}
|
||||
if len(returns) == 1 {
|
||||
fieldData.ReturnTypes = fmt.Sprintf(" %s", returns[0])
|
||||
} else {
|
||||
fieldData.ReturnTypes = fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
fieldDataList = append(fieldDataList, fieldData)
|
||||
}
|
||||
}
|
||||
|
||||
// Read the template from file
|
||||
templateContent, err := os.ReadFile(templateFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to read template file: %v", err)
|
||||
}
|
||||
|
||||
// Generate the code using the template
|
||||
tmpl, err := template.New("multiplexer").Funcs(template.FuncMap{"join": strings.Join}).Parse(string(templateContent))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
var generatedCode bytes.Buffer
|
||||
generatedCode.WriteString("// Code generated by generate_multiplexer.go; DO NOT EDIT.\n\n")
|
||||
if err = tmpl.Execute(&generatedCode, map[string]interface{}{
|
||||
"Fields": fieldDataList,
|
||||
"StructName": structName,
|
||||
}); err != nil {
|
||||
log.Fatalf("Failed to execute template: %v", err)
|
||||
}
|
||||
|
||||
// Format the generated code and add imports
|
||||
formattedCode, err := imports.Process(outputFile, generatedCode.Bytes(), nil)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to process imports: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(outputFile, formattedCode, 0o644); err != nil {
|
||||
log.Fatalf("Failed to write output file: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user