由于业务需求,需要将数据从mysql上迁移到mongo上,我从网络上查了一下资料,网上的方案大多是使用工具导出,但我觉得很麻烦,于是打算自己用golang写个小程序来完成这个功能

迁移思路

1、一般方式

一般的方式是先根据数据表的定义,在golang中定义一个结构体,打上对应的tag,然后使用sqlx来将读取到的数据反序列化到结构体中来,然后再序列化成mongo的document插入数据库。过程如下:

mysql -> 读取 -> []byte(这个是反序列化之前的数据) -> 结构体 -> 插入数据库

但其实这样做有一个很麻烦的地方,我需要先根据每一个表的结构来定义一个结构体,每多一个表,我就需要手动来添加一个结构体,程序的灵活性非常堪忧。

2、另辟蹊径

其实在sql包查询的时候,其实是可以获得表里面每一列的类型信息和列名,我们可以根据这些类型信息,将上面过程中的[]byte 转化为相应的类型。不过中间比较麻烦的地方在于,这些类型并不仅仅是golang的基本类型,比如int、string等,而且也包括sql包包装过的类型,如sql.NullInt64等(参见 database/sql/sql.go),所以中间会经历两次类型的转化。过程如下:

mysql -> 读取 -> []byte -> sql.NullInt64等类型 ->int,string 等基本类型 -> 列名和值组成的map -> 插入数据库

3、会遇到的问题

如何将[]byte转化为 sql.NullInt64类型?rows.Scan(dest …interface{})方法的作用是将[]byte 给转化为dest 对应的类型,它会先查看dest的类型,如果是一般的类型都可以很简单的转化,但如果是比较复杂的类型该怎么办呢? 此时它会看dest是否实现了Scanner 接口(接口定义:Scan(src interface{}) error,注意区别这个Scan跟rows.Scan是不一样的)。说白了它的逻辑就是rows.Scan不知道如何将一个[]byte 给转化为一个复杂的结构体,但是这个结构体实现了Scanner接口,说明他自己知道如何把[]byte转化为他自己,那rows.Scan就把转化的任务交给Scanner吧!而刚好 sql.NullInt64就是一个Scanner,所以调用我们只需要给 rows.Scan一个sql.NullInt64对象的interface{}就好啦

如何获得一个sql.NullInt64的对象的interface{} ?读者可能会想,可以直接在代码里面写一个空的sql.NullInt64的对象然后再转成interface{}就好了。但是,由于这个小程序并不是针对某一个确定的数据表的,所以事先我们并不能知道某一个字段具体信息,那么就只能在运行的时候根据sql查询获得的类型信息来生成一个sql.NullInt64的对象。 上文中说过,我们可以获得每一列的类型信息,其实这个类型信息就是反射的Type对象,根据这个Type对象我们就可以生成一个sql.NullInt64的对象:

reflect.New(coltypes[i].ScanType()).Interface()

最后一个问题:如何将sql.NullInt64 插入到mongo数据库中?读者可能由会想,上一个问题中我们不是已经获得了sql.NullInt64的对象的interface{} 了吗,我们可以将这个interface{} 再转回sql.NullInt64的对象,然后再提取出这个对象里面包含的那个int64不就好了吗?

如果读者愿意,当然可以去挨个判断你获得的interface{}对象是sql.NullInt64还是sql.NullFloat64,但其实有更简单的方式:sql.NullInt64实现了Valuer接口(Value() (Value, error)) 直接调用这个接口就可以返回获得他包含的值。

代码

下面是我写的小程序源码,只有200行代码,中间的一些敏感信息已经去掉了,读者可以根据自己需求添加上去。package main

