package main

import (
	"database/sql"
	"flag"
	"fmt"
	"log"
	"os"
	"os/exec"
	"regexp"
	"strconv"
	"strings"

	"golang.org/x/text/cases"
	"golang.org/x/text/language"

	_ "github.com/go-sql-driver/mysql"
)

var testName = "fs_auth_item"
var testGenDir = "model/gmodel"

func toPascalCase(s string) string {
	words := strings.Split(s, "_")
	for i, word := range words {
		words[i] = cases.Title(language.English).String(strings.ToLower(word))
	}
	return strings.Join(words, "")
}

func GetAllTableNames(uri string) []string {
	db, err := sql.Open("mysql", uri)
	if err != nil {
		panic(err)
	}
	defer db.Close()

	rows, err := db.Query("SHOW TABLES")
	if err != nil {
		panic(err)
	}

	var tableNames []string
	for rows.Next() {
		var tableName string
		if err := rows.Scan(&tableName); err != nil {
			panic(err)
		}
		tableNames = append(tableNames, tableName)
	}

	return tableNames
}

// "fusentest:XErSYmLELKMnf3Dh@tcp(110.41.19.98:3306)/fusentest"
func GetColsFromTable(tname string, db *sql.DB) (result []Column, tableName, tableComment string) {

	var a, ddl string
	err := db.QueryRow("SHOW CREATE TABLE "+tname).Scan(&a, &ddl)
	// log.Println(ddl)
	if err != nil {
		panic(err)
	}

	return ParserDDL(ddl)
}

func main() {
	var mysqluri string
	var name string // 需要序列化的单独文件名
	var mdir string // 需要修改的序列化路径 model

	flag.StringVar(&mysqluri, "uri", "fusentest:XErSYmLELKMnf3Dh@tcp(110.41.19.98:3306)/fusentest", "输入需要序列化的ddl文件名, 不需要后缀.ddl")
	flag.StringVar(&name, "name", "", "输入需要序列化的ddl文件名, 不需要后缀.ddl")
	flag.StringVar(&mdir, "mdir", "", "输入需要生成model的Go文件所在目录")

	flag.Parse()

	if mdir != "" {
		testGenDir = mdir
	}

	db, err := sql.Open("mysql", mysqluri)
	if err != nil {
		panic(err)
	}
	defer db.Close()

	if name == "-" {
		tablenames := GetAllTableNames(mysqluri)
		for _, testName := range tablenames {
			cols, tname, tcomment := GetColsFromTable(testName, db)
			GenFromPath(testGenDir, cols, tname, tcomment)
		}
	} else {
		if name != "" {
			testName = name
		}

		// log.Println(testName)
		cols, tname, tcomment := GetColsFromTable(testName, db)
		GenFromPath(testGenDir, cols, tname, tcomment)
	}

	// tablenames := GetAllTableNames(mysqluri)
	// log.Println(tablenames)

	// name

}

