package main

import (
	"database/sql"
	"flag"
	"fmt"
	"log"
	"os"
	"os/exec"
	"regexp"
	"sort"
	"strconv"
	"strings"

	"golang.org/x/text/cases"
	"golang.org/x/text/language"

	_ "github.com/go-sql-driver/mysql"
)

var testName = "fs_auth_item"
var testGenDir = "model/gmodel"

func toPascalCase(s string) string {
	words := strings.Split(s, "_")
	for i, word := range words {
		words[i] = cases.Title(language.English).String(strings.ToLower(word))
	}
	return strings.Join(words, "")
}

func GetAllTableNames(uri string) []string {
	db, err := sql.Open("mysql", uri)
	if err != nil {
		panic(err)
	}
	defer db.Close()

	rows, err := db.Query("SHOW TABLES")
	if err != nil {
		panic(err)
	}

	var tableNames []string
	for rows.Next() {
		var tableName string
		if err := rows.Scan(&tableName); err != nil {
			panic(err)
		}
		tableNames = append(tableNames, tableName)
	}

	return tableNames
}

// "fusentest:XErSYmLELKMnf3Dh@tcp(110.41.19.98:3306)/fusentest"
func GetColsFromTable(tname string, db *sql.DB) (result []Column, tableName, tableComment string) {

	var a, ddl string
	err := db.QueryRow("SHOW CREATE TABLE "+tname).Scan(&a, &ddl)
	// log.Println(ddl)
	if err != nil {
		panic(err)
	}

	return ParserDDL(ddl)
}

var gmodelVarStr = `
package gmodel

import "gorm.io/gorm"

// AllModelsGen 所有Model集合,修改单行,只要不改字段名,不会根据新的内容修改,需要修改的话手动删除
type AllModelsGen struct {
}

func NewAllModels(gdb *gorm.DB) *AllModelsGen {
	models := &AllModelsGen{
	}
	return models
}
`

var gmodelVarStrFormat = `
package gmodel

import "gorm.io/gorm"

// AllModelsGen 所有Model集合,修改单行,只要不改字段名,不会根据新的内容修改,需要修改的话手动删除
type AllModelsGen struct {
	 %s
}

func NewAllModels(gdb *gorm.DB) *AllModelsGen {
	models := &AllModelsGen{
		%s
	}
	return models
}
`

type TableNameComment struct {
	Name    string
	GoName  string
	Comment string
}

type TMCS []TableNameComment

func (u TMCS) Len() int {
	return len(u)
}

func (u TMCS) Less(i, j int) bool {
	return u[i].Name < u[j].Name
}

func (u TMCS) Swap(i, j int) {
	u[i], u[j] = u[j], u[i]
}

func GenAllModels(filedir string, tmcs ...TableNameComment) {
	fileName := filedir + "/var_gen.go"
	var dupMap map[string]TableNameComment = make(map[string]TableNameComment)
	for _, tmc := range tmcs {
		dupMap[tmc.Name] = tmc
	}

	if _, err := os.Stat(fileName); err == nil {
		log.Printf("%s exists!", fileName)
		data, err := os.ReadFile(fileName)
		if err != nil {
			panic(err)
		}
		filestr := string(data)
		filelines := strings.Split(filestr, "\n")
		re := regexp.MustCompile(`([A-Za-z0-9_]+) [^/]+ // ([^ ]+) (.+)$`)
		for _, line := range filelines {
			result := re.FindStringSubmatch(line)
			if len(result) > 0 {
				// key := result[0]
				if len(result) != 4 {
					log.Println(result)
				}
				log.Println(result)

				tmc := TableNameComment{
					Name:    result[2],
					GoName:  result[1],
					Comment: result[3],
				}

				if newTmc, ok := dupMap[tmc.Name]; ok {
					log.Printf("not change: (old)%v -> (new)%v", tmc, newTmc)
				}

				dupMap[tmc.Name] = tmc
			}
		}

		tmcs = nil

		for _, tmc := range dupMap {
			tmcs = append(tmcs, tmc)
		}

		sort.Sort(TMCS(tmcs))

		structStr := ""
		newModelsStr := ""
		for _, tmc := range tmcs {
			fsline := fmt.Sprintf("%s *%sModel // %s %s\n", tmc.GoName, tmc.GoName, tmc.Name, tmc.Comment)
			structStr += fsline
			nmline := fmt.Sprintf("%s: New%sModel(gdb),\n", tmc.GoName, tmc.GoName)
			newModelsStr += nmline
		}

		content := fmt.Sprintf(gmodelVarStrFormat, structStr, newModelsStr)
		f, err := os.OpenFile(fileName, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644)
		if err != nil {
			panic(err)
		}
		_, err = f.WriteString(content)
		if err != nil {
			panic(err)
		}

	} else if os.IsExist(err) {
		f, err := os.Create(fileName)
		if err != nil {
			panic(err)
		}
		_, err = f.WriteString(gmodelVarStr)
		if err != nil {
			panic(err)
		}
	} else {
		panic(err)
	}

	err := exec.Command("gofmt", "-w", fileName).Run()
	if err != nil {
		panic(err)
	}
}

