//go:build generate package main import ( "bytes" "fmt" "go/ast" "go/parser" "go/printer" "go/token" "log" "os" "strings" "text/template" "golang.org/x/tools/imports" ) func main() { if len(os.Args) != 5 { log.Fatalf("Usage: %s ", os.Args[0]) } structName := os.Args[1] inputFile := os.Args[2] templateFile := os.Args[3] outputFile := os.Args[4] fset := token.NewFileSet() // Parse the input file containing the struct type file, err := parser.ParseFile(fset, inputFile, nil, parser.AllErrors) if err != nil { log.Fatalf("Failed to parse file: %v", err) } var fields []*ast.Field // Find the specified struct type in the AST for _, decl := range file.Decls { genDecl, ok := decl.(*ast.GenDecl) if !ok || genDecl.Tok != token.TYPE { continue } for _, spec := range genDecl.Specs { typeSpec, ok := spec.(*ast.TypeSpec) if !ok || typeSpec.Name.Name != structName { continue } structType, ok := typeSpec.Type.(*ast.StructType) if !ok { log.Fatalf("%s is not a struct", structName) } fields = structType.Fields.List break } } if fields == nil { log.Fatalf("Could not find %s type", structName) } // Prepare data for the template type FieldData struct { Name string Params string Args string HasParams bool ReturnTypes string HasReturn bool } var fieldDataList []FieldData for _, field := range fields { funcType, ok := field.Type.(*ast.FuncType) if !ok { continue } for _, name := range field.Names { fieldData := FieldData{Name: name.Name} // extract parameters var params []string var args []string if funcType.Params != nil { for i, param := range funcType.Params.List { // We intentionally reject unnamed (and, further down, "_") function parameters. // We could auto-generate parameter names, // but having meaningful variable names will be more helpful for the user. if len(param.Names) == 0 { log.Fatalf("encountered unnamed parameter at position %d in function %s", i, fieldData.Name) } var buf bytes.Buffer printer.Fprint(&buf, fset, param.Type) paramType := buf.String() for _, paramName := range param.Names { if paramName.Name == "_" { log.Fatalf("encountered underscore parameter at position %d in function %s", i, fieldData.Name) } params = append(params, fmt.Sprintf("%s %s", paramName.Name, paramType)) args = append(args, paramName.Name) } } } fieldData.Params = strings.Join(params, ", ") fieldData.Args = strings.Join(args, ", ") fieldData.HasParams = len(params) > 0 // extract return types if funcType.Results != nil && len(funcType.Results.List) > 0 { fieldData.HasReturn = true var returns []string for _, result := range funcType.Results.List { var buf bytes.Buffer printer.Fprint(&buf, fset, result.Type) returns = append(returns, buf.String()) } if len(returns) == 1 { fieldData.ReturnTypes = fmt.Sprintf(" %s", returns[0]) } else { fieldData.ReturnTypes = fmt.Sprintf(" (%s)", strings.Join(returns, ", ")) } } fieldDataList = append(fieldDataList, fieldData) } } // Read the template from file templateContent, err := os.ReadFile(templateFile) if err != nil { log.Fatalf("Failed to read template file: %v", err) } // Generate the code using the template tmpl, err := template.New("multiplexer").Funcs(template.FuncMap{"join": strings.Join}).Parse(string(templateContent)) if err != nil { log.Fatalf("Failed to parse template: %v", err) } var generatedCode bytes.Buffer generatedCode.WriteString("// Code generated by generate_multiplexer.go; DO NOT EDIT.\n\n") if err = tmpl.Execute(&generatedCode, map[string]interface{}{ "Fields": fieldDataList, "StructName": structName, }); err != nil { log.Fatalf("Failed to execute template: %v", err) } // Format the generated code and add imports formattedCode, err := imports.Process(outputFile, generatedCode.Bytes(), nil) if err != nil { log.Fatalf("Failed to process imports: %v", err) } if err := os.WriteFile(outputFile, formattedCode, 0o644); err != nil { log.Fatalf("Failed to write output file: %v", err) } }