MySQL通讯协议(3)连接阶段

痴心易碎 提交于 2020-04-06 06:08:54

[TOC]

MySQL通讯协议(3)连接阶段

MySQL 连接生命周期

graph TD
A[开始] --> |连接|B(ConnectionState)
B --> |认证成功|C(CommandState)
C --> |复制命令|D(ReplicationMode)
B --> |复制命令|D
B --> |错误或断开|End
C --> |关闭连接|End
D --> |关闭连接|End[结束]

MySQL连接是有状态的,当服务接通后,首先会进入连接状态,进行认证,如交换信息、认证账号密码等。认证成功后,进入命令阶段,提交命令接受响应。同时,在连接阶段和命令阶段受到复制命令,都可以进入复制模式。

连接阶段

连接阶段主要做三件事:

  • 交换客户机和服务器的支持的功能

  • 如果需要,设置SSL通信通道

  • 服务器对客户端进行身份认证

Plain Handshake

1, Initial Handshake Packet

连接建立之后,服务端先发送初始握手包。以最新的HandshakeV10为例:

Type Name Description
int<1> protocol version 协议版本:10
string[NUL] server version 易读的服务器版本
int<4> thread id 连接id
string[8] auth-plugin-data-part-1 认证插件数据第一部分
int<1> filler 填充位,固定0x00
int<2> capability_flags_1 功能标志的低位两个字节
int<1> character_set 服务器默认编码
int<2> status_flags 服务器状态
int<2> capability_flags_2 功能标志的高位两个字节
if capabilities & CLIENT_PLUGIN_AUTH {
int<1> auth_plugin_data_len 认证插件数据长度
} else {
int<1> 00 固定常量0x00
}
string[10] reserved 保留部分,用0x00填充
string[$len] auth-plugin-data-part-2 认证插件数据第二部分, 长度$len=MAX(13, length of auth-plugin-data - 8)
if capabilities & CLIENT_PLUGIN_AUTH {
string[NUL] auth_plugin_name 认证插件名
}

字段含义如下:

  • protocol_version : 协议版本,当前版本为:0x0a也就是10

  • server_version :服务器版本,如:8.0.19

  • connection_id :连接id

  • auth_plugin_data_part_1 :认证插件数据(即加密种子)第一部分

  • filler_1 :空位:0x00

  • capability_flag_1 :服务器功能的低位2个字节,用每个bit代表一种功能,2个字节能保存16种功能。 Protocol::CapabilityFlags

  • character_set :服务器默认编码,对应编码这个编码集合的id。Protocol::CharacterSet

  • status_flags :服务器状态。 Protocol::StatusFlags

  • capability_flags_2 :服务器功能的低位2个字节。 Protocol::CapabilityFlags

  • auth_plugin_data_len :认证插件数据两部分的总长度。

  • auth_plugin_name :认证插件的名字。

知道了包格式和字段意义,就可以写代码了。

import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.net.Socket;