func main() {
	var mysqluri string
	var name string // 需要序列化的单独文件名
	var mdir string // 需要修改的序列化路径 model

	flag.StringVar(&mysqluri, "uri", "fusentest:XErSYmLELKMnf3Dh@tcp(110.41.19.98:3306)/fusentest", "输入需要序列化的ddl文件名, 不需要后缀.ddl")
	flag.StringVar(&name, "name", "", "输入需要序列化的ddl文件名, 不需要后缀.ddl")
	flag.StringVar(&mdir, "mdir", "", "输入需要生成model的Go文件所在目录")

	flag.Parse()

	if mdir != "" {
		testGenDir = mdir
	}

	db, err := sql.Open("mysql", mysqluri)
	if err != nil {
		panic(err)
	}
	defer db.Close()
	var tmcs []TableNameComment
	if name == "-" {
		tablenames := GetAllTableNames(mysqluri)
		for _, testName := range tablenames {
			cols, tname, tcomment := GetColsFromTable(testName, db)
			GenFromPath(testGenDir, cols, tname, tcomment)

			tmcs = append(tmcs, TableNameComment{
				Name:    tname,
				GoName:  toPascalCase(tname),
				Comment: tcomment,
			})
		}

	} else {
		if name != "" {
			testName = name
		}

		// log.Println(testName)
		cols, tname, tcomment := GetColsFromTable(testName, db)
		GenFromPath(testGenDir, cols, tname, tcomment)

		tmcs = append(tmcs, TableNameComment{
			Name:    tname,
			GoName:  toPascalCase(tname),
			Comment: tcomment,
		})
	}

	GenAllModels(testGenDir, tmcs...)

	// tablenames := GetAllTableNames(mysqluri)
	// log.Println(tablenames)

	// name

}

