fusenapi/generator/main.go

284 lines
5.5 KiB
Go

package main
import (
"flag"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"github.com/zeromicro/ddl-parser/parser"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
var ddlDir = "ddl"
var genDir = "model/gmodel_gen"
func toPascalCase(s string) string {
words := strings.Split(s, "_")
for i, word := range words {
words[i] = cases.Title(language.English).String(strings.ToLower(word))
}
return strings.Join(words, "")
}
func main() {
var name string // 需要序列化的单独文件名
var gdir string // 需要修改的序列化路径 model
var ddir string // 需要修改的序列化路径 ddl
flag.StringVar(&name, "name", "", "输入需要序列化的ddl文件名, 不需要后缀.ddl")
flag.StringVar(&gdir, "mdir", "", "输入需要生成model的Go文件所在目录")
flag.StringVar(&ddir, "ddir", "", "输入需要生成ddl的Go文件所在目录")
flag.Parse()
if gdir != "" {
genDir = gdir
}
if ddir != "" {
ddlDir = ddir
}
if name != "" {
name = fmt.Sprintf("%s/%s.sql", ddlDir, name)
GenFromPath(name)
} else {
matches, err := filepath.Glob(fmt.Sprintf("%s/*.sql", ddlDir))
if err != nil {
panic(err)
}
for _, pth := range matches {
GenFromPath(pth)
}
}
}
func GenFromPath(pth string) {
p, err := filepath.Abs(pth)
if err != nil {
panic(err)
}
ddlfilestr, err := ioutil.ReadFile(pth)
if err != nil {
panic(err)
}
// PRIMARY KEY (`guest_id`) USING BTREE
re := regexp.MustCompile("PRIMARY\\s+KEY\\s+\\(\\s*`([^`]+)`\\s*\\)|`([^`]+)` [^\n]+PRIMARY\\s+KEY\\s+")
matches := re.FindStringSubmatch(string(ddlfilestr))
PrimaryStr := ""
if len(matches) > 0 {
PrimaryStr = matches[1]
}
var importstr = "import (\"gorm.io/gorm\"\n"
// 匹配到主键定义
parser.NewParser()
result, err := parser.NewParser().From(p)
if err != nil {
panic(err)
}
fcontent := "package model\n"
for _, table := range result {
structstr := "type %s struct {%s\n}\n"
tableName := toPascalCase(table.Name)
fieldstr := ""
for _, col := range table.Columns {
fieldName := toPascalCase(col.Name)
typeName := SQLTypeToGoTypeMap[col.DataType.Type()]
if typeName == "*time.Time" {
importstr += "\"time\"\n"
}
tagstr := "`gorm:"
if col.Name == PrimaryStr {
tagstr += "\"primary_key\""
typeName = typeName[1:]
} else {
tagstr += "\"\""
}
tagstr += fmt.Sprintf(" json:\"%s\"`", col.Name)
fieldColStr := fmt.Sprintf("\n%s %s %s// %s", fieldName, typeName, tagstr, col.Constraint.Comment)
fieldstr += fieldColStr
}
fcontent += importstr + ")\n"
fcontent += fmt.Sprintf(structstr, tableName, fieldstr)
modelstr := fmt.Sprintf(`type %sModel struct {db *gorm.DB}`, tableName)
fcontent += modelstr
fcontent += "\n"
newfuncstr := fmt.Sprintf(`func New%sModel(db *gorm.DB) *%sModel {return &%sModel{db}}`, tableName, tableName, tableName)
fcontent += newfuncstr
fcontent += "\n"
genGoFileName := fmt.Sprintf("%s/%s_gen.go", genDir, table.Name)
f, err := os.OpenFile(genGoFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
panic(err)
}
f.WriteString(fcontent)
err = f.Close()
if err != nil {
panic(err)
}
err = exec.Command("gofmt", "-w", genGoFileName).Run()
if err != nil {
panic(err)
}
fcontent = "package model\n// TODO: 使用model的属性做你想做的"
genGoLogicFileName := fmt.Sprintf("%s/%s_logic.go", genDir, table.Name)
// 使用 os.Stat 函数获取文件信息
_, err = os.Stat(genGoLogicFileName)
// 判断文件是否存在并输出结果
if os.IsNotExist(err) {
f2, err := os.OpenFile(genGoLogicFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
panic(err)
}
f2.WriteString(fcontent)
err = f2.Close()
if err != nil {
panic(err)
}
fmt.Println(genGoLogicFileName, "create!")
} else {
fmt.Println(genGoLogicFileName, "exists")
}
}
}
const (
_ int = iota
LongVarBinary
LongVarChar
GeometryCollection
GeomCollection
LineString
MultiLineString
MultiPoint
MultiPolygon
Point
Polygon
Json
Geometry
Enum
Set
Bit
Time
Timestamp
DateTime
Binary
VarBinary
Blob
Year
Decimal
Dec
Fixed
Numeric
Float
Float4
Float8
Double
Real
TinyInt
SmallInt
MediumInt
Int
Integer
BigInt
MiddleInt
Int1
Int2
Int3
Int4
Int8
Date
TinyBlob
MediumBlob
LongBlob
Bool
Boolean
Serial
NVarChar
NChar
Char
Character
VarChar
TinyText
Text
MediumText
LongText
)
var SQLTypeToGoTypeMap = map[int]string{
LongVarBinary: "[]byte",
Binary: "[]byte",
VarBinary: "[]byte",
Blob: "[]byte",
TinyBlob: "[]byte",
MediumBlob: "[]byte",
LongBlob: "[]byte",
LongVarChar: "*string",
NVarChar: "*string",
NChar: "*string",
Char: "*string",
Character: "*string",
VarChar: "*string",
TinyText: "*string",
Text: "*string",
MediumText: "*string",
LongText: "*string",
Time: "*time.Time",
Timestamp: "*time.Time",
DateTime: "*time.Time",
Date: "*time.Time",
Year: "*int64",
TinyInt: "*int64",
SmallInt: "*int64",
MediumInt: "*int64",
Int: "*int64",
Integer: "*int64",
BigInt: "*int64",
MiddleInt: "*int64",
Int1: "*int64",
Int2: "*int64",
Int3: "*int64",
Int4: "*int64",
Int8: "*int64",
Serial: "*int64",
Decimal: "*float64",
Dec: "*float64",
Fixed: "*float64",
Numeric: "*float64",
Float: "*float64",
Float4: "*float64",
Float8: "*float64",
Double: "*float64",
Real: "*float64",
Bool: "*bool",
Boolean: "*bool",
}