Przeglądaj źródła

feature: swagger 支持自动导入、嵌入字段

lai 3 lat temu
rodzic
commit
fb364f80f3
1 zmienionych plików z 84 dodań i 2 usunięć
  1. 84 2
      opms_parent/swaggerui/swagger.go

+ 84 - 2
opms_parent/swaggerui/swagger.go

@@ -8,8 +8,10 @@ import (
 	"go/token"
 	"go/types"
 	"io"
+	"io/ioutil"
 	"log"
 	"os"
+	"path"
 	"reflect"
 	"strconv"
 	"strings"
@@ -149,6 +151,7 @@ type SwaggerModel struct {
 	Properties  map[string]*SwaggerModel `yaml:"properties,omitempty"`
 	Items       *SwaggerModel            `yaml:"items,omitempty"`
 	Ref         string                   `yaml:"ref,omitempty"`
+	Embedded    bool                     `yaml:"embedded,omitempty"`
 }
 
 func RangeSwaggerModel(node *SwaggerModel) []*SwaggerModel {
@@ -189,16 +192,50 @@ type entityDecl struct {
 }
 
 var outfile string
+var modname string
 
 func init() {
 	flag.StringVar(&outfile, "o", "swagger.yml", "output file")
+	flag.StringVar(&modname, "m", "dashoo.cn/micro", "mod name")
 }
 
 func main() {
 	// testGoParser()
 	// return
 	flag.Parse()
-	toload := []string{"dashoo.cn/micro/app/model/contract", "dashoo.cn/micro/app/handler/contract"}
+	toload := []string{
+		// "dashoo.cn/micro/app/model/contract",
+		// "dashoo.cn/micro/app/handler/contract",
+	}
+
+	for _, d := range []string{"app/model", "app/handler"} {
+		files, err := ioutil.ReadDir(d)
+		if err != nil {
+
+			if os.IsNotExist(err) {
+				fmt.Printf("%s not found\n", d)
+				return
+			}
+			panic(err)
+		}
+		for _, f := range files {
+			if f.IsDir() {
+				toload = append(toload, fmt.Sprintf("%s/%s/%s", modname, d, f.Name()))
+			}
+			info, err := os.Stat(path.Join(d, f.Name(), "internal"))
+			if os.IsNotExist(err) {
+				continue
+			}
+			if info.IsDir() {
+				toload = append(toload, fmt.Sprintf("%s/%s/%s/internal", modname, d, f.Name()))
+			}
+
+		}
+	}
+	for _, p := range toload {
+		fmt.Println(p)
+	}
+
 	cfg := &packages.Config{Mode: pkgLoadMode}
 	pkgs, err := packages.Load(cfg, toload...)
 	if err != nil {
@@ -257,6 +294,10 @@ func main() {
 		}
 		// fmt.Println(m.Ref, AllStructModel[m.Ref])
 		for _, subm := range RangeSwaggerModel(m) {
+			// 属性引用自己的循环引用,两个不同 struct 的循环引用这里没有检测
+			if n == subm.Ref {
+				continue
+			}
 			if subm.Ref != "" && AllStructModel[subm.Ref] != nil {
 				subm.Items = AllStructModel[subm.Ref].Items
 				subm.Properties = AllStructModel[subm.Ref].Properties
@@ -270,6 +311,29 @@ func main() {
 		// fmt.Println(n)
 		// fmt.Println(string(b))
 	}
+
+	// 展开嵌入字段
+	for n := range AllStructModel {
+		m := AllStructModel[n]
+		if m == nil {
+			continue
+		}
+		for _, subm := range RangeSwaggerModel(m) {
+			for pn, p := range subm.Properties {
+				if !p.Embedded {
+					continue
+				}
+				// fmt.Println("--", n, "|", pn)
+				delete(subm.Properties, pn)
+				for ssn, ssp := range p.Properties {
+					if _, ok := subm.Properties[ssn]; ok {
+						continue
+					}
+					subm.Properties[ssn] = ssp
+				}
+			}
+		}
+	}
 	genconfig(outfile)
 }
 
@@ -280,6 +344,7 @@ func genconfig(filename string) {
 	}
 	defer f.Close()
 
+	debugLog("genconfig AllStructModel")
 	for n := range AllStructModel {
 		m := AllStructModel[n]
 		if m == nil {
@@ -296,12 +361,14 @@ func genconfig(filename string) {
 		}
 	}
 
+	debugLog("genconfig Routes")
 	for p, m := range Routes {
 		swagger["paths"].(map[string]interface{})[p] = newPath(m)
 	}
 
 	encoder := yaml.NewEncoder(f)
 	encoder.SetIndent(2)
+	debugLog("genconfig Encode")
 	err = encoder.Encode(swagger)
 	if err != nil {
 		panic(err)
@@ -541,6 +608,11 @@ func buildFromType(tpe types.Type, declFile *ast.File, desc string) (error, *Swa
 			// return buildFromType(utitpe.Elem(), tgt.Items())
 
 		}
+	case *types.Interface:
+		return nil, &SwaggerModel{
+			Type:        "object",
+			Description: desc,
+		}
 
 	default:
 		panic(fmt.Sprintf("WARNING: can't determine refined type %s (%T)", titpe.String(), titpe))
@@ -601,7 +673,7 @@ func buildFromStruct(declFile *ast.File, st *types.Struct, desc string) (error,
 				properties["orderBy"] = &SwaggerModel{Type: "string", Description: "排序方式"}
 			}
 			debugLog("Embedded %s, %s", fld.Name(), fld.Pkg())
-			continue
+			// continue
 		}
 
 		if !fld.Exported() {
@@ -632,6 +704,9 @@ func buildFromStruct(declFile *ast.File, st *types.Struct, desc string) (error,
 		if err != nil {
 			return err, nil
 		}
+		if name == "" {
+			name = fld.Name()
+		}
 		commentText := ""
 		if afld.Comment != nil {
 			commentText = afld.Comment.Text()
@@ -640,6 +715,13 @@ func buildFromStruct(declFile *ast.File, st *types.Struct, desc string) (error,
 		if err != nil {
 			return err, nil
 		}
+		// if fld.Embedded() {
+		// 	fmt.Println(fld.Name(), ret)
+		// }
+		if ret == nil {
+			continue
+		}
+		ret.Embedded = fld.Embedded()
 		debugLog("****** %s %v %s", fld.Type(), ret, commentText)
 		if isString {
 			return nil, &SwaggerModel{