package main import ( "flag" "fmt" "io/ioutil" "os" "os/exec" "path/filepath" "regexp" "strings" "github.com/zeromicro/ddl-parser/parser" "golang.org/x/text/cases" "golang.org/x/text/language" ) var ddlDir = "ddl" var genDir = "model/gmodel_gen" func toPascalCase(s string) string { words := strings.Split(s, "_") for i, word := range words { words[i] = cases.Title(language.English).String(strings.ToLower(word)) } return strings.Join(words, "") } func main() { var name string // 需要序列化的单独文件名 var gdir string // 需要修改的序列化路径 model var ddir string // 需要修改的序列化路径 ddl flag.StringVar(&name, "name", "", "输入需要序列化的ddl文件名, 不需要后缀.ddl") flag.StringVar(&gdir, "mdir", "", "输入需要生成model的Go文件所在目录") flag.StringVar(&ddir, "ddir", "", "输入需要生成ddl的Go文件所在目录") flag.Parse() if gdir != "" { genDir = gdir } if ddir != "" { ddlDir = ddir } if name != "" { name = fmt.Sprintf("%s/%s.sql", ddlDir, name) GenFromPath(name) } else { matches, err := filepath.Glob(fmt.Sprintf("%s/*.sql", ddlDir)) if err != nil { panic(err) } for _, pth := range matches { GenFromPath(pth) } } } func GenFromPath(pth string) { p, err := filepath.Abs(pth) if err != nil { panic(err) } ddlfilestr, err := ioutil.ReadFile(pth) if err != nil { panic(err) } // PRIMARY KEY (`guest_id`) USING BTREE re := regexp.MustCompile("PRIMARY\\s+KEY\\s+\\(\\s*`([^`]+)`\\s*\\)|`([^`]+)` [^\n]+PRIMARY\\s+KEY\\s+") matches := re.FindStringSubmatch(string(ddlfilestr)) PrimaryStr := "" if len(matches) > 0 { PrimaryStr = matches[1] } var importstr = "import (\"gorm.io/gorm\"\n" // 匹配到主键定义 parser.NewParser() result, err := parser.NewParser().From(p) if err != nil { panic(err) } fcontent := "package model\n" for _, table := range result { structstr := "type %s struct {%s\n}\n" tableName := toPascalCase(table.Name) fieldstr := "" for _, col := range table.Columns { fieldName := toPascalCase(col.Name) typeName := SQLTypeToGoTypeMap[col.DataType.Type()] if typeName == "*time.Time" { importstr += "\"time\"\n" } tagstr := "`gorm:" if col.Name == PrimaryStr { tagstr += "\"primary_key\"" typeName = typeName[1:] } else { tagstr += "\"\"" } tagstr += fmt.Sprintf(" json:\"%s\"`", col.Name) fieldColStr := fmt.Sprintf("\n%s %s %s// %s", fieldName, typeName, tagstr, col.Constraint.Comment) fieldstr += fieldColStr } fcontent += importstr + ")\n" fcontent += fmt.Sprintf(structstr, tableName, fieldstr) modelstr := fmt.Sprintf(`type %sModel struct {db *gorm.DB}`, tableName) fcontent += modelstr fcontent += "\n" newfuncstr := fmt.Sprintf(`func New%sModel(db *gorm.DB) *%sModel {return &%sModel{db}}`, tableName, tableName, tableName) fcontent += newfuncstr fcontent += "\n" genGoFileName := fmt.Sprintf("%s/%s_gen.go", genDir, table.Name) f, err := os.OpenFile(genGoFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { panic(err) } f.WriteString(fcontent) err = f.Close() if err != nil { panic(err) } err = exec.Command("gofmt", "-w", genGoFileName).Run() if err != nil { panic(err) } fcontent = "package model\n// TODO: 使用model的属性做你想做的" genGoLogicFileName := fmt.Sprintf("%s/%s_logic.go", genDir, table.Name) // 使用 os.Stat 函数获取文件信息 _, err = os.Stat(genGoLogicFileName) // 判断文件是否存在并输出结果 if os.IsNotExist(err) { f2, err := os.OpenFile(genGoLogicFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { panic(err) } f2.WriteString(fcontent) err = f2.Close() if err != nil { panic(err) } fmt.Println(genGoLogicFileName, "create!") } else { fmt.Println(genGoLogicFileName, "exists") } } } const ( _ int = iota LongVarBinary LongVarChar GeometryCollection GeomCollection LineString MultiLineString MultiPoint MultiPolygon Point Polygon Json Geometry Enum Set Bit Time Timestamp DateTime Binary VarBinary Blob Year Decimal Dec Fixed Numeric Float Float4 Float8 Double Real TinyInt SmallInt MediumInt Int Integer BigInt MiddleInt Int1 Int2 Int3 Int4 Int8 Date TinyBlob MediumBlob LongBlob Bool Boolean Serial NVarChar NChar Char Character VarChar TinyText Text MediumText LongText ) var SQLTypeToGoTypeMap = map[int]string{ LongVarBinary: "[]byte", Binary: "[]byte", VarBinary: "[]byte", Blob: "[]byte", TinyBlob: "[]byte", MediumBlob: "[]byte", LongBlob: "[]byte", LongVarChar: "*string", NVarChar: "*string", NChar: "*string", Char: "*string", Character: "*string", VarChar: "*string", TinyText: "*string", Text: "*string", MediumText: "*string", LongText: "*string", Time: "*time.Time", Timestamp: "*time.Time", DateTime: "*time.Time", Date: "*time.Time", Year: "*int64", TinyInt: "*int64", SmallInt: "*int64", MediumInt: "*int64", Int: "*int64", Integer: "*int64", BigInt: "*int64", MiddleInt: "*int64", Int1: "*int64", Int2: "*int64", Int3: "*int64", Int4: "*int64", Int8: "*int64", Serial: "*int64", Decimal: "*float64", Dec: "*float64", Fixed: "*float64", Numeric: "*float64", Float: "*float64", Float4: "*float64", Float8: "*float64", Double: "*float64", Real: "*float64", Bool: "*bool", Boolean: "*bool", }