package config import ( "ariga.io/entcache" "database/sql" entsql "entgo.io/ent/dialect/sql" "fmt" "git.noahlan.cn/noahlan/ntool-biz/core/config" "git.noahlan.cn/noahlan/ntool/nlog" "github.com/go-redis/redis/v8" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" "time" ) type Database struct { Host string `json:",env=DB_HOST"` // Host Port int `json:",env=DB_PORT"` // Port Username string `json:",optional,env=DB_USERNAME"` // 用户名 Password string `json:",optional,env=DB_PASSWORD"` // 密码 DBName string `json:",optional,env=DB_DBNAME"` // 数据库名 Timezone string `json:",optional,default=Local"` // 时区 Local本地 Asia/Shanghai上海 SSLMode string `json:",optional"` // Postgresql的SSL模式 Type string `json:",default=mysql,options=[mysql,postgres]"` // 数据库类型 MaxOpenConns *int `json:",optional,default=200"` // 最大打开连接数 Debug bool `json:",optional,default=false"` // 调试模式 CacheTime int `json:",optional,default=10"` // 缓存时间,单位秒 } // NewCacheDriver returns an ent driver with cache. func (c Database) NewCacheDriver(redisConf config.RedisConf) *entcache.Driver { db, err := sql.Open(c.Type, c.GetDSN()) nlog.Must(err) db.SetMaxOpenConns(*c.MaxOpenConns) driver := entsql.OpenDB(c.Type, db) rdb := redis.NewClient(&redis.Options{ Network: redisConf.Network, Addr: redisConf.Addr, Username: redisConf.Username, Password: redisConf.Password, DB: redisConf.DB, }) cacheDrv := entcache.NewDriver( driver, entcache.TTL(time.Duration(c.CacheTime)*time.Second), entcache.Levels( entcache.NewLRU(256), entcache.NewRedis(rdb), ), ) return cacheDrv } // NewNoCacheDriver returns a ent driver without cache. func (c Database) NewNoCacheDriver() *entsql.Driver { db, err := sql.Open(c.Type, c.GetDSN()) nlog.Must(err) db.SetMaxOpenConns(*c.MaxOpenConns) driver := entsql.OpenDB(c.Type, db) return driver } // MysqlDSN returns mysql DSN. func (c Database) MysqlDSN() string { return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True&loc=%s", c.Username, c.Password, c.Host, c.Port, c.DBName, c.Timezone) } // PostgresDSN returns Postgres DSN. func (c Database) PostgresDSN() string { return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", c.Username, c.Password, c.Host, c.Port, c.DBName, c.SSLMode) } // GetDSN returns DSN according to the database type. func (c Database) GetDSN() string { switch c.Type { case "mysql": return c.MysqlDSN() case "postgres": return c.PostgresDSN() default: return "mysql" } }