database/sql 一点深入理解

我们主要来分析一下查询的实现,涉及的是database/sqlgo-sql-driver/mysql的部分源码。database/sql是go对于db抽象出的一个标准库,go-sql-driver/mysql是实现了database/sql驱动接口的mysql驱动。

什么是数据库驱动?

简单来讲,数据库驱动实现了mysql协议,比如连接数据库,驱动会给数据库服务器发送握手初始化报文登陆认证报文,拼出报文头,将username,password放在报文的固定位置,将[]byte数据写入到socket,这就是驱动的主要功能,他给我们(这里指database/sql)完成了底层的操作

查询的接口主要有两个:

1
2
3
4
5
6
7
8
9
10
/*
执行一个查询并返回多个数据行, 这个查询通常是一个 SELECT 。 方法的 arg 部分用于填写查询语句中包含的占位符的实际参数。
*/
func (db *DB) Query(query string, args ...interface{}) (*Rows, error)
/*
执行一个预期最多只会返回一个数据行的查询。
这个方法总是会返回一个非空的值, 而它引起的错误则会被推延到数据行的 Scan 方法被调用为止。
*/
func (db *DB) QueryRow(query string, args ...interface{}) *Row

再来看看返回值的结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
type Rows struct {
dc *driverConn // owned; must call releaseConn when closed to release
releaseConn func(error)
rowsi driver.Rows
cancel func() // called when Rows is closed, may be nil.
closeStmt *driverStmt // if non-nil, statement to Close on close
closemu sync.RWMutex
closed bool
lasterr error // non-nil only if closed is true
// lastcols is only used in Scan, Next, and NextResultSet which are expected
// not to be called concurrently.
lastcols []driver.Value
}
// Row is the result of calling QueryRow to select a single row.
type Row struct {
// One of these two will be non-nil:
err error // deferred error for easy chaining
rows *Rows
}

可以明显地看到Row将Rows包了一层,加了一个err字段,上面注释中说的而它引起的错误则会被推延到数据行的 Scan 方法被调用为止也就可以理解了,错误在error上而没有直接返回。

个人认为看源码最好带着问题看,有目的往往更能坚持。我们先来看一段最基本的调用代码,看看能不能找出几个我们感兴趣的问题:

1
2
3
4
5
6
7
8
9
rows, err := db.Query("SELECT * FROM User WHERE id=?", id_param)
...
defer rows.Close()
for rows.Next() {
var id int
var name string
err = rows.Scan(&id, &name)
...
}

我的问题来了:

  1. 这里query用的sql格式看上去是prepared statement,我们知道prepared statement是可以防注入的。我们也知道不能够直接拼接字符串,直接拼接定会引入注入问题。那么问题来了,prepared statement会增加与服务器的交互,影响性能,是否能不使用prepared statement,如果不使用,像上面这样的query能否防注入呢?另外上面我们也只是说了看上去是,底层是否真的在使用prepared statement

    要想使用prepared statement在数据库提前编译,复用的特性需要在客户端做很多事,比如JDBC就实现了一整套方案。目前看来database/sql没有这样的实现,所以这里可以忽略编译的优势。

  2. 这个for的写法真是奇怪,简直看不懂到底在遍历什么,每次都是rows.Scan, rows里面到底存了些什么可以这么玩

下面我们从这些问题出发来研究一下实现,我们会忽略掉与问题无关的部分(比如连接池),只关注我们需要的链路。

Prepared Statement & 防注入

Query之后还有几个函数调用才能到下面的函数,包括了取连接和错误处理的逻辑,这些对我们的问题没有影响,暂且忽略它们。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
// 我们的driver实现了这个接口,这里可以走进去
if queryer, ok := dc.ci.(driver.Queryer); ok {
dargs, err := driverArgs(dc.ci, nil, args)
if err != nil {
releaseConn(err)
return nil, err
}
var rowsi driver.Rows
withLock(dc, func() {
rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs) // 尝试不使用`Prepared Statement`来执行
})
// 这个错误非常重要,不想使用`Prepared Statement`,上面就不能返回这个错误
if err != driver.ErrSkip {
if err != nil {
releaseConn(err)
return nil, err
}
// Note: ownership of dc passes to the *Rows, to be freed
// with releaseConn.
rows := &Rows{
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
}
rows.initContextClose(ctx, txctx)
return rows, nil
}
}
// 下面将会使用Prepared Statement,如果不想使用就要保证上面的代码能顺利return
var si driver.Stmt
var err error
withLock(dc, func() {
si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
.....
}

从上面的注释可以看到,Query默认不使用prepared statement,但是要保证正常查询不返回driver.ErrSkip,一旦返回这个错误,则会使用prepared statement继续查询。再往下跟我们将会走到driver中,我们来看看真实的query,在go-sql-driver/mysqlconnection.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// 考虑到注入问题所以一定有args,将会走进if
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip // 我们就是不想返回这个错误,所以要保证mc.cfg.InterpolateParams为true
}
// try client-side prepare to reduce roundtrip
prepared, err := mc.interpolateParams(query, args) // 这里将进行插值,能不能防注入就看这里了
if err != nil {
return nil, err
}
query = prepared
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
......
}