public class HandshakeV10Console {
    public static void main(String[] args) throws IOException {
        Socket socket = new Socket();
        socket.connect(new InetSocketAddress("127.0.0.1", 3306));
        InputStream in = socket.getInputStream();

        byte[] head = new byte[4];
        while (in.read(head) != 4) {
        }

        final int length = head[0] + (head[1] & 0xff << 8) + (head[2] & 0xff << 16);
        System.out.println("length:" + length);
        final int seq = head[3];
        System.out.println("seq:" + seq);

        byte[] body = new byte[length];
        while ((in.read(body)) != length) {
        }

        final int protocolVersion = body[0];
        System.out.println("protocolVersion:" + protocolVersion);

        int position = 1, p = 0;
        for (; ; ) {
            if (body[position + p] == 0) {
                break;
            }
            p++;
        }
        byte[] ssa = new byte[p];
        System.arraycopy(body, position, ssa, 0, ssa.length);
        final String serverVersion = new String(ssa);
        System.out.println("serverVersion:" + serverVersion);

        position = position + p + 1;

        byte[] cida = new byte[4];
        System.arraycopy(body, position, cida, 0, cida.length);
        final int connectionId = (cida[0] & 0xff) + ((cida[1] & 0xff) << 8) + ((cida[2] & 0xff) << 16) + ((cida[3] & 0xff) << 24);
        System.out.println("connectionId:" + connectionId);

        position += 4;

        // auth-plugin-data-part-1
        byte[] apdpa1 = new byte[8];
        System.arraycopy(body, position, apdpa1, 0, apdpa1.length);
        final String authPluginDataPart1 = new String(apdpa1);
        System.out.println("authPluginDataPart1:" + authPluginDataPart1);

        position += 9; //filler(1) == 0x00


        //capability_flag_1 (2)
        byte[] cfa = new byte[2];
        System.arraycopy(body, position, cfa, 0, cfa.length);
        final int capabilityFlag = (cfa[0] & 0xff) + ((cfa[1] & 0xff) << 8);
        System.out.println("capabilityFlag:" + capabilityFlag);//65535 = ffff

        position += 2;

        //character_set (1)
        final int characterSet = (body[position] & 0xff);
        System.out.println("characterSet:" + characterSet);//33 = utf8_general_ci

        position += 1;

        //status_flags (2)
        final int statusFlags = (body[position] & 0xff) + ((body[position + 1] & 0xff) << 8);
        System.out.println("statusFlags:" + statusFlags);//2 = auto-commit is enabled

        position += 2;

        //capability_flag_2 (2)
        byte[] cfa2 = new byte[2];
        System.arraycopy(body, position, cfa2, 0, cfa.length);
        final int capabilityFlag2 = ((cfa2[0] & 0xff) << 16) + ((cfa2[1] & 0xff) << 24);
        System.out.println("capabilityFlag2:" + capabilityFlag2);//65535 = ffff

        position += 2;

        //auth_plugin_data_len (1)
        final int authPluginDataLen = (body[position] & 0xff);
        System.out.println("authPluginDataLen:" + authPluginDataLen);// 0x00080000

        position += 1;

        position += 10;//reserved (all [00])

        int capabilities = capabilityFlag + capabilityFlag2;
        if ((capabilities & 0x00008000) != 0) {
            // auth-plugin-data-part-2
            int len = Math.max(13, authPluginDataLen - 8);
            System.out.println("auth-plugin-data-part-2 length:" + len);
            byte[] apdpa2 = new byte[len];
            System.arraycopy(body, position, apdpa2, 0, apdpa2.length);
            String authPluginDataPart2 = new String(apdpa2);
            System.out.println("authPluginDataPart2:" + authPluginDataPart2);
            position += len;
        }


        if ((capabilities & 0x00080000) != 0) {
            //auth-plugin name
            int p2 = 0;
            for (; ; ) {
                if (body[position + p2] == 0) {
                    break;
                }
                p2++;
            }
            byte[] apna = new byte[p2];
            System.arraycopy(body, position, apna, 0, apna.length);
            final String authPluginName = new String(apna);
            System.out.println("authPluginName:" + authPluginName);
            position = position + p2 + 1;
        }
        System.out.println(position);

        socket.close();
    }
}

运行结果:

length:74
seq:0
protocolVersion:10
serverVersion:8.0.19
connectionId:9
authPluginDataPart1:%N1omd4
capabilityFlag:65535
characterSet:33
statusFlags:2
capabilityFlag2:-939589632
authPluginDataLen:21
auth-plugin-data-part-2 length:13
authPluginDataPart2:X>Yt%#b#y9 
authPluginName:mysql_native_password
74

2,Handshake Response

客户端收到握手请求之后,需要回复响应。以客户端4.1+版本使用的HandshakeResponse41为例:

Type Name Description
int<4> client_flag 功能标志,必须包含CLIENT_PROTOCOL_41
int<4> max_packet_size 最大包大小
int<1> character_set 客户端编码
string[23] filler 填充位,固定23个0x00
string<NUL> username 用户名
if capabilities & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA {
string<length> auth_response 认证插件生成的认证数据
} else {
int<1> auth_response_length 认证数据长度
string<length> auth_response 认证插件生成的认证数据
}
if capabilities & CLIENT_CONNECT_WITH_DB {
string<NUL> database 连接默认的database(schema)
}
if capabilities & CLIENT_PLUGIN_AUTH {
string<NUL> client_plugin_name 客户端用来生成认证数据的插件名,用UTF8编码。
}
if capabilities & CLIENT_CONNECT_ATTRS {
int<lenenc> length of all key-values 所有属性的长度
string<lenenc> key1 第一个属性的名字
string<lenenc> value1 第一个属性的值
.. (如果更多的属性,以kv的形式跟在后面)
}
int<1> zstd_compression_level zstd压缩算法的压缩级别

