本文主要研究一下gorm的Model
Model
gorm.io/gorm@v1.20.10/model.go
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt // It may be embedded into your model or you may build your own model without it // type User struct { // gorm.Model // } type Model struct { ID uint `gorm:"primarykey"` CreatedAt time.Time UpdatedAt time.Time DeletedAt DeletedAt `gorm:"index"` }
Model定义了ID、CreatedAt、UpdatedAt、DeletedAt属性
ParseField
gorm.io/gorm@v1.20.10/schema/field.go
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var err error field := &Field{ Name: fieldStruct.Name, BindNames: []string{fieldStruct.Name}, FieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type, StructField: fieldStruct, Creatable: true, Updatable: true, Readable: true, Tag: fieldStruct.Tag, TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), Schema: schema, AutoIncrementIncrement: 1, } for field.IndirectFieldType.Kind() == reflect.Ptr { field.IndirectFieldType = field.IndirectFieldType.Elem() } fieldValue := reflect.New(field.IndirectFieldType) // if field is valuer, used its value or first fields as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) //...... field.GORMDataType = field.DataType if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { field.DataType = DataType(dataTyper.GormDataType()) } if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond } else if strings.ToUpper(v) == "MILLI" { field.AutoCreateTime = UnixMillisecond } else { field.AutoCreateTime = UnixSecond } } if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoUpdateTime = UnixNanosecond } else if strings.ToUpper(v) == "MILLI" { field.AutoUpdateTime = UnixMillisecond } else { field.AutoUpdateTime = UnixSecond } } //...... return field }
ParseField方法会解析field的属性,如果field的name为CreatedAt或者UpdatedAt,且dataType为Time、Int、Unit或者tag标注了AUTOCREATETIME或者AUTOUPDATETIME,则会设置field.AutoCreateTime或者field.AutoUpdateTime
TimeType
gorm.io/gorm@v1.20.10/schema/field.go
type TimeType int64 const ( UnixSecond TimeType = 1 UnixMillisecond TimeType = 2 UnixNanosecond TimeType = 3 )
field.AutoCreateTime、AutoUpdateTime属性为TimeType类型,该类型有UnixSecond、UnixMillisecond、UnixNanosecond三种类型
ConvertToCreateValues
gorm.io/gorm@v1.20.10/callbacks/create.go
// ConvertToCreateValues convert to create values func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch value := stmt.Dest.(type) { case map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, value) case *map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, *value) case []map[string]interface{}: values = ConvertSliceOfMapToValuesForCreate(stmt, value) case *[]map[string]interface{}: values = ConvertSliceOfMapToValuesForCreate(stmt, *value) default: var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) curTime = stmt.DB.NowFunc() isZero bool ) values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { values.Columns = append(values.Columns, clause.Column{Name: db}) } } } switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: stmt.SQL.Grow(stmt.ReflectValue.Len() * 18) values.Values = make([][]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} if stmt.ReflectValue.Len() == 0 { stmt.AddError(gorm.ErrEmptySlice) return } for i := 0; i < stmt.ReflectValue.Len(); i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) if !rv.IsValid() { stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) return } values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { if field.DefaultValueInterface != nil { values.Values[i][idx] = field.DefaultValueInterface field.Set(rv, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { field.Set(rv, curTime) values.Values[i][idx], _ = field.ValueOf(rv) } } else if field.AutoUpdateTime > 0 { if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { field.Set(rv, curTime) values.Values[0][idx], _ = field.ValueOf(rv) } } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, isZero := field.ValueOf(rv); !isZero { if len(defaultValueFieldsHavingValue[field]) == 0 { defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len()) } defaultValueFieldsHavingValue[field][i] = v } } } } for field, vs := range defaultValueFieldsHavingValue { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) for idx := range values.Values { if vs[idx] == nil { values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) } else { values.Values[idx] = append(values.Values[idx], vs[idx]) } } } case reflect.Struct: values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface field.Set(stmt.ReflectValue, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { field.Set(stmt.ReflectValue, curTime) values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) } } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], v) } } } default: stmt.AddError(gorm.ErrInvalidData) } } if c, ok := stmt.Clauses["ON CONFLICT"]; ok { if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { if stmt.Schema != nil && len(values.Columns) > 1 { columns := make([]string, 0, len(values.Columns)-1) for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { columns = append(columns, column.Name) } } } onConflict := clause.OnConflict{ Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), DoUpdates: clause.AssignmentColumns(columns), } for idx, field := range stmt.Schema.PrimaryFields { onConflict.Columns[idx] = clause.Column{Name: field.DBName} } stmt.AddClause(onConflict) } } } return values }
ConvertToCreateValues从stmt.DB.NowFunc()获取curTime,然后对于field.AutoCreateTime或者field.AutoUpdateTime大于0的,会设置curTime
setupValuerAndSetter
gorm.io/gorm@v1.20.10/schema/field.go
// create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { //...... // Set switch field.FieldType.Kind() { case reflect.Bool: field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case bool: field.ReflectValueOf(value).SetBool(data) case *bool: if data != nil { field.ReflectValueOf(value).SetBool(*data) } else { field.ReflectValueOf(value).SetBool(false) } case int64: if data > 0 { field.ReflectValueOf(value).SetBool(true) } else { field.ReflectValueOf(value).SetBool(false) } case string: b, _ := strconv.ParseBool(data) field.ReflectValueOf(value).SetBool(b) default: return fallbackSetter(value, v, field.Set) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case int64: field.ReflectValueOf(value).SetInt(data) case int: field.ReflectValueOf(value).SetInt(int64(data)) case int8: field.ReflectValueOf(value).SetInt(int64(data)) case int16: field.ReflectValueOf(value).SetInt(int64(data)) case int32: field.ReflectValueOf(value).SetInt(int64(data)) case uint: field.ReflectValueOf(value).SetInt(int64(data)) case uint8: field.ReflectValueOf(value).SetInt(int64(data)) case uint16: field.ReflectValueOf(value).SetInt(int64(data)) case uint32: field.ReflectValueOf(value).SetInt(int64(data)) case uint64: field.ReflectValueOf(value).SetInt(int64(data)) case float32: field.ReflectValueOf(value).SetInt(int64(data)) case float64: field.ReflectValueOf(value).SetInt(int64(data)) case []byte: return field.Set(value, string(data)) case string: if i, err := strconv.ParseInt(data, 0, 64); err == nil { field.ReflectValueOf(value).SetInt(i) } else { return err } case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) } else { field.ReflectValueOf(value).SetInt(data.Unix()) } case *time.Time: if data != nil { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) } else { field.ReflectValueOf(value).SetInt(data.Unix()) } } else { field.ReflectValueOf(value).SetInt(0) } default: return fallbackSetter(value, v, field.Set) } return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case uint64: field.ReflectValueOf(value).SetUint(data) case uint: field.ReflectValueOf(value).SetUint(uint64(data)) case uint8: field.ReflectValueOf(value).SetUint(uint64(data)) case uint16: field.ReflectValueOf(value).SetUint(uint64(data)) case uint32: field.ReflectValueOf(value).SetUint(uint64(data)) case int64: field.ReflectValueOf(value).SetUint(uint64(data)) case int: field.ReflectValueOf(value).SetUint(uint64(data)) case int8: field.ReflectValueOf(value).SetUint(uint64(data)) case int16: field.ReflectValueOf(value).SetUint(uint64(data)) case int32: field.ReflectValueOf(value).SetUint(uint64(data)) case float32: field.ReflectValueOf(value).SetUint(uint64(data)) case float64: field.ReflectValueOf(value).SetUint(uint64(data)) case []byte: return field.Set(value, string(data)) case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) } else { field.ReflectValueOf(value).SetUint(uint64(data.Unix())) } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { field.ReflectValueOf(value).SetUint(i) } else { return err } default: return fallbackSetter(value, v, field.Set) } return err } case reflect.Float32, reflect.Float64: field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case float64: field.ReflectValueOf(value).SetFloat(data) case float32: field.ReflectValueOf(value).SetFloat(float64(data)) case int64: field.ReflectValueOf(value).SetFloat(float64(data)) case int: field.ReflectValueOf(value).SetFloat(float64(data)) case int8: field.ReflectValueOf(value).SetFloat(float64(data)) case int16: field.ReflectValueOf(value).SetFloat(float64(data)) case int32: field.ReflectValueOf(value).SetFloat(float64(data)) case uint: field.ReflectValueOf(value).SetFloat(float64(data)) case uint8: field.ReflectValueOf(value).SetFloat(float64(data)) case uint16: field.ReflectValueOf(value).SetFloat(float64(data)) case uint32: field.ReflectValueOf(value).SetFloat(float64(data)) case uint64: field.ReflectValueOf(value).SetFloat(float64(data)) case []byte: return field.Set(value, string(data)) case string: if i, err := strconv.ParseFloat(data, 64); err == nil { field.ReflectValueOf(value).SetFloat(i) } else { return err } default: return fallbackSetter(value, v, field.Set) } return err } case reflect.String: field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case string: field.ReflectValueOf(value).SetString(data) case []byte: field.ReflectValueOf(value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: field.ReflectValueOf(value).SetString(utils.ToString(data)) case float64, float32: field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: return fallbackSetter(value, v, field.Set) } return err } default: fieldValue := reflect.New(field.FieldType) switch fieldValue.Elem().Interface().(type) { case time.Time: field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case *time.Time: if data != nil { field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) } else { field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) } case string: if t, err := now.Parse(data); err == nil { field.ReflectValueOf(value).Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: return fallbackSetter(value, v, field.Set) } return nil } case *time.Time: field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: fieldValue := field.ReflectValueOf(value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(v)) case *time.Time: field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { fieldValue := field.ReflectValueOf(value) if fieldValue.IsNil() { if v == "" { return nil } fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: return fallbackSetter(value, v, field.Set) } return nil } default: if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { return field.Set(value, reflectV.Elem().Interface()) } } else { fieldValue := field.ReflectValueOf(value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() } err = fieldValue.Interface().(sql.Scanner).Scan(v) } return } } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { return field.Set(value, reflectV.Elem().Interface()) } } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() } err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else { field.Set = func(value reflect.Value, v interface{}) (err error) { return fallbackSetter(value, v, field.Set) } } } } }
setupValuerAndSetter方法针对time.Time或*time.Time类型的setter会根据TimeType再做时间精度处理
实例
type Product struct { gorm.Model Code string Price uint }
Product内嵌了gorm.Model,内置了ID、CreatedAt、UpdatedAt、DeletedAt属性,同时Create的时候会自动设置CreatedAt、UpdatedAt,Update的时候会自动更新UpdatedAt
小结
gorm定义了ID、CreatedAt、UpdatedAt、DeletedAt属性;其中Create的时候会自动设置CreatedAt、UpdatedAt,Update的时候会自动更新UpdatedAt;CreatedAt、UpdatedAt支持 UnixSecond、UnixMillisecond、UnixNanosecond三种时间精度。