InterpolateParams从字面上看意思是是否可以插值,它其实是dsn的一个参数,见dsn.go

1
2
3
4
5
6
7
// Enable client side placeholder substitution
case "interpolateParams":
var isBool bool
cfg.InterpolateParams, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}

这里我们就清楚了,如果不想使用prepared statement,就要在dsn中加入interpolateParams=true,允许驱动进行对sql进行插值。下一个问题,这里的插值能不能防注入?我们来看interpolateParams方法,我们只看参数为string的情况:

1
2
3
4
5
6
7
8
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v) // 将会走到这里
} else {
buf = escapeStringQuotes(buf, v)
}
buf = append(buf, '\'')

从函数名我相信大家已经知道了,我们将会对字符串参数进行转义,而转义特殊字符是可以防止注入的,不妨再看看escapeStringBackslash

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
func escapeBytesBackslash(buf, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for _, c := range v {
switch c {
case '\x00':
buf[pos] = '\\'
buf[pos+1] = '0'
pos += 2
case '\n':
buf[pos] = '\\'
buf[pos+1] = 'n'
pos += 2
case '\r':
buf[pos] = '\\'
buf[pos+1] = 'r'
pos += 2
case '\x1a':
buf[pos] = '\\'
buf[pos+1] = 'Z'
pos += 2
case '\'':
buf[pos] = '\\'
buf[pos+1] = '\''
pos += 2
case '"':
buf[pos] = '\\'
buf[pos+1] = '"'
pos += 2
case '\\':
buf[pos] = '\\'
buf[pos+1] = '\\'
pos += 2
default:
buf[pos] = c
pos++
}
}
return buf[:pos]
}

一目了然。上面的分析基本解决了我们的第一个问题:dsn中加入interpolateParams=true可以不使用prepared statement,将sql的参数传入Query方法可以防止注入,防注入是驱动通过转义特殊字符来实现的。

第二个问题我们可以接着第一个问题的代码往下看,上面我们看了拼接好sql,下面就是真正的查询了,代码在statement.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
// 这个方法背后就是实现了`客户端命令请求报文`,非常值得一看,但是我们这里关心的是返回值,暂且跳过
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, stmt.mc.markBadConn(err)
}
mc := stmt.mc
// Read Result
// 这里虽然说是read result,但是只是根据协议读出了列的数量,用于下面的readColumns
// 注意,即使数据库中没有找到对应的记录,数据库仍然会将字段的信息返回,只是在返回的报文中
// Row Data没有数据
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
rows := new(binaryRows)
if resLen > 0 {
rows.mc = mc
rows.rs.columns, err = mc.readColumns(resLen) // 这里会将字段名(列名)读出然后存起来
} else {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
// 竟然就这么返回了
return rows, err
}

看到了吗,query最后返回的时候,只是把字段名读出来了而已,根本没有读到我们需要的数据库中的记录。不妨大胆猜测:for Next将会每次从buffer中读出一条记录,将其赋值给某个变量,Scan就是在解析这个变量。让我们回到database/sql,在sql.go中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
func (rs *Rows) nextLocked() (doClose, ok bool) {
if rs.closed {
return false, false
}
if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) // 哇哈哈,我想就是存在这里了
}
// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
rs.dc.Lock()
defer rs.dc.Unlock()
rs.lasterr = rs.rowsi.Next(rs.lastcols) // 这里应该就是在读数据然后存到rs.lastcols
if rs.lasterr != nil {
......
}
return false, true
}

再回到go-sql-driver/mysql

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// rows.go
func (rows *binaryRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if err := mc.error(); err != nil {
return err
}
// Fetch next row from stream
// 读数据并存到dest中,readRow中实现了针对Row Data 结构的解析
// 复杂但真的值得一看
return rows.readRow(dest)
}
return io.EOF
}

来来回回啊,再看database/sql的sql.go

1
2
3
4
5
6
7
8
9
10
11
func (rs *Rows) Scan(dest ...interface{}) error {
......
// 看到了吧,就是lastcols,遍历lastcols,将字段值放入到dest中
for i, sv := range rs.lastcols {
err := convertAssign(dest[i], sv)
if err != nil {
return fmt.Errorf("sql: Scan error on column index %d: %v", i, err)
}
}
return nil
}

至此我们的第二个问题也解决了,Query返回的时候只是取出了字段信息,真实的数据库记录还留在buffer中,for循环Next,每次从buffer中读取一条记录,存在rows结构体的lastcols字段中,调用Scan的时候就是从lastcols取出值。

上面只是对大致流程的分析,里面还有大量的细节没有涉及到,特别是对于协议的实现,非常值得一看。

本文如有错误,欢迎联系O(∩_∩)O