func GenFromPath(mdir string, cols []Column, tableName string, tableComment string) {

	var importstr = "import (\"gorm.io/gorm\"\n"

	// 匹配到主键定义

	fcontent := "package gmodel\n"

	structstr := "// %s %s\ntype %s struct {%s\n}\n"

	pTableName := toPascalCase(tableName)

	fieldstr := ""
	for _, col := range cols {
		fieldName := toPascalCase(col.Name)
		typeName := typeForMysqlToGo[col.GetType()]
		var defaultString string
		if col.DefaultValue != nil {
			switch typeName {
			case "*int64", "*uint64", "*float64", "*bool":
				defaultString = "default:" + strings.Trim(*col.DefaultValue, "'") + ";"
			default:
				defaultString = "default:" + *col.DefaultValue + ";"
			}

		} else {

			switch typeName {
			case "*string":
				defaultString = "default:'';"
			case "*time.Time":
				defaultString = "default:'0000-00-00 00:00:00';"
			case "*[]byte":
				defaultString = "default:'';"
			case "*int64", "*uint64":
				defaultString = "default:0;"
			case "*float64":
				defaultString = "default: 0.0;"
			case "*bool":
				defaultString = "default:0;"
			default:
				fieldName = "// " + fieldName + " " + col.Type
			}
		}

		if typeName == "*time.Time" {
			importstr += "\"time\"\n"
		}

		if col.IndexType == "primary_key" {
			typeName = typeName[1:]
		}

		tagstr := "`gorm:"

		gormTag := ""
		if col.IndexType != "" {
			gormTag += col.IndexType + ";"
		}

		gormTag += defaultString
		// if col.DefaultValue == "" {
		// 	log.Panic(col, "需要默认值")
		// }

		tagstr += fmt.Sprintf("\"%s\"", gormTag)

		tagstr += fmt.Sprintf(" json:\"%s\"`", col.Name)

		fieldColStr := fmt.Sprintf("\n%s %s %s// %s", fieldName, typeName, tagstr, col.Comment)

		fieldstr += fieldColStr

	}

	fcontent += importstr + ")\n"
	fcontent += fmt.Sprintf(structstr, tableName, tableComment, pTableName, fieldstr)
	modelstr := fmt.Sprintf(`type %sModel struct {db *gorm.DB}`, pTableName)
	fcontent += modelstr
	fcontent += "\n"

	newfuncstr := fmt.Sprintf(`func New%sModel(db *gorm.DB) *%sModel {return &%sModel{db}}`, pTableName, pTableName, pTableName)
	fcontent += newfuncstr
	fcontent += "\n"

	genGoFileName := fmt.Sprintf("%s/%s_gen.go", mdir, tableName)
	f, err := os.OpenFile(genGoFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
	if err != nil {
		panic(err)
	}

	f.WriteString(fcontent)
	err = f.Close()
	if err != nil {
		panic(err)
	}
	err = exec.Command("gofmt", "-w", genGoFileName).Run()
	if err != nil {
		panic(err)
	}

	fcontent = "package gmodel\n// TODO: 使用model的属性做你想做的"
	genGoLogicFileName := fmt.Sprintf("%s/%s_logic.go", mdir, tableName)

	// 使用 os.Stat 函数获取文件信息
	_, err = os.Stat(genGoLogicFileName)
	// 判断文件是否存在并输出结果
	if os.IsNotExist(err) {
		f2, err := os.OpenFile(genGoLogicFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
		if err != nil {
			panic(err)
		}
		f2.WriteString(fcontent)
		err = f2.Close()
		if err != nil {
			panic(err)
		}
		fmt.Println(genGoLogicFileName, "create!")
	} else {
		fmt.Println(genGoLogicFileName, "exists")
	}

}

type Column struct {
	Name          string
	Type          string
	DefaultValue  *string
	Length        int
	Decimal       int
	Unsigned      bool
	NotNull       bool
	AutoIncrement bool
	Comment       string

	IndexType string
}

func (col *Column) GetType() string {
	content := col.Type
	if col.Unsigned {
		return content + " unsigned"
	}
	return content
}

var typeForMysqlToGo = map[string]string{
	// 整数
	"int":       "*int64",
	"integer":   "*int64",
	"tinyint":   "*int64",
	"smallint":  "*int64",
	"mediumint": "*int64",
	"bigint":    "*int64",
	"year":      "*int64",

	"int unsigned":       "*int64",
	"integer unsigned":   "*int64",
	"tinyint unsigned":   "*int64",
	"smallint unsigned":  "*int64",
	"mediumint unsigned": "*int64",
	"bigint unsigned":    "*int64",
	"bit":                "*int64",

	// 布尔类型
	"bool": "*bool",

	// 字符串
	"enum":       "*string",
	"set":        "*string",
	"varchar":    "*string",
	"char":       "*string",
	"tinytext":   "*string",
	"mediumtext": "*string",
	"text":       "*string",
	"longtext":   "*string",

	// 二进制
	"binary":     "*[]byte",
	"varbinary":  "*[]byte",
	"blob":       "*[]byte",
	"tinyblob":   "*[]byte",
	"mediumblob": "*[]byte",
	"longblob":   "*[]byte",

	// 日期时间
	"date":      "*time.Time",
	"datetime":  "*time.Time",
	"timestamp": "*time.Time",
	"time":      "*time.Time",

	// 浮点数
	"float":   "*float64",
	"double":  "*float64",
	"decimal": "*float64",
}

func ParserDDL(ddl string) (result []Column, tableName, tableComment string) {

	reTable := regexp.MustCompile(`CREATE TABLE +([^ ]+) +\(`)
	reTableComment := regexp.MustCompile(`.+COMMENT='(.+)'$`)
	reField := regexp.MustCompile("`([^`]+)` +([^ \n\\(\\,]+)(?:\\(([^)]+)\\))?( +unsigned| +UNSIGNED)?( +not +null| +NOT +NULL)?( +default +\\'[^\\']*'| +DEFAULT +\\'[^\\']*')?( +auto_increment| +AUTO_INCREMENT)?( comment '[^']*'| COMMENT '[^']*')?(,)?")
	reIndex := regexp.MustCompile(`(?i)(PRIMARY|UNIQUE)?\s*(INDEX|KEY)\s*(` + "`([^`]*)`" + `)?\s*\(([^)]+)\)`)
	reValue := regexp.MustCompile(` '(.+)'$`)
	reDefaultValue := regexp.MustCompile(` ('.+')$`)

	var fieldmap map[string]string = make(map[string]string)
	indexMatches := reIndex.FindAllStringSubmatch(ddl, -1)
	for _, m := range indexMatches {
		idxAttr := strings.Trim(m[5], "`")
		PrefixName := strings.ToUpper(m[1])
		if PrefixName == "PRIMARY" {
			fieldmap[idxAttr] = "primary_key"
		} else if PrefixName == "UNIQUE" {
			fieldmap[idxAttr] = "unique_key"
		} else if PrefixName == "" {
			fieldmap[idxAttr] = "index"
		} else {
			log.Fatal(PrefixName)
		}
	}

	tableMatches := reTable.FindStringSubmatch(ddl)
	tableName = strings.Trim(tableMatches[1], "`")

	tableCommentMatches := reTableComment.FindStringSubmatch(ddl)
	if len(tableCommentMatches) > 0 {
		tableComment = strings.Trim(tableCommentMatches[1], "`")
	}

	// log.Println(tableName, tableComment)
	fieldMatches := reField.FindAllStringSubmatch(ddl, -1)
	for _, m := range fieldMatches {
		if m[0] == "" {
			continue
		}

		col := Column{
			Name: m[1],
			Type: strings.ToLower(m[2]),
		}

		col.IndexType = fieldmap[col.Name]

		if m[3] != "" {
			maylen := strings.Split(m[3], ",")
			if len(maylen) >= 1 {
				clen, err := strconv.ParseInt(maylen[0], 10, 64)
				if err != nil {
					panic(err)
				}
				col.Length = int(clen)
			}
			if len(maylen) >= 2 {
				clen, err := strconv.ParseInt(maylen[1], 10, 64)
				if err != nil {
					panic(err)
				}
				col.Decimal = int(clen)
			}
		}

		if len(m[4]) > 0 {
			col.Unsigned = true
		}

		if len(m[5]) > 0 {
			col.NotNull = true
		}

		if len(m[6]) > 0 {
			v := reDefaultValue.FindStringSubmatch(m[6])
			if len(v) > 0 {
				dv := string(v[1])
				col.DefaultValue = &dv
			}
		}

		if len(m[7]) > 0 {
			col.AutoIncrement = true
		}

		if len(m[8]) > 0 {
			v := reValue.FindStringSubmatch(m[8])
			if len(v) > 0 {
				col.Comment = v[1]
			}
		}
		result = append(result, col)
		// fmt.Println(col)
	}
	return result, tableName, tableComment
}