字段含义如下:

  • capability_flags : 客户端支持的功能标志。Protocol::CapabilityFlags

  • max_packet_size :客户端发送到服务端的命令包的最大大小。

  • character_set : 连接默认的编码。 Protocol::CharacterSet.

  • username : 用来登陆数据库的账号,用连接(character_set字段指定的)编码方式编码。

  • auth-response : 认证插件加密过的认证数据。

  • database : 连接默认的数据库(schema),用连接(character_set字段指定的)编码方式编码。

  • auth plugin name : 客户端实际用来生成认证数据的插件名字,这个字段需要用UTF-8编码。

用代码实现,首先,把解析初始化包的代码封装一下:

import com.mysql.cj.protocol.a.NativeServerSession;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;

public class HandshakeV10Parser {


    public HandshakeV10Parser() {
    }

    public InitialHandshakePayload parse(SocketChannel socket) throws IOException {
        InitialHandshakePayload packet = new InitialHandshakePayload();
        ByteBuffer buffer = ByteBuffer.allocate(2048);

        //读取头
        int read = read(socket, buffer, 4);
        //计算包长
        int bodyLength = readFixInt(buffer, 3);
        //读取剩下的
        read(socket, buffer, bodyLength - read);
        buffer.flip();
        buffer.position(4);//跳过头

        packet.setProtocolVersion(readFixInt(buffer, 1));
        packet.setServerVersion(readNullString(buffer));
        packet.setThreadId(readFixInt(buffer, 4));
        packet.setAuthPluginDataPart1(readFixString(buffer, 8));
        buffer.get();//filler
        packet.setCapabilityFlags(readFixInt(buffer, 2));
        packet.setCharacterSet(readFixInt(buffer, 1));
        packet.setStatusFlags(readFixInt(buffer, 2));
        int capabilities = packet.getCapabilityFlags() | (readFixInt(buffer, 2) << 16);
        packet.setCapabilityFlags(capabilities);
        if ((capabilities & NativeServerSession.CLIENT_PLUGIN_AUTH) != 0) {
            int i = readFixInt(buffer, 1);
            packet.setAuthPluginDataLen(i);
        } else {
            buffer.get();
            packet.setAuthPluginDataLen(0);
        }
        buffer.position(buffer.position() + 10);//reserved
        int apdp2len = Math.max(13, packet.getAuthPluginDataLen() - 8);
        packet.setAuthPluginDataPart2(readFixString(buffer, apdp2len));

        if ((capabilities & NativeServerSession.CLIENT_PLUGIN_AUTH) != 0) {
            packet.setAuthPluginName(readNullString(buffer));
        }
        return packet;
    }

    private String readFixString(ByteBuffer buffer, int len) {
        byte[] data = new byte[len];
        for (int i = 0; i < len; i++) {
            data[i] = buffer.get();
        }
        return new String(data);
    }

    private String readNullString(ByteBuffer buffer) {
        int position = buffer.position();
        int end = position;
        while (buffer.get() != 0) {
            end++;
        }
        buffer.position(position);
        byte[] data = new byte[end - position];
        buffer.get(data);
        buffer.get();//skip 00
        return new String(data);
    }

    private int readFixInt(ByteBuffer buffer, int len) {
        int data = 0;
        for (int i = 0; i < len; i++) {
            data |= (buffer.get() << (i * 8));
        }
        return data;
    }

    public int read(SocketChannel socket, ByteBuffer buffer, int len) throws IOException {
        grow(buffer, len);
        int n = 0;
        while (n < len) {
            int count = socket.read(buffer);
            n += count;
        }
        return n;
    }

    public ByteBuffer grow(ByteBuffer buffer, int len) {
        if (buffer.remaining() < len) {
            ByteBuffer nb = ByteBuffer.allocate(buffer.capacity() << 1);
            buffer.flip();
            nb.put(buffer);
            return nb;
        } else {
            return buffer;
        }
    }
}

然后读取加密种子,混淆密码,加上其他所需参数,编码消息返回:

import com.mysql.cj.protocol.Security;
import com.mysql.cj.protocol.a.NativeConstants;
import com.mysql.cj.protocol.a.NativeServerSession;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;

public class MySqlClient {

    private String username = "root";
    private String password = "root";
    private String database = "test";

    public static void main(String[] args) throws Exception {
        MySqlClient client = new MySqlClient();
        client.run();
    }

