| |
| |
| |
|
|
| package dialect |
|
|
| import ( |
| "strings" |
|
|
| "github.com/GoAdminGroup/go-admin/modules/config" |
| ) |
|
|
| |
| type Dialect interface { |
| |
| GetName() string |
|
|
| |
| ShowColumns(table string) string |
|
|
| |
| ShowColumnsWithComment(schema, table string) string |
|
|
| |
| ShowTables() string |
|
|
| |
| Insert(comp *SQLComponent) string |
|
|
| |
| Delete(comp *SQLComponent) string |
|
|
| |
| Update(comp *SQLComponent) string |
|
|
| |
| Select(comp *SQLComponent) string |
|
|
| |
| GetDelimiter() string |
| } |
|
|
| |
| func GetDialect() Dialect { |
| return GetDialectByDriver(config.GetDatabases().GetDefault().Driver) |
| } |
|
|
| |
| func GetDialectByDriver(driver string) Dialect { |
| switch driver { |
| case "mysql": |
| return mysql{ |
| commonDialect: commonDialect{delimiter: "`", delimiter2: "`"}, |
| } |
| case "mssql": |
| return mssql{ |
| commonDialect: commonDialect{delimiter: "[", delimiter2: "]"}, |
| } |
| case "postgresql": |
| return postgresql{ |
| commonDialect: commonDialect{delimiter: `"`, delimiter2: `"`}, |
| } |
| case "sqlite": |
| return sqlite{ |
| commonDialect: commonDialect{delimiter: "`", delimiter2: "`"}, |
| } |
| case "oceanbase": |
| return oceanbase{ |
| commonDialect: commonDialect{delimiter: "`", delimiter2: "`"}, |
| } |
| default: |
| return commonDialect{delimiter: "`", delimiter2: "`"} |
| } |
| } |
|
|
| |
| type H map[string]interface{} |
|
|
| |
| type SQLComponent struct { |
| Fields []string |
| Functions []string |
| TableName string |
| Wheres []Where |
| Leftjoins []Join |
| Args []interface{} |
| Order string |
| Offset string |
| Limit string |
| WhereRaws string |
| UpdateRaws []RawUpdate |
| Group string |
| Statement string |
| Values H |
| } |
|
|
| |
| type Where struct { |
| Operation string |
| Field string |
| Qmark string |
| } |
|
|
| |
| type Join struct { |
| Table string |
| FieldA string |
| Operation string |
| FieldB string |
| } |
|
|
| |
| type RawUpdate struct { |
| Expression string |
| Args []interface{} |
| } |
|
|
| |
| |
| |
|
|
| func (sql *SQLComponent) getLimit() string { |
| if sql.Limit == "" { |
| return "" |
| } |
| return " limit " + sql.Limit + " " |
| } |
|
|
| func (sql *SQLComponent) getOffset() string { |
| if sql.Offset == "" { |
| return "" |
| } |
| return " offset " + sql.Offset + " " |
| } |
|
|
| func (sql *SQLComponent) getOrderBy() string { |
| if sql.Order == "" { |
| return "" |
| } |
| return " order by " + sql.Order + " " |
| } |
|
|
| func (sql *SQLComponent) getGroupBy() string { |
| if sql.Group == "" { |
| return "" |
| } |
| return " group by " + sql.Group + " " |
| } |
|
|
| func (sql *SQLComponent) getJoins(delimiter, delimiter2 string) string { |
| if len(sql.Leftjoins) == 0 { |
| return "" |
| } |
| joins := "" |
| for _, join := range sql.Leftjoins { |
| joins += " left join " + wrap(delimiter, delimiter2, join.Table) + " on " + |
| sql.processLeftJoinField(join.FieldA, delimiter, delimiter2) + " " + join.Operation + " " + |
| sql.processLeftJoinField(join.FieldB, delimiter, delimiter2) + " " |
| } |
| return joins |
| } |
|
|
| func (sql *SQLComponent) processLeftJoinField(field, delimiter, delimiter2 string) string { |
| arr := strings.Split(field, ".") |
| if len(arr) > 0 { |
| return delimiter + arr[0] + delimiter2 + "." + delimiter + arr[1] + delimiter2 |
| } |
| return field |
| } |
|
|
| func (sql *SQLComponent) getFields(delimiter, delimiter2 string) string { |
| if len(sql.Fields) == 0 { |
| return "*" |
| } |
| fields := "" |
| if len(sql.Leftjoins) == 0 { |
| for k, field := range sql.Fields { |
| if sql.Functions[k] != "" { |
| fields += sql.Functions[k] + "(" + wrap(delimiter, delimiter2, field) + ")," |
| } else { |
| fields += wrap(delimiter, delimiter2, field) + "," |
| } |
| } |
| } else { |
| for _, field := range sql.Fields { |
| arr := strings.Split(field, ".") |
| if len(arr) > 1 { |
| fields += wrap(delimiter, delimiter2, arr[0]) + "." + wrap(delimiter, delimiter2, arr[1]) + "," |
| } else { |
| fields += wrap(delimiter, delimiter2, field) + "," |
| } |
| } |
| } |
| return fields[:len(fields)-1] |
| } |
|
|
| func wrap(delimiter, delimiter2, field string) string { |
| if field == "*" { |
| return "*" |
| } |
| return delimiter + field + delimiter2 |
| } |
|
|
| func (sql *SQLComponent) getWheres(delimiter, delimiter2 string) string { |
| if len(sql.Wheres) == 0 { |
| if sql.WhereRaws != "" { |
| return " where " + sql.WhereRaws |
| } |
| return "" |
| } |
| wheres := " where " |
| var arr []string |
| for _, where := range sql.Wheres { |
| arr = strings.Split(where.Field, ".") |
| if len(arr) > 1 { |
| wheres += arr[0] + "." + wrap(delimiter, delimiter2, arr[1]) + " " + where.Operation + " " + where.Qmark + " and " |
| } else { |
| wheres += wrap(delimiter, delimiter2, where.Field) + " " + where.Operation + " " + where.Qmark + " and " |
| } |
| } |
|
|
| if sql.WhereRaws != "" { |
| return wheres + sql.WhereRaws |
| } |
| return wheres[:len(wheres)-5] |
| } |
|
|
| func (sql *SQLComponent) prepareUpdate(delimiter, delimiter2 string) { |
| fields := "" |
| args := make([]interface{}, 0) |
|
|
| if len(sql.Values) != 0 { |
|
|
| for key, value := range sql.Values { |
| fields += wrap(delimiter, delimiter2, key) + " = ?, " |
| args = append(args, value) |
| } |
|
|
| if len(sql.UpdateRaws) == 0 { |
| fields = fields[:len(fields)-2] |
| } else { |
| for i := 0; i < len(sql.UpdateRaws); i++ { |
| if i == len(sql.UpdateRaws)-1 { |
| fields += sql.UpdateRaws[i].Expression + " " |
| } else { |
| fields += sql.UpdateRaws[i].Expression + "," |
| } |
| args = append(args, sql.UpdateRaws[i].Args...) |
| } |
| } |
|
|
| sql.Args = append(args, sql.Args...) |
| } else { |
| if len(sql.UpdateRaws) == 0 { |
| panic("prepareUpdate: wrong parameter") |
| } else { |
| for i := 0; i < len(sql.UpdateRaws); i++ { |
| if i == len(sql.UpdateRaws)-1 { |
| fields += sql.UpdateRaws[i].Expression + " " |
| } else { |
| fields += sql.UpdateRaws[i].Expression + "," |
| } |
| args = append(args, sql.UpdateRaws[i].Args...) |
| } |
| } |
| sql.Args = append(args, sql.Args...) |
| } |
|
|
| sql.Statement = "update " + delimiter + sql.TableName + delimiter2 + " set " + fields + sql.getWheres(delimiter, delimiter2) |
| } |
|
|
| func (sql *SQLComponent) prepareInsert(delimiter, delimiter2 string) { |
| fields := " (" |
| quesMark := "(" |
|
|
| for key, value := range sql.Values { |
| fields += wrap(delimiter, delimiter2, key) + "," |
| quesMark += "?," |
| sql.Args = append(sql.Args, value) |
| } |
| fields = fields[:len(fields)-1] + ")" |
| quesMark = quesMark[:len(quesMark)-1] + ")" |
|
|
| sql.Statement = "insert into " + delimiter + sql.TableName + delimiter2 + fields + " values " + quesMark |
| } |
|
|