| |
| |
| |
|
|
| package db |
|
|
| import ( |
| "database/sql" |
| "fmt" |
| "regexp" |
| "strconv" |
| "strings" |
|
|
| "github.com/GoAdminGroup/go-admin/modules/config" |
| ) |
|
|
| |
| type Mssql struct { |
| Base |
| } |
|
|
| |
| func GetMssqlDB() *Mssql { |
| return &Mssql{ |
| Base: Base{ |
| DbList: make(map[string]*sql.DB), |
| }, |
| } |
| } |
|
|
| |
| func (db *Mssql) GetDelimiter() string { |
| return "[" |
| } |
|
|
| |
| func (db *Mssql) GetDelimiter2() string { |
| return "]" |
| } |
|
|
| |
| func (db *Mssql) GetDelimiters() []string { |
| return []string{"[", "]"} |
| } |
|
|
| |
| func (db *Mssql) Name() string { |
| return "mssql" |
| } |
|
|
| |
|
|
| func replaceStringFunc(pattern, src string, rpl func(s string) string) (string, error) { |
|
|
| r, err := regexp.Compile(pattern) |
| if err != nil { |
| return "", err |
| } |
|
|
| bytes := r.ReplaceAllFunc([]byte(src), func(bytes []byte) []byte { |
| return []byte(rpl(string(bytes))) |
| }) |
|
|
| return string(bytes), nil |
| } |
|
|
| func replace(pattern string, replace, src []byte) ([]byte, error) { |
|
|
| r, err := regexp.Compile(pattern) |
| if err != nil { |
| return nil, err |
| } |
|
|
| return r.ReplaceAll(src, replace), nil |
| } |
|
|
| func replaceString(pattern, rep, src string) (string, error) { |
| r, e := replace(pattern, []byte(rep), []byte(src)) |
| return string(r), e |
| } |
|
|
| func matchAllString(pattern string, src string) ([][]string, error) { |
| r, err := regexp.Compile(pattern) |
| if err != nil { |
| return nil, err |
| } |
| return r.FindAllStringSubmatch(src, -1), nil |
| } |
|
|
| func isMatch(pattern string, src []byte) bool { |
| r, err := regexp.Compile(pattern) |
| if err != nil { |
| return false |
| } |
| return r.Match(src) |
| } |
|
|
| func isMatchString(pattern string, src string) bool { |
| return isMatch(pattern, []byte(src)) |
| } |
|
|
| func matchString(pattern string, src string) ([]string, error) { |
| r, err := regexp.Compile(pattern) |
| if err != nil { |
| return nil, err |
| } |
| return r.FindStringSubmatch(src), nil |
| } |
|
|
| |
| |
| func (db *Mssql) handleSqlBeforeExec(query string) string { |
| index := 0 |
| str, _ := replaceStringFunc("\\?", query, func(s string) string { |
| index++ |
| return fmt.Sprintf("@p%d", index) |
| }) |
|
|
| str, _ = replaceString("\"", "", str) |
|
|
| return db.parseSql(str) |
| } |
|
|
| |
| |
| func (db *Mssql) parseSql(sql string) string { |
| |
| patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))` |
| if !isMatchString(patten, sql) { |
| |
| return sql |
| } |
|
|
| res, err := matchAllString(patten, sql) |
| if err != nil { |
| |
| return "" |
| } |
|
|
| index := 0 |
| keyword := strings.TrimSpace(res[index][0]) |
| keyword = strings.ToUpper(keyword) |
|
|
| index++ |
| switch keyword { |
| case "SELECT": |
| |
| if len(res) < 2 || (!strings.HasPrefix(res[index][0], "LIMIT") && !strings.HasPrefix(res[index][0], "limit")) { |
| break |
| } |
|
|
| |
| if !isMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) { |
| break |
| } |
|
|
| |
| selectStr := "" |
| orderbyStr := "" |
| haveOrderby := isMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql) |
| if haveOrderby { |
| |
| queryExpr, _ := matchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql) |
|
|
| if len(queryExpr) != 4 || !strings.EqualFold(queryExpr[1], "SELECT") || !strings.EqualFold(queryExpr[3], "ORDER BY") { |
| break |
| } |
| selectStr = queryExpr[2] |
|
|
| |
| orderbyExpr, _ := matchString("((?i)ORDER BY)(.+)((?i)LIMIT)", sql) |
| if len(orderbyExpr) != 4 || !strings.EqualFold(orderbyExpr[1], "ORDER BY") || !strings.EqualFold(orderbyExpr[3], "LIMIT") { |
| break |
| } |
| orderbyStr = orderbyExpr[2] |
| } else { |
| queryExpr, _ := matchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) |
| if len(queryExpr) != 4 || !strings.EqualFold(queryExpr[1], "SELECT") || !strings.EqualFold(queryExpr[3], "LIMIT") { |
| break |
| } |
| selectStr = queryExpr[2] |
| } |
|
|
| |
| first, limit := 0, 0 |
| for i := 1; i < len(res[index]); i++ { |
| if strings.TrimSpace(res[index][i]) == "" { |
| continue |
| } |
|
|
| if strings.HasPrefix(res[index][i], "LIMIT") || strings.HasPrefix(res[index][i], "limit") { |
| first, _ = strconv.Atoi(res[index][i+1]) |
| limit, _ = strconv.Atoi(res[index][i+2]) |
| break |
| } |
| } |
|
|
| if haveOrderby { |
| sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit) |
| } else { |
| if first == 0 { |
| first = limit |
| } else { |
| first = limit - first |
| } |
| sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr) |
| } |
| default: |
| } |
| return sql |
| } |
|
|
| |
| func (db *Mssql) QueryWithConnection(con string, query string, args ...interface{}) ([]map[string]interface{}, error) { |
| query = db.handleSqlBeforeExec(query) |
| return CommonQuery(db.DbList[con], query, args...) |
| } |
|
|
| |
| func (db *Mssql) ExecWithConnection(con string, query string, args ...interface{}) (sql.Result, error) { |
| query = db.handleSqlBeforeExec(query) |
| return CommonExec(db.DbList[con], query, args...) |
| } |
|
|
| |
| func (db *Mssql) Query(query string, args ...interface{}) ([]map[string]interface{}, error) { |
| query = db.handleSqlBeforeExec(query) |
| return CommonQuery(db.DbList["default"], query, args...) |
| } |
|
|
| |
| func (db *Mssql) Exec(query string, args ...interface{}) (sql.Result, error) { |
| query = db.handleSqlBeforeExec(query) |
| return CommonExec(db.DbList["default"], query, args...) |
| } |
|
|
| func (db *Mssql) QueryWith(tx *sql.Tx, conn, query string, args ...interface{}) ([]map[string]interface{}, error) { |
| if tx != nil { |
| return db.QueryWithTx(tx, query, args...) |
| } |
| return db.QueryWithConnection(conn, query, args...) |
| } |
|
|
| func (db *Mssql) ExecWith(tx *sql.Tx, conn, query string, args ...interface{}) (sql.Result, error) { |
| if tx != nil { |
| return db.ExecWithTx(tx, query, args...) |
| } |
| return db.ExecWithConnection(conn, query, args...) |
| } |
|
|
| |
| func (db *Mssql) InitDB(cfgs map[string]config.Database) Connection { |
| db.Configs = cfgs |
| db.Once.Do(func() { |
| for conn, cfg := range cfgs { |
|
|
| sqlDB, err := sql.Open("sqlserver", cfg.GetDSN()) |
|
|
| if sqlDB == nil { |
| panic("invalid connection") |
| } |
|
|
| if err != nil { |
| _ = sqlDB.Close() |
| panic(err.Error()) |
| } |
|
|
| sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) |
| sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) |
| sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime) |
| sqlDB.SetConnMaxIdleTime(cfg.ConnMaxIdleTime) |
|
|
| db.DbList[conn] = sqlDB |
|
|
| if err := sqlDB.Ping(); err != nil { |
| panic(err) |
| } |
| } |
| }) |
| return db |
| } |
|
|
| |
| func (db *Mssql) BeginTxWithReadUncommitted() *sql.Tx { |
| return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelReadUncommitted) |
| } |
|
|
| |
| func (db *Mssql) BeginTxWithReadCommitted() *sql.Tx { |
| return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelReadCommitted) |
| } |
|
|
| |
| func (db *Mssql) BeginTxWithRepeatableRead() *sql.Tx { |
| return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelRepeatableRead) |
| } |
|
|
| |
| func (db *Mssql) BeginTx() *sql.Tx { |
| return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelDefault) |
| } |
|
|
| |
| func (db *Mssql) BeginTxWithLevel(level sql.IsolationLevel) *sql.Tx { |
| return CommonBeginTxWithLevel(db.DbList["default"], level) |
| } |
|
|
| |
| func (db *Mssql) BeginTxWithReadUncommittedAndConnection(conn string) *sql.Tx { |
| return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelReadUncommitted) |
| } |
|
|
| |
| func (db *Mssql) BeginTxWithReadCommittedAndConnection(conn string) *sql.Tx { |
| return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelReadCommitted) |
| } |
|
|
| |
| func (db *Mssql) BeginTxWithRepeatableReadAndConnection(conn string) *sql.Tx { |
| return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelRepeatableRead) |
| } |
|
|
| |
| func (db *Mssql) BeginTxAndConnection(conn string) *sql.Tx { |
| return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelDefault) |
| } |
|
|
| |
| func (db *Mssql) BeginTxWithLevelAndConnection(conn string, level sql.IsolationLevel) *sql.Tx { |
| return CommonBeginTxWithLevel(db.DbList[conn], level) |
| } |
|
|
| |
| func (db *Mssql) QueryWithTx(tx *sql.Tx, query string, args ...interface{}) ([]map[string]interface{}, error) { |
| query = db.handleSqlBeforeExec(query) |
| return CommonQueryWithTx(tx, query, args...) |
| } |
|
|
| |
| func (db *Mssql) ExecWithTx(tx *sql.Tx, query string, args ...interface{}) (sql.Result, error) { |
| query = db.handleSqlBeforeExec(query) |
| return CommonExecWithTx(tx, query, args...) |
| } |
|
|