import (
"database/sql"
"database/sql/driver"
"fmt"
_ "github.com/go-sql-driver/mysql"
"gopkg.in/mgo.v2"
"reflect"
"sort"
"strings"
"sync"
"time"
)
var wg sync.WaitGroup
const DATABASE = ""
const NUMQUERYDATA = 30
type toPrettyFunc func(interface{}) interface{}
type statProc struct {
Name string
Finished int
Total int
}
var chStat chan statProc
func main() {
chStat = make(chan statProc, 256)
db, err := sql.Open("mysql", "")
if err != nil {
panic(err)
}
sess, err := mgo.Dial("")
if err != nil {
panic(err)
}
wg.Add(1)
go migrate(db, sess, DATABASE, "xxx")
go statistic()
wg.Wait()
close(chStat)
}
func migrate(db *sql.DB, sess *mgo.Session, dbname, table string) {
defer func() {
fmt.Printf("migrate table %s done.n", table)
chStat 
Name: table,
Total: -1,
}
if x := recover(); x != nil {
fmt.Println(x)
}
wg.Done()
}()
mdb := sess.DB(dbname)
//条目总数
var count int
rw := db.QueryRow(fmt.Sprintf(`SELECT COUNT(*) FROM %s`, table))
err := rw.Scan(&count)
if err != nil {
panic(err)
}
//读取一条数据,并获得表结构
tmp, err := db.Query(fmt.Sprintf(`SELECT * FROM %s LIMIT 1`, table))
if err != nil {
panic(err)
}
//列名
colnames, err := tmp.Columns()
if err != nil {
panic(err)
}
//列类型
coltypes, err := tmp.ColumnTypes()
if err != nil {
panic(err)
}
results := make([]interface{}, 0, len(colnames)) //储存查询结果,interface的类型是根据coltypes来定义的,一般为sql.NullInt64之类的
resultsPretty := make([]interface{}, len(colnames)) //将查询的结果转化为一般类型,如int,string 等
resultsPrettyFunc := make([]toPrettyFunc, 0, len(colnames)) //转化函数
// 根据mysql类型来初始化results的类型
for i := range coltypes {
newObj := reflect.New(coltypes[i].ScanType()).Interface()
results = append(results, newObj)
switch newObj.(type) {
case driver.Valuer:
resultsPrettyFunc = append(resultsPrettyFunc, func(i interface{}) interface{} {
valuer, _ := i.(driver.Valuer)
v, _ := valuer.Value()
return v
})
case *sql.RawBytes:
resultsPrettyFunc = append(resultsPrettyFunc, func(i interface{}) interface{} {
data, _ := i.(*sql.RawBytes)
return string(*data)
})
default:
resultsPrettyFunc = append(resultsPrettyFunc, func(i interface{}) interface{} {
return reflect.ValueOf(i).Elem().Interface()
})
}
}
i := 0
for ; i < count; i = i + NUMQUERYDATA {
//批量查询数据库
rws, err := db.Query(fmt.Sprintf(`SELECT * FROM %s LIMIT ? OFFSET ?`, table), NUMQUERYDATA, i)
if err != nil {
panic(err)
}
for rws.Next() {
err := rws.Scan(results...)
if err != nil {
panic(err)
}
//转化为一般类型
for j := range results {
resultsPretty[j] = resultsPrettyFunc[j](results[j])
}
//生成mongo的文档
mdata := make(map[string]interface{})
for j := range resultsPretty {
mdata[colnames[j]] = resultsPretty[j]
}
err = mdb.C(table).Insert(mdata)
if err != nil {
fmt.Println("insert mongo error: ", err)
}
}
chStat 
Name: table,
Finished: i,
Total: count,
}
}
}
//统计函数
func statistic() {
procs := make(map[string]statProc)
t := time.NewTicker(time.Millisecond * 300)
for {
select {
case sp, ok := 
if !ok {
return
}
if sp.Total == -1 {
delete(procs, sp.Name)
continue
}
procs[sp.Name] = sp
case 
printStat(procs)
}
}
}
//显示统计结果
func printStat(procs map[string]statProc) {
var procslice []statProc
fmt.Print("r")
for _, proc := range procs {
procslice = append(procslice, proc)
}
sort.Slice(procslice, func(i, j int) bool {
return strings.Compare(procslice[i].Name, procslice[j].Name) < 0
})
for i := range procslice {
fmt.Printf("%10s %2.2f%%(%d)",
procslice[i].Name,
float32(procslice[i].Finished)/ float32(procslice[i].Total) * 100,
procslice[i].Total)
}
}