Add support for "auth switch" and "query attrs"

Also fix the issue where Resultset could not correctly distinguish between EOF_Packet and OK_Packet.
This commit is contained in:
Fupeng Zhao 2024-06-30 21:52:31 +08:00 committed by Arne Welzel
parent 9cb618c718
commit e8bdf149f2
12 changed files with 272 additions and 46 deletions

View file

@ -96,11 +96,6 @@ type LengthEncodedStringArg(first_byte: uint8) = record {
};
%}
%code{
const char* PLUGIN_CACHING_SHA2_PASSWORD = "caching_sha2_password";
%}
extern type PLUGIN_CACHING_SHA2_PASSWORD;
extern type to_int;
# Enums
@ -141,11 +136,12 @@ enum command_consts {
};
enum state {
CONNECTION_PHASE = 0,
COMMAND_PHASE = 1,
SHA2_AUTH_PHASE = 2,
PUB_KEY_PHASE = 3,
SHA2_AUTH_RESP_PHASE = 4,
CONNECTION_PHASE = 0,
COMMAND_PHASE = 1,
SHA2_AUTH_PHASE = 2,
PUB_KEY_PHASE = 3,
SHA2_AUTH_RESP_PHASE = 4,
AUTH_SWITCH_RESP_PHASE = 5,
};
enum Expected {
@ -173,9 +169,10 @@ enum Client_Capabilities {
# Expects an OK (instead of EOF) after the resultset rows of a Text Resultset.
CLIENT_DEPRECATE_EOF = 0x01000000,
CLIENT_ZSTD_COMPRESSION_ALGORITHM = 0x04000000,
CLIENT_QUERY_ATTRIBUTES = 0x08000000,
};
enum SHA2_Atuh_State {
enum SHA2_Auth_State {
REQUEST_PUBLIC_KEY = 2,
FAST_AUTH_SUCCESS = 3,
PERFORM_FULL_AUTHENTICATION = 4,
@ -217,6 +214,7 @@ type Client_Message(state: int) = case state of {
SHA2_AUTH_PHASE -> sha2_auth_phase : SHA2_Auth_Packet;
PUB_KEY_PHASE -> pub_key_phase : Public_Key_Packet;
SHA2_AUTH_RESP_PHASE -> sha2_auth_resp_phase : SHA2_Auth_Response_Packet;
AUTH_SWITCH_RESP_PHASE -> auth_switch_resp_phase : Auth_Switch_Response_Packet;
};
# Handshake Request
@ -244,10 +242,14 @@ type Handshake_v10 = record {
auth_plugin_data_len : uint8;
reserved : padding[10];
auth_plugin_data_part_2: bytestring &length=13;
have_plugin : case ( ( capability_flags_2 << 4 ) & CLIENT_PLUGIN_AUTH ) of {
CLIENT_PLUGIN_AUTH -> auth_plugin_name: NUL_String;
have_plugin : case ( ( capability_flags_2 << 16 ) & CLIENT_PLUGIN_AUTH ) of {
CLIENT_PLUGIN_AUTH -> auth_plugin: NUL_String;
0x0 -> none : empty;
};
} &let {
update_auth_plugin: bool = $context.connection.set_auth_plugin(auth_plugin)
&if( ( capability_flags_2 << 16 ) & CLIENT_PLUGIN_AUTH );
server_query_attrs: bool = $context.connection.set_server_query_attrs(( capability_flags_2 << 16 ) & CLIENT_QUERY_ATTRIBUTES);
};
type Handshake_v9 = record {
@ -287,7 +289,7 @@ type Handshake_Plain_v10(cap_flags: uint32) = record {
0x0 -> none_1 : empty;
};
have_plugin : case ( cap_flags & CLIENT_PLUGIN_AUTH ) of {
CLIENT_PLUGIN_AUTH -> auth_plugin_name: EmptyOrNUL_String;
CLIENT_PLUGIN_AUTH -> auth_plugin: EmptyOrNUL_String;
0x0 -> none_2 : empty;
};
have_attrs : case ( cap_flags & CLIENT_CONNECT_ATTRS ) of {
@ -299,8 +301,10 @@ type Handshake_Plain_v10(cap_flags: uint32) = record {
0x0 -> none_4 : empty;
};
} &let {
update_state: bool = $context.connection.update_state(SHA2_AUTH_PHASE)
&if(( cap_flags & CLIENT_PLUGIN_AUTH ) && auth_plugin_name==PLUGIN_CACHING_SHA2_PASSWORD);
update_auth_plugin: bool = $context.connection.set_auth_plugin(auth_plugin)
&if( cap_flags & CLIENT_PLUGIN_AUTH );
update_state: bool = $context.connection.update_state_from_auth()
&if( cap_flags & CLIENT_PLUGIN_AUTH );
};
type Handshake_Response_Packet_v10 = record {
@ -314,6 +318,7 @@ type Handshake_Response_Packet_v10 = record {
};
} &let {
deprecate_eof: bool = $context.connection.set_deprecate_eof(cap_flags & CLIENT_DEPRECATE_EOF);
client_query_attrs: bool = $context.connection.set_client_query_attrs(cap_flags & CLIENT_QUERY_ATTRIBUTES);
};
type Handshake_Response_Packet_v9 = record {
@ -352,10 +357,43 @@ type SHA2_Auth_Response_Packet = record {
update_state: bool = $context.connection.update_state(COMMAND_PHASE);
};
# Auth Switch
type Auth_Switch_Response_Packet = record {
data : bytestring &restofdata;
} &let {
update_state: bool = $context.connection.update_state_from_auth();
};
# Command Request
type AttributeTypeAndName = record {
type: uint16;
name: LengthEncodedString;
};
type Attributes(count: uint8) = record {
unused : uint8;
send_types_to_server: uint8; # Always 1.
names : AttributeTypeAndName[count];
values : LengthEncodedString[count];
};
type Query_Attributes = record {
count : uint8;
set_coun : uint8;
have_attr : case ( count > 0 ) of {
true -> attrs: Attributes(count);
false -> none: empty;
};
};
type Command_Request_Packet = record {
command: uint8;
attrs : case ( command == COM_QUERY && $context.connection.get_client_query_attrs() && $context.connection.get_server_query_attrs() ) of {
true -> query_attrs: Query_Attributes;
false -> none: empty;
};
arg : bytestring &restofdata;
} &let {
update_expectation: bool = $context.connection.set_next_expected_from_command(command);
@ -413,22 +451,22 @@ type ColumnDefinition = record {
};
# Only used to indicate the end of a result, no intermediate eofs here.
type EOFOrOK = case $context.connection.get_deprecate_eof() of {
# MySQL spec says "You must check whether the packet length is less than 9
# to make sure that it is a EOF_Packet packet" so the value of 13 here
# comes from that 9, plus a 4-byte header.
type EOFOrOK(pkt_len: uint32) = case ( $context.connection.get_deprecate_eof() || pkt_len > 13 ) of {
false -> eof: EOF_Packet(EOF_END);
true -> ok: OK_Packet;
};
type ColumnDefinitionOrEOF(pkt_len: uint32) = record {
marker : uint8;
def_or_eof: case is_eof of {
true -> eof: EOFOrOK;
def_or_eof: case is_eof_or_ok of {
true -> eof: EOFOrOK(pkt_len);
false -> def: ColumnDefinition41(marker);
} &requires(is_eof);
} &requires(is_eof_or_ok);
} &let {
# MySQL spec says "You must check whether the packet length is less than 9
# to make sure that it is a EOF_Packet packet" so the value of 13 here
# comes from that 9, plus a 4-byte header.
is_eof: bool = (marker == 0xfe && pkt_len < 13);
is_eof_or_ok: bool = (marker == 0xfe);
};
@ -442,17 +480,14 @@ type EOFIfLegacyThenResultset(pkt_len: uint32) = case $context.connection.get_de
type Resultset(pkt_len: uint32) = record {
marker : uint8;
row_or_eof: case is_eof of {
true -> eof: EOFOrOK;
row_or_eof: case is_eof_or_ok of {
true -> eof: EOFOrOK(pkt_len);
false -> row: ResultsetRow(marker);
} &requires(is_eof);
} &requires(is_eof_or_ok);
} &let {
# MySQL spec says "You must check whether the packet length is less than 9
# to make sure that it is a EOF_Packet packet" so the value of 13 here
# comes from that 9, plus a 4-byte header.
is_eof : bool = (marker == 0xfe && pkt_len < 13);
is_eof_or_ok : bool = (marker == 0xfe);
update_result_seen: bool = $context.connection.inc_results_seen();
update_expectation: bool = $context.connection.set_next_expected(is_eof ? NO_EXPECTATION : EXPECT_RESULTSET);
update_expectation: bool = $context.connection.set_next_expected(is_eof_or_ok ? NO_EXPECTATION : EXPECT_RESULTSET);
};
type ResultsetRow(first_byte: uint8) = record {
@ -480,6 +515,9 @@ type AuthSwitchRequest = record {
status: uint8;
name : NUL_String;
data : bytestring &restofdata;
} &let {
update_auth_plugin: bool = $context.connection.set_auth_plugin(name);
update_state: bool = $context.connection.update_state(AUTH_SWITCH_RESP_PHASE);
};
type ColumnDefinition320 = record {
@ -531,6 +569,9 @@ refine connection MySQL_Conn += {
uint32 remaining_cols_;
uint32 results_seen_;
bool deprecate_eof_;
bool server_query_attrs_;
bool client_query_attrs_;
bytestring auth_plugin_;
%}
%init{
@ -542,6 +583,13 @@ refine connection MySQL_Conn += {
remaining_cols_ = 0;
results_seen_ = 0;
deprecate_eof_ = false;
server_query_attrs_ = false;
client_query_attrs_ = false;
auth_plugin_ = bytestring();
%}
%cleanup{
auth_plugin_.free();
%}
function get_version(): uint8
@ -577,6 +625,18 @@ refine connection MySQL_Conn += {
return true;
%}
function update_state_from_auth(): bool
%{
if ( auth_plugin_ == "caching_sha2_password" )
{
state_ = SHA2_AUTH_PHASE;
if ( expected_ == EXPECT_AUTH_SWITCH )
expected_ = EXPECT_STATUS;
}
return true;
%}
function get_deprecate_eof(): bool
%{
return deprecate_eof_;
@ -588,6 +648,46 @@ refine connection MySQL_Conn += {
return true;
%}
function get_server_query_attrs(): bool
%{
return server_query_attrs_;
%}
function set_server_query_attrs(q: bool): bool
%{
server_query_attrs_ = q;
return true;
%}
function get_client_query_attrs(): bool
%{
return client_query_attrs_;
%}
function set_client_query_attrs(q: bool): bool
%{
client_query_attrs_ = q;
return true;
%}
function get_auth_plugin(): bytestring
%{
return auth_plugin_;
%}
function set_auth_plugin(a: bytestring): bool
%{
if ( auth_plugin_.length() > 0 &&
strncmp(c_str(auth_plugin_), c_str(a), auth_plugin_.length()) != 0 )
{
expected_ = EXPECT_AUTH_SWITCH;
}
auth_plugin_.free();
auth_plugin_.init(a.data(), a.length());
return true;
%}
function get_expectation(): Expected
%{
return expected_;