    public void run() {
        try (SocketChannel socket = SocketChannel.open(new InetSocketAddress("127.0.0.1", 3306))) {
            InitialHandshakePayload packet = init(socket);

            HandshakeResponse response = new HandshakeResponse();
            response.setCapabilityFlags(packet.getCapabilityFlags());
            response.setMaxPacketSize(NativeConstants.MAX_PACKET_SIZE);
            response.setCharacterSet(packet.getCharacterSet());
            response.setUsername(username);
            //混淆
            response.setAuthResponse(auth(packet, password));
            response.setDatabase(database);
            response.setClientPluginName(packet.getAuthPluginName());

            Map<String, String> attrs = new HashMap<>();
//            attrs.put("_runtime_version", "1.8.0_181");
//            attrs.put("_client_version", "8.0.19");
//            attrs.put("_client_license", "GPL");
//            attrs.put("_runtime_vendor", "Oracle Corporation");
//            attrs.put("_client_name", "MySQL Connector/J");
            response.setAttributes(attrs);

            //客户端响应
            response(socket, response);

            //服务端响应
            ByteBuffer buffer = ByteBuffer.allocate(2048);
            int n = 0;
            while (n < 4) {
                int count = socket.read(buffer);
                n += count;
            }
            int bodyLength = readFixInt(buffer, 3);
            //读取剩下的
            int len = bodyLength - n;
            n = 0;
            while (n < len) {
                int count = socket.read(buffer);
                n += count;
            }

            byte type = buffer.get(4);
            if (type == 0) {
                //ok
                System.out.println("ok");
            } else if (type == (byte) 0xff) {
                //err
                int code = (buffer.get(5) & 0xff) | ((buffer.get(6) & 0xff) << 8);
                System.out.println("error code:" + code);
                System.out.println("marker:" + (buffer.get(7) & 0xff));
                byte[] ssa = new byte[5];
                ssa[0] = buffer.get(8);
                ssa[1] = buffer.get(9);
                ssa[2] = buffer.get(10);
                ssa[3] = buffer.get(11);
                ssa[4] = buffer.get(12);
                System.out.println("code:" + new String(ssa));

                int s = 13;
                while (true) {
                    if (buffer.get(s) == 0) {
                        break;
                    }
                    s++;
                }
                byte[] msga = new byte[s - 13];
                for (int i = 0; i < msga.length; i++) {
                    msga[i] = buffer.get(13 + i);
                }
                System.out.println("error:" + new String(msga));
            }

        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void response(SocketChannel socket, HandshakeResponse response) throws IOException {
        ByteBuffer buffer = ByteBuffer.allocate(1024);

        int capabilityFlags = response.getCapabilityFlags();
        writeFixInt(buffer, 20881935, 4);
        int maxPacketSize = response.getMaxPacketSize();
        writeFixInt(buffer, maxPacketSize, 4);
        int characterSet = response.getCharacterSet();
        writeFixInt(buffer, characterSet, 1);
        //filler [00]*23
        writeFixInt(buffer, 0, 23);
        String username = response.getUsername();
        writeNullString(buffer, username.getBytes(StandardCharsets.UTF_8));

        if ((capabilityFlags & NativeServerSession.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) != 0) {
            byte[] authResponse = response.getAuthResponse();
            writeLengthString(buffer, authResponse);
        } else {
            byte[] authResponse = response.getAuthResponse();
            writeLengthInt(buffer, authResponse.length);
            writeFixString(buffer, authResponse);
        }

        if ((capabilityFlags & NativeServerSession.CLIENT_CONNECT_WITH_DB) != 0) {
            String database = response.getDatabase();
            writeNullString(buffer, database.getBytes(StandardCharsets.UTF_8));
        }

        if ((capabilityFlags & NativeServerSession.CLIENT_PLUGIN_AUTH) != 0) {
            String clientPluginName = response.getClientPluginName();
            writeNullString(buffer, clientPluginName.getBytes(StandardCharsets.UTF_8));
        }

        if ((capabilityFlags & NativeServerSession.CLIENT_CONNECT_ATTRS) != 0) {
            Map<String, String> attributes = response.getAttributes();
            ByteBuffer attrBuffer = ByteBuffer.allocate(1024);
            for (Map.Entry<String, String> entry : attributes.entrySet()) {
                writeLengthString(attrBuffer, entry.getKey().getBytes(StandardCharsets.UTF_8));
                writeLengthString(attrBuffer, entry.getValue().getBytes(StandardCharsets.UTF_8));
            }
            attrBuffer.flip();

            writeLengthInt(buffer, attrBuffer.limit());
            grow(buffer, attrBuffer.limit());
            buffer.put(attrBuffer);
        }

        buffer.flip();

        int bodySize = buffer.limit();
        ByteBuffer packet = ByteBuffer.allocate(bodySize + 4);
        writeFixInt(packet, bodySize, 3);
        packet.put((byte) 1);
        packet.put(buffer);
        packet.flip();
        socket.write(packet);

    }

    private void writeFixString(ByteBuffer buffer, byte[] data) {
        grow(buffer, data.length);
        for (byte b : data) {
            buffer.put(b);
        }
    }

    private void writeLengthString(ByteBuffer buffer, byte[] data) {
        grow(buffer, data.length + 9);
        writeLengthInt(buffer, data.length);
        writeFixString(buffer, data);
    }

    /**
     * If the value is < 251, it is stored as a 1-byte integer.
     * If the value is ≥ 251 and < (2^16), it is stored as fc + 2-byte integer.
     * If the value is ≥ (2^16) and < (2^24), it is stored as fd + 3-byte integer.
     * If the value is ≥ (2^24) and < (2^64) it is stored as fe + 8-byte integer.
     */
    private void writeLengthInt(ByteBuffer buffer, int v) {
        if (v < 251) {
            grow(buffer, 1);
            writeFixInt(buffer, v, 1);
        } else if (v < 65536L) {
            grow(buffer, 3);
            writeFixInt(buffer, 0xfc, 1);
            writeFixInt(buffer, v, 2);
        } else if (v < 16777216L) {
            grow(buffer, 4);
            writeFixInt(buffer, 0xfd, 1);
            writeFixInt(buffer, v, 3);

        } else {
            grow(buffer, 9);
            writeFixInt(buffer, 0xfe, 1);
            writeFixInt(buffer, v, 8);
        }
    }

    private void writeNullString(ByteBuffer buffer, byte[] data) {
        grow(buffer, data.length + 1);
        for (byte b : data) {
            buffer.put(b);
        }
        buffer.put((byte) 0);
    }

    public void writeFixInt(ByteBuffer buffer, int v, int len) {
        grow(buffer, len);
        for (int i = 0; i < len; i++) {
            buffer.put((byte) (v >>> (i * 8)));
        }
    }


    private int readFixInt(ByteBuffer buffer, int len) {
        int data = 0;
        for (int i = 0; i < len; i++) {
            data |= (buffer.get() << (i * 8));
        }
        return data;
    }

    public InitialHandshakePayload init(SocketChannel socket) throws IOException {
        HandshakeV10Parser parser = new HandshakeV10Parser();
        InitialHandshakePayload packet = parser.parse(socket);
        System.out.println(packet);
        return packet;
    }


    public byte[] auth(InitialHandshakePayload packet, String password) {
        final String authPluginName = packet.getAuthPluginName();

        if ("mysql_native_password".equals(authPluginName)) {
            String data = packet.getAuthPluginDataPart1() + packet.getAuthPluginDataPart2();
            byte[] bytes = data.getBytes();
            byte[] seed = new byte[20];
            //去掉最后的0
            System.arraycopy(bytes, 0, seed, 0, 20);
            return Security.scramble411(password.getBytes(StandardCharsets.UTF_8), seed);
        } else {
            //省略
            return new byte[0];
        }
    }

    public ByteBuffer grow(ByteBuffer buffer, int len) {
        if (buffer.remaining() < len) {
            ByteBuffer nb = ByteBuffer.allocate(buffer.capacity() << 1);
            buffer.flip();
            nb.put(buffer);
            return nb;
        } else {
            return buffer;
        }
    }
}

如果不出意外,服务端应该会返回OK。报文内容为[7] [0] [0] [2] [0] [0] [0] [2]

总结

至此,连接阶段完成,接下来会进入命令阶段,客户端可以向服务器提交命令了。

另外,由于MySQL8默认认证插件改成了caching_sha2_password,这里为了测试方便,改回了mysql_native_password。连接阶段还涉及一些其他操作,如认证方法切换、SSL连接等,这里暂时略过。

最后,整理下代码:https://github.com/dingfugui/mysql-protocol/tree/master/src/main/java/prrety

相关配置

MySQL server:8.0.19

JDBC:8.0.19

MySQL配置文件:

[mysqld]
port=3306
basedir=...
datadir=...
character_set_server=utf8
default-storage-engine=INNODB
sql_mode=NO_ENGINE_SUBSTITUTION,STRICT_TRANS_TABLES
default_authentication_plugin=mysql_native_password
[mysql]
default-character-set=utf8

参考资料:

https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse

https://dev.mysql.com/doc/dev/mysql-server/8.0.19/page_protocol_connection_phase_packets.html

mysql:mysql-connector-java:8.0.19

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!