package main import ( "flag" "fmt" "go/ast" "go/parser" "go/token" "go/types" "io" "io/ioutil" "log" "os" "path" "reflect" "strconv" "strings" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" "gopkg.in/yaml.v3" ) var swagger = map[string]interface{}{ "openapi": "3.0.0", "info": map[string]string{ "title": "CRM", "description": "CRM", "version": "0.0.1", }, "paths": map[string]interface{}{}, "security": []map[string]interface{}{ { "bearerAuth": []interface{}{}, }, }, "components": map[string]interface{}{ "securitySchemes": map[string]interface{}{ "basicAuth": map[string]interface{}{ "type": "http", "scheme": "basic", }, "bearerAuth": map[string]interface{}{ "type": "http", "scheme": "bearer", }, }, "schemas": map[string]interface{}{}, "examples": map[string]interface{}{ "success": map[string]interface{}{ "summary": "请求成功", "value": map[string]interface{}{ "code": 200, "msg": "success", }, }, }, }, } type pathModel struct { OperationId string Summary string Tags []string Request string } func newPath(m pathModel) interface{} { schemaref := fmt.Sprintf("#/components/schemas/%s", m.Request) examplesref := fmt.Sprintf("#/components/examples/%s", m.Request) return map[string]interface{}{ "post": map[string]interface{}{ "tags": m.Tags, "operationId": m.OperationId, "summary": m.Summary, "requestBody": map[string]interface{}{ "required": true, "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ "oneOf": []map[string]interface{}{ { "$ref": schemaref, }, }, }, "examples": map[string]interface{}{ m.Request: map[string]interface{}{ "$ref": examplesref, }, }, }, }, }, "responses": map[string]interface{}{ "200": map[string]interface{}{ "description": "请求成功", "content": map[string]interface{}{ "application/json": map[string]interface{}{ "examples": map[string]interface{}{ "success": map[string]interface{}{ "$ref": "#/components/examples/success", }, }, }, }, }, }, }, } } func testGoParser() error { sourceFile, err := os.Open("/home/lai/code/working/opms_backend/opms_parent/app/handler/contract/ctr_contract.go") if err != nil { return err } sourceContent, err := io.ReadAll(sourceFile) if err != nil { return err } // Create the AST by parsing src. fset := token.NewFileSet() // positions are relative to fset f, err := parser.ParseFile(fset, "", sourceContent, parser.ParseComments) if err != nil { panic(err) } // Print the imports from the file's AST. for _, s := range f.Imports { debugLog(s.Path.Value) fmt.Printf("%#v\n", s.Path) } debugLog("-----------------------") if Debug { ast.Print(fset, f) } debugLog("-----------------------") for _, c := range f.Comments { for _, l := range c.List { fmt.Printf("%#v%#v\n", l.Text, l.Slash) } } return nil } type SwaggerModel struct { Type string `yaml:"type"` Description string `yaml:"description"` Properties map[string]*SwaggerModel `yaml:"properties,omitempty"` Items *SwaggerModel `yaml:"items,omitempty"` Ref string `yaml:"ref,omitempty"` Embedded bool `yaml:"embedded,omitempty"` Required []string `yaml:"required,omitempty"` } func RangeSwaggerModel(node *SwaggerModel) []*SwaggerModel { if node == nil { return nil } allnode := []*SwaggerModel{} nodes := []*SwaggerModel{node} for len(nodes) != 0 { newnode := []*SwaggerModel{} for _, n := range nodes { allnode = append(allnode, n) if n.Items != nil { newnode = append(newnode, n.Items) } for _, p := range n.Properties { newnode = append(newnode, p) } } nodes = newnode } return allnode } const pkgLoadMode = packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo var AllPackages = map[string]*packages.Package{} var Models = map[*ast.Ident]*entityDecl{} var AllStructModel = map[string]*SwaggerModel{} var AllSwaggerModel = map[string]*SwaggerModel{} var SpecifiedSwaggerModel = []map[string]string{} var Routes = map[string]pathModel{} type entityDecl struct { Type *types.Named File *ast.File Pkg *packages.Package } var outfile string var modname string func init() { flag.StringVar(&outfile, "o", "swagger.yml", "output file") flag.StringVar(&modname, "m", "dashoo.cn/micro", "mod name") } func main() { // testGoParser() // return flag.Parse() toload := []string{ // "dashoo.cn/micro/app/model/contract", // "dashoo.cn/micro/app/handler/contract", } for _, d := range []string{"app/model", "app/handler"} { files, err := ioutil.ReadDir(d) if err != nil { if os.IsNotExist(err) { fmt.Printf("%s not found\n", d) return } panic(err) } for _, f := range files { if f.IsDir() { toload = append(toload, fmt.Sprintf("%s/%s/%s", modname, d, f.Name())) } info, err := os.Stat(path.Join(d, f.Name(), "internal")) if os.IsNotExist(err) { continue } if f.Name() == ".DS_Store" { continue } if info.IsDir() { toload = append(toload, fmt.Sprintf("%s/%s/%s/internal", modname, d, f.Name())) } } } for _, p := range toload { fmt.Println(p) } cfg := &packages.Config{Mode: pkgLoadMode} pkgs, err := packages.Load(cfg, toload...) if err != nil { fmt.Fprintf(os.Stderr, "load: %v\n", err) os.Exit(1) } if packages.PrintErrors(pkgs) > 0 { os.Exit(1) } for _, pkg := range pkgs { AllPackages[pkg.PkgPath] = pkg for _, file := range pkg.Syntax { for _, dt := range file.Decls { switch fd := dt.(type) { case *ast.GenDecl: for _, sp := range fd.Specs { switch ts := sp.(type) { case *ast.TypeSpec: def, ok := pkg.TypesInfo.Defs[ts.Name] if !ok { debugLog("couldn't find type info for %s", ts.Name) continue } nt, isNamed := def.Type().(*types.Named) if !isNamed { debugLog("%s is not a named type but a %T", ts.Name, def.Type()) continue } key := ts.Name Models[key] = &entityDecl{ Type: nt, File: file, Pkg: pkg, } } } } } } } // Print the names of the source files // for each package listed on the command line. for _, pkg := range pkgs { debugLog("-----------%s", pkg.Name) for _, f := range pkg.Syntax { detectModels(pkg, f) } } for n := range AllStructModel { m := AllStructModel[n] if m == nil { continue } // fmt.Println(m.Ref, AllStructModel[m.Ref]) for _, subm := range RangeSwaggerModel(m) { // 属性引用自己的循环引用,两个不同 struct 的循环引用这里没有检测 if n == subm.Ref { continue } if subm.Ref != "" && AllStructModel[subm.Ref] != nil { subm.Items = AllStructModel[subm.Ref].Items subm.Properties = AllStructModel[subm.Ref].Properties } } // b, err := yaml.Marshal(m) // if err != nil { // panic(err) // } // fmt.Println(n) // fmt.Println(string(b)) } // 展开嵌入字段 for n := range AllStructModel { m := AllStructModel[n] if m == nil { continue } for _, subm := range RangeSwaggerModel(m) { for pn, p := range subm.Properties { if !p.Embedded { continue } // fmt.Println("--", n, "|", pn) delete(subm.Properties, pn) for ssn, ssp := range p.Properties { if _, ok := subm.Properties[ssn]; ok { continue } subm.Properties[ssn] = ssp } } } } genconfig(outfile) } func genconfig(filename string) { f, err := os.Create(filename) if err != nil { panic(err) } defer f.Close() debugLog("genconfig AllStructModel") for n := range AllStructModel { m := AllStructModel[n] if m == nil { continue } namelist := strings.Split(n, ".") name := namelist[len(namelist)-1] swagger["components"].(map[string]interface{})["schemas"].(map[string]interface{})[name] = m swagger["components"].(map[string]interface{})["examples"].(map[string]interface{})[name] = map[string]interface{}{ "value": map[string]interface{}{ "placeholder": "", }, } } debugLog("genconfig Routes") for p, m := range Routes { swagger["paths"].(map[string]interface{})[p] = newPath(m) } encoder := yaml.NewEncoder(f) encoder.SetIndent(2) debugLog("genconfig Encode") err = encoder.Encode(swagger) if err != nil { panic(err) } } func getFuncRecv(fd *ast.FuncDecl) string { defer func() { if r := recover(); r != nil { debugLog("%s", r) } }() return fd.Recv.List[0].Type.(*ast.StarExpr).X.(*ast.Ident).Name } func getFuncParams(fd *ast.FuncDecl) string { defer func() { if r := recover(); r != nil { debugLog("%s", r) } }() return fd.Type.Params.List[1].Type.(*ast.StarExpr).X.(*ast.SelectorExpr).Sel.Name } func detectPath(fd *ast.FuncDecl) { recv := getFuncRecv(fd) fname := fd.Name.String() req := getFuncParams(fd) debugLog("//////////%v %s %s %s %s", fd, fname, recv, req, fd.Doc.Text()) commentText := fd.Doc.Text() for _, c := range strings.Split(commentText, "\n") { c = strings.TrimSpace(c) if !strings.HasPrefix(c, "Swagger:") { continue } names := strings.Split(c, " ") tags := []string{} summary := "" if len(names) == 3 { tags = strings.Split(names[1], ",") summary = names[2] recvNamed := strings.Split(names[0], ":") if len(recvNamed) == 2 { recv = recvNamed[1] } } path := fmt.Sprintf("/%s.%s", recv, fname) operationId := recv + fname Routes[path] = pathModel{ OperationId: operationId, Summary: summary, Tags: tags, Request: req, } break } } func detectModels(pkg *packages.Package, file *ast.File) { for _, dt := range file.Decls { switch fd := dt.(type) { case *ast.BadDecl: continue case *ast.FuncDecl: detectPath(fd) case *ast.GenDecl: processDecl(pkg, file, fd) } } } func processDecl(pkg *packages.Package, file *ast.File, gd *ast.GenDecl) error { for _, sp := range gd.Specs { switch ts := sp.(type) { case *ast.ValueSpec: return nil case *ast.ImportSpec: return nil case *ast.TypeSpec: def, ok := pkg.TypesInfo.Defs[ts.Name] if !ok { debugLog("couldn't find type info for %s", ts.Name) continue } nt, isNamed := def.Type().(*types.Named) if !isNamed { debugLog("%s is not a named type but a %T", ts.Name, def.Type()) continue } comments := ts.Doc // type ( /* doc */ Foo struct{} ) if comments == nil { comments = gd.Doc // /* doc */ type ( Foo struct{} ) } // decl := &entityDecl{ // Comments: comments, // Type: nt, // Ident: ts.Name, // Spec: ts, // File: file, // Pkg: pkg, // } // key := ts.Name // if n&modelNode != 0 && decl.HasModelAnnotation() { // a.Models[key] = decl // } // if n¶metersNode != 0 && decl.HasParameterAnnotation() { // a.Parameters = append(a.Parameters, decl) // } // if n&responseNode != 0 && decl.HasResponseAnnotation() { // a.Responses = append(a.Responses, decl) // } switch tpe := nt.Obj().Type().(type) { case *types.Struct: st := tpe for i := 0; i < st.NumFields(); i++ { fld := st.Field(i) if !fld.Anonymous() { debugLog("skipping field %q for allOf scan because not anonymous\n", fld.Name()) continue } tg := st.Tag(i) debugLog(fld.Name(), tg) } case *types.Slice: debugLog("Slice") case *types.Basic: debugLog("Basic") case *types.Interface: debugLog("Interface") case *types.Array: debugLog("Array") case *types.Map: debugLog("Map") case *types.Named: debugLog("Named") o := tpe.Obj() if o != nil { debugLog("got the named type object: %s.%s | isAlias: %t | exported: %t //////////// %s", o.Pkg().Path(), o.Name(), o.IsAlias(), o.Exported(), comments.Text()) if o.Pkg().Name() == "time" && o.Name() == "Time" { // schema.Typed("string", "date-time") return nil } for { ti := pkg.TypesInfo.Types[ts.Type] if ti.IsBuiltin() { break } if ti.IsType() { err, ret := buildFromType(ti.Type, file, comments.Text()) if err != nil { return err } AllStructModel[fmt.Sprintf("%s.%s", o.Pkg().Path(), o.Name())] = ret break } } } } } } return nil } func buildFromType(tpe types.Type, declFile *ast.File, desc string) (error, *SwaggerModel) { switch titpe := tpe.(type) { case *types.Basic: return nil, &SwaggerModel{ Type: swaggerSchemaForType(titpe.String()), Description: desc, } case *types.Pointer: return buildFromType(titpe.Elem(), declFile, desc) case *types.Struct: return buildFromStruct(declFile, titpe, desc) case *types.Slice: err, ret := buildFromType(titpe.Elem(), declFile, desc) if err != nil { return err, nil } return nil, &SwaggerModel{ Type: "array", Items: ret, Description: desc, } case *types.Named: tio := titpe.Obj() debugLog("%s", tio) if tio.Pkg() == nil && tio.Name() == "error" { return nil, &SwaggerModel{ Type: swaggerSchemaForType(tio.Name()), Description: desc, } } if tpe.String() == "github.com/gogf/gf/os/gtime.Time" { return nil, &SwaggerModel{ Type: "string", Description: desc + "DATETIME", } } debugLog("named refined type %s.%s", tio.Pkg().Path(), tio.Name()) pkg, found := PkgForType(tpe) if !found { // this must be a builtin debugLog("skipping because package is nil: %s", tpe.String()) return nil, nil } if pkg.Name == "time" && tio.Name() == "Time" { return nil, &SwaggerModel{ Type: "string", Description: desc, } } switch utitpe := tpe.Underlying().(type) { case *types.Struct: if decl, ok := FindModel(tio.Pkg().Path(), tio.Name()); ok { if decl.Type.Obj().Pkg().Path() == "time" && decl.Type.Obj().Name() == "Time" { return nil, &SwaggerModel{ Type: "string", Description: desc, } } debugLog("!!!!!!!%s", decl.Type.String()) return nil, &SwaggerModel{ Type: "object", Ref: decl.Type.String(), } } case *types.Slice: debugLog("!!!!!!!slice %s", utitpe.Elem()) // decl, ok := FindModel(tio.Pkg().Path(), tio.Name()) // return buildFromType(utitpe.Elem(), tgt.Items()) } case *types.Interface: return nil, &SwaggerModel{ Type: "object", Description: desc, } case *types.Map: return nil, &SwaggerModel{ Type: "object", Description: desc, } default: panic(fmt.Sprintf("WARNING: can't determine refined type %s (%T)", titpe.String(), titpe)) } return nil, nil } func buildFromStruct(declFile *ast.File, st *types.Struct, desc string) (error, *SwaggerModel) { // for i := 0; i < st.NumFields(); i++ { // fld := st.Field(i) // if !fld.Anonymous() { // debugLog("skipping field %q for allOf scan because not anonymous", fld.Name()) // continue // } // tg := st.Tag(i) // debugLog("maybe allof field(%t) %s: %s (%T) [%q](anon: %t, embedded: %t)", fld.IsField(), fld.Name(), fld.Type().String(), fld.Type(), tg, fld.Anonymous(), fld.Embedded()) // var afld *ast.Field // ans, _ := astutil.PathEnclosingInterval(declFile, fld.Pos(), fld.Pos()) // // debugLog("got %d nodes (exact: %t)", len(ans), isExact) // for _, an := range ans { // at, valid := an.(*ast.Field) // if !valid { // continue // } // debugLog("maybe allof field %s: %s(%T) [%q]", fld.Name(), fld.Type().String(), fld.Type(), tg) // afld = at // break // } // if afld == nil { // debugLog("can't find source associated with %s for %s", fld.String(), st.String()) // continue // } // _, ignore, _, err := parseJSONTag(afld) // if err != nil { // return err // } // if ignore { // continue // } // } required := []string{} properties := map[string]*SwaggerModel{} for i := 0; i < st.NumFields(); i++ { fld := st.Field(i) tg := st.Tag(i) if fld.Embedded() { if fld.Name() == "PageReq" { properties["beginTime"] = &SwaggerModel{Type: "string", Description: "开始时间"} properties["endTime"] = &SwaggerModel{Type: "string", Description: "结束时间"} properties["pageNum"] = &SwaggerModel{Type: "int", Description: "当前页码"} properties["pageSize"] = &SwaggerModel{Type: "int", Description: "每页数"} properties["orderBy"] = &SwaggerModel{Type: "string", Description: "排序方式"} } debugLog("Embedded %s, %s", fld.Name(), fld.Pkg()) // continue } if !fld.Exported() { debugLog("skipping field %s because it's not exported", fld.Name()) continue } var afld *ast.Field ans, _ := astutil.PathEnclosingInterval(declFile, fld.Pos(), fld.Pos()) // debugLog("got %d nodes (exact: %t)", len(ans), isExact) for _, an := range ans { at, valid := an.(*ast.Field) if !valid { continue } debugLog("field %s: %s(%T) [%q] ==> %s", fld.Name(), fld.Type().String(), fld.Type(), tg, at.Doc.Text()) afld = at break } if afld == nil { debugLog("can't find source associated with %s", fld.String()) continue } name, _, isString, err := parseJSONTag(afld) if err != nil { return err, nil } if name == "" { name = fld.Name() } commentText := "" if afld.Comment != nil { commentText = afld.Comment.Text() } // if strings.Contains(commentText, "required") { // required = append(required, name) // } if afld.Tag != nil { tagstr := strings.Split(afld.Tag.Value, "v:") if len(tagstr) > 1 { if strings.Contains(tagstr[1], "required") { required = append(required, name) } } // debugLog("------------------------------%s", afld.Tag.Value) } err, ret := buildFromType(fld.Type(), declFile, commentText) if err != nil { return err, nil } // if fld.Embedded() { // fmt.Println(fld.Name(), ret) // } if ret == nil { continue } ret.Embedded = fld.Embedded() debugLog("****** %s %v %s", fld.Type(), ret, commentText) if isString { return nil, &SwaggerModel{ Type: "string", Description: desc, } } properties[name] = ret } return nil, &SwaggerModel{ Type: "object", Properties: properties, Required: required, Description: desc, } } func swaggerSchemaForType(typeName string) string { switch typeName { case "bool": return "boolean" case "byte": return "integer" case "float32": return "number" case "float64": return "number" case "int": return "integer" case "int16": return "integer" case "int32": return "integer" case "int64": return "integer" case "int8": return "integer" case "rune": return "integer" case "string": return "string" case "uint": return "integer" case "uint16": return "integer" case "uint32": return "integer" case "uint64": return "integer" case "uint8": return "integer" } return "" } func FindModel(pkgPath, name string) (*entityDecl, bool) { for _, cand := range Models { ct := cand.Type.Obj() if ct.Name() == name && ct.Pkg().Path() == pkgPath { return cand, true } } return nil, false } func PkgForType(t types.Type) (*packages.Package, bool) { switch tpe := t.(type) { // case *types.Basic: // case *types.Struct: // case *types.Pointer: // case *types.Interface: // case *types.Array: // case *types.Slice: // case *types.Map: case *types.Named: v, ok := AllPackages[tpe.Obj().Pkg().Path()] return v, ok default: log.Printf("unknown type to find the package for [%T]: %s", t, t.String()) return nil, false } } func parseJSONTag(field *ast.Field) (name string, ignore bool, isString bool, err error) { if len(field.Names) > 0 { name = field.Names[0].Name } if field.Tag == nil || len(strings.TrimSpace(field.Tag.Value)) == 0 { return name, false, false, nil } tv, err := strconv.Unquote(field.Tag.Value) if err != nil { return name, false, false, err } if strings.TrimSpace(tv) != "" { st := reflect.StructTag(tv) jsonParts := tagOptions(strings.Split(st.Get("json"), ",")) if jsonParts.Contain("string") { // Need to check if the field type is a scalar. Otherwise, the // ",string" directive doesn't apply. isString = isFieldStringable(field.Type) } switch jsonParts.Name() { case "-": return name, true, isString, nil case "": return name, false, isString, nil default: return jsonParts.Name(), false, isString, nil } } return name, false, false, nil } func isFieldStringable(tpe ast.Expr) bool { if ident, ok := tpe.(*ast.Ident); ok { switch ident.Name { case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float64", "string", "bool": return true } } else if starExpr, ok := tpe.(*ast.StarExpr); ok { return isFieldStringable(starExpr.X) } else { return false } return false } type tagOptions []string func (t tagOptions) Contain(option string) bool { for i := 1; i < len(t); i++ { if t[i] == option { return true } } return false } func (t tagOptions) Name() string { return t[0] } var Debug = false func debugLog(format string, args ...interface{}) { if Debug { log.Printf(format, args...) } }