func GenFromPath(mdir string, cols []Column, tableName string, tableComment string) {

	var importstr = "import (\"gorm.io/gorm\"\n"

	// 匹配到主键定义

	fcontent := "package gmodel\n"

	structstr := "// %s %s\ntype %s struct {%s\n}\n"

	pTableName := toPascalCase(tableName)

	fieldstr := ""
	for _, col := range cols {
		fieldName := toPascalCase(col.Name)
		typeName := typeForMysqlToGo[col.GetType()]
		var defaultString string
		if col.DefaultValue != nil {
			switch typeName {
			case "*int64", "*uint64", "*float64", "*bool":
				defaultString = "default:" + strings.Trim(*col.DefaultValue, "'") + ";"
			default:
				defaultString = "default:" + *col.DefaultValue + ";"
			}

		} else {

			switch typeName {
			case "*string":
				defaultString = "default:'';"
			case "*time.Time":
				defaultString = "default:'0000-00-00 00:00:00';"
			case "*[]byte":
				defaultString = "default:'';"
			case "*int64", "*uint64":
				defaultString = "default:0;"
			case "*float64":
				defaultString = "default: 0.0;"
			case "*bool":
				defaultString = "default:0;"
			default:
				fieldName = "// " + fieldName + " " + col.Type
			}
		}

		if typeName == "*time.Time" {
			importstr += "\"time\"\n"
		}

		if col.IndexType == "primary_key" {
			typeName = typeName[1:]
		}

		tagstr := "`gorm:"

		gormTag := ""
		if col.IndexType != "" {
			gormTag += col.IndexType + ";"
		}

		gormTag += defaultString

		if col.AutoIncrement {
			gormTag += "auto_increment;"
		}

		tagstr += fmt.Sprintf("\"%s\"", gormTag)

		tagstr += fmt.Sprintf(" json:\"%s\"`", col.Name)

		fieldColStr := fmt.Sprintf("\n%s %s %s// %s", fieldName, typeName, tagstr, col.Comment)

		fieldstr += fieldColStr

	}

	fcontent += importstr + ")\n"
	fcontent += fmt.Sprintf(structstr, tableName, tableComment, pTableName, fieldstr)
	modelstr := fmt.Sprintf(`type %sModel struct {db *gorm.DB
		name string}`, pTableName)
	fcontent += modelstr
	fcontent += "\n"

	newfuncstr := fmt.Sprintf(`func New%sModel(db *gorm.DB) *%sModel {return &%sModel{db:db,name:"%s"}}`, pTableName, pTableName, pTableName, tableName)
	fcontent += newfuncstr
	fcontent += "\n"

	genGoFileName := fmt.Sprintf("%s/%s_gen.go", mdir, tableName)
	f, err := os.OpenFile(genGoFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
	if err != nil {
		panic(err)
	}

	f.WriteString(fcontent)
	err = f.Close()
	if err != nil {
		panic(err)
	}
	err = exec.Command("gofmt", "-w", genGoFileName).Run()
	if err != nil {
		panic(err)
	}

	fcontent = "package gmodel\n// TODO: 使用model的属性做你想做的"
	genGoLogicFileName := fmt.Sprintf("%s/%s_logic.go", mdir, tableName)

	// 使用 os.Stat 函数获取文件信息
	_, err = os.Stat(genGoLogicFileName)
	// 判断文件是否存在并输出结果
	if os.IsNotExist(err) {
		f2, err := os.OpenFile(genGoLogicFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
		if err != nil {
			panic(err)
		}
		f2.WriteString(fcontent)
		err = f2.Close()
		if err != nil {
			panic(err)
		}
		fmt.Println(genGoLogicFileName, "create!")
	} else {
		fmt.Println(genGoLogicFileName, "exists")
	}

}

type Column struct {
	Name          string
	Type          string
	DefaultValue  *string
	Length        int
	Decimal       int
	Unsigned      bool
	NotNull       bool
	AutoIncrement bool
	Comment       string

	IndexType string
}

func (col *Column) GetType() string {
	content := col.Type
	if col.Unsigned {
		return content + " unsigned"
	}
	return content
}

var typeForMysqlToGo = map[string]string{
	// 整数
	"int":       "*int64",
	"integer":   "*int64",
	"tinyint":   "*int64",
	"smallint":  "*int64",
	"mediumint": "*int64",
	"bigint":    "*int64",
	"year":      "*int64",

	"int unsigned":       "*int64",
	"integer unsigned":   "*int64",
	"tinyint unsigned":   "*int64",
	"smallint unsigned":  "*int64",
	"mediumint unsigned": "*int64",
	"bigint unsigned":    "*int64",
	"bit":                "*int64",

	// 布尔类型
	"bool": "*bool",

	// 字符串
	"enum":       "*string",
	"set":        "*string",
	"varchar":    "*string",
	"char":       "*string",
	"tinytext":   "*string",
	"mediumtext": "*string",
	"text":       "*string",
	"longtext":   "*string",

	// 二进制
	"binary":     "*[]byte",
	"varbinary":  "*[]byte",
	"blob":       "*[]byte",
	"tinyblob":   "*[]byte",
	"mediumblob": "*[]byte",
	"longblob":   "*[]byte",

	// 日期时间
	"date":      "*time.Time",
	"datetime":  "*time.Time",
	"timestamp": "*time.Time",
	"time":      "*time.Time",

	// 浮点数
	"float":   "*float64",
	"double":  "*float64",
	"decimal": "*float64",
}

func ParserDDL(ddl string) (result []Column, tableName, tableComment string) {

	reTable := regexp.MustCompile(`CREATE TABLE +([^ ]+) +\(`)
	reTableComment := regexp.MustCompile(`.+COMMENT='(.+)'$`)
	reField := regexp.MustCompile("`([^`]+)` +([^ \n\\(\\,]+)(?:\\(([^)]+)\\))?( +unsigned| +UNSIGNED)?( +not +null| +NOT +NULL)?( +default +\\'[^\\']*'| +DEFAULT +\\'[^\\']*')?( +auto_increment| +AUTO_INCREMENT)?( comment '[^']*'| COMMENT '[^']*')?(,)?")
	reIndex := regexp.MustCompile(`(?i)(PRIMARY|UNIQUE)?\s*(INDEX|KEY)\s*(` + "`([^`]*)`" + `)?\s*\(([^)]+)\)`)
	reValue := regexp.MustCompile(` '(.+)'$`)
	reDefaultValue := regexp.MustCompile(` ('.+')$`)

	var fieldmap map[string]string = make(map[string]string)
	indexMatches := reIndex.FindAllStringSubmatch(ddl, -1)
	for _, m := range indexMatches {
		idxAttr := strings.Trim(m[5], "`")
		PrefixName := strings.ToUpper(m[1])
		if PrefixName == "PRIMARY" {
			fieldmap[idxAttr] = "primary_key"
		} else if PrefixName == "UNIQUE" {
			fieldmap[idxAttr] = "unique_key"
		} else if PrefixName == "" {
			fieldmap[idxAttr] = "index"
		} else {
			log.Fatal(PrefixName)
		}
	}

	tableMatches := reTable.FindStringSubmatch(ddl)
	tableName = strings.Trim(tableMatches[1], "`")

	tableCommentMatches := reTableComment.FindStringSubmatch(ddl)
	if len(tableCommentMatches) > 0 {
		tableComment = strings.Trim(tableCommentMatches[1], "`")
	}

	// log.Println(tableName, tableComment)
	fieldMatches := reField.FindAllStringSubmatch(ddl, -1)
	for _, m := range fieldMatches {
		if m[0] == "" {
			continue
		}

		col := Column{
			Name: m[1],
			Type: strings.ToLower(m[2]),
		}

		col.IndexType = fieldmap[col.Name]

		if m[3] != "" {
			maylen := strings.Split(m[3], ",")
			if len(maylen) >= 1 {
				clen, err := strconv.ParseInt(maylen[0], 10, 64)
				if err != nil {
					panic(err)
				}
				col.Length = int(clen)
			}
			if len(maylen) >= 2 {
				clen, err := strconv.ParseInt(maylen[1], 10, 64)
				if err != nil {
					panic(err)
				}
				col.Decimal = int(clen)
			}
		}

		if len(m[4]) > 0 {
			col.Unsigned = true
		}

		if len(m[5]) > 0 {
			col.NotNull = true
		}

		if len(m[6]) > 0 {
			v := reDefaultValue.FindStringSubmatch(m[6])
			if len(v) > 0 {
				dv := string(v[1])
				col.DefaultValue = &dv
			}
		}

		if len(m[7]) > 0 {
			col.AutoIncrement = true
		}

		if len(m[8]) > 0 {
			v := reValue.FindStringSubmatch(m[8])
			if len(v) > 0 {
				col.Comment = v[1]
			}
		}
		result = append(result, col)
		// fmt.Println(col)
	}
	return result, tableName, tableComment
}