diff --git a/src/server/poolserver/Database/ServerDatabaseEnv.h b/src/server/poolserver/Database/ServerDatabaseEnv.h index 8e6ea0e..a2abb9a 100644 --- a/src/server/poolserver/Database/ServerDatabaseEnv.h +++ b/src/server/poolserver/Database/ServerDatabaseEnv.h @@ -3,8 +3,19 @@ #include "DatabaseWorkerPool.h" +enum ServerSTMT +{ + STMT_QUERY_TEST_TABLE, + STMT_INSERT_SHIT +}; + class ServerDatabaseWorkerPoolMySQL : public MySQL::DatabaseWorkerPool { + void LoadSTMT() + { + PrepareStatement(STMT_QUERY_TEST_TABLE, "SELECT * FROM `test_table`", MySQL::STMT_BOTH); + PrepareStatement(STMT_INSERT_SHIT, "INSERT INTO `test_table` VALUES (?, ?, ?)", MySQL::STMT_BOTH); + } }; extern ServerDatabaseWorkerPoolMySQL sDatabase; diff --git a/src/server/poolserver/Server/Server.cpp b/src/server/poolserver/Server/Server.cpp index db6f9a0..ba8e4d2 100644 --- a/src/server/poolserver/Server/Server.cpp +++ b/src/server/poolserver/Server/Server.cpp @@ -20,10 +20,8 @@ Server::~Server() void AsyncQueryCallback(MySQL::QueryResult result) { sLog.Info(LOG_SERVER, "Metadata: F: %u R: %u", result->GetFieldCount(), result->GetRowCount()); - while (result->NextRow()) { - MySQL::Field* fields = result->Fetch(); - sLog.Info(LOG_SERVER, "Row: %i %s", fields[0].GetUInt32(), - fields[1].GetString().c_str()); + while (MySQL::Field* fields = result->FetchRow()) { + sLog.Info(LOG_SERVER, "Row: %i %s", fields[0].GetUInt32(), fields[1].GetString().c_str()); } } @@ -36,14 +34,25 @@ int Server::Run() //sDatabase.Execute("INSERT INTO `test_table` VALUES ('999', 'sync', '1.1')"); //sDatabase.ExecuteAsync("INSERT INTO `test_table` VALUES ('999', 'sync', '1.1')"); - sDatabase.QueryAsync("SELECT * FROM `test_table`", &AsyncQueryCallback); - MySQL::QueryResult result = sDatabase.Query("SELECT * FROM `test_table`"); - sLog.Info(LOG_SERVER, "Metadata: F: %u R: %u", result->GetFieldCount(), result->GetRowCount()); - while (result->NextRow()) { - MySQL::Field* fields = result->Fetch(); - sLog.Info(LOG_SERVER, "Row: %i %s", fields[0].GetUInt32(), - fields[1].GetString().c_str()); - } + /*MySQL::PreparedStatement* stmt = sDatabase.GetPreparedStatement(STMT_INSERT_SHIT); + stmt->SetUInt32(0, 10); + stmt->SetString(1, "hello"); + stmt->SetFloat(2, 5.987); + sDatabase.ExecuteAsync(stmt);*/ + + MySQL::PreparedStatement* stmt = sDatabase.GetPreparedStatement(STMT_QUERY_TEST_TABLE); + MySQL::QueryResult result = sDatabase.Query(stmt); + + + //sDatabase.QueryAsync("SELECT * FROM `test_table`", &AsyncQueryCallback); + //MySQL::QueryResult result = sDatabase.Query("SELECT * FROM `test_table`"); + if (result) { + sLog.Info(LOG_SERVER, "Metadata: F: %u R: %u", result->GetFieldCount(), result->GetRowCount()); + while (MySQL::Field* fields = result->FetchRow()) { + sLog.Info(LOG_SERVER, "Row: %i %s", fields[0].GetUInt32(), fields[1].GetString().c_str()); + } + } else + sLog.Info(LOG_SERVER, "Empty result"); // Start stratum server sLog.Info(LOG_SERVER, "Starting stratum"); diff --git a/src/server/shared/Common.h b/src/server/shared/Common.h new file mode 100644 index 0000000..93e6854 --- /dev/null +++ b/src/server/shared/Common.h @@ -0,0 +1,16 @@ +#ifndef COMMON_H_ +#define COMMON_H_ + +#include + +typedef uint8_t uint8; +typedef uint16_t uint16; +typedef uint32_t uint32; +typedef uint64_t uint64; + +typedef int8_t int8; +typedef int16_t int16; +typedef int32_t int32; +typedef int64_t int64; + +#endif diff --git a/src/server/shared/MySQL/DatabaseCallback.h b/src/server/shared/MySQL/DatabaseCallback.h index d1eea07..3a67cfc 100644 --- a/src/server/shared/MySQL/DatabaseCallback.h +++ b/src/server/shared/MySQL/DatabaseCallback.h @@ -10,4 +10,4 @@ namespace MySQL typedef boost::function DatabaseCallback; } -#endif \ No newline at end of file +#endif diff --git a/src/server/shared/MySQL/DatabaseConnection.cpp b/src/server/shared/MySQL/DatabaseConnection.cpp index 862af3c..aceb028 100644 --- a/src/server/shared/MySQL/DatabaseConnection.cpp +++ b/src/server/shared/MySQL/DatabaseConnection.cpp @@ -126,8 +126,7 @@ namespace MySQL if (!result) return false; - if (!rowCount) - { + if (!rowCount) { mysql_free_result(*result); return false; } @@ -136,17 +135,125 @@ namespace MySQL return true; } + + bool DatabaseConnection::_Query(PreparedStatement* stmt, MYSQL_RES** result, MYSQL_STMT** resultSTMT, uint64& rowCount, uint32& fieldCount) + { + if (!_mysql) + return false; + + ConnectionPreparedStatement* cstmt = GetPreparedStatement(stmt->_index); + + if (!cstmt) { + sLog.Error(LOG_DATABASE, "STMT id: %u not found!", stmt->_index); + return false; + } + + cstmt->BindParameters(stmt); + + MYSQL_STMT* mSTMT = cstmt->GetSTMT(); + MYSQL_BIND* mBIND = cstmt->GetBind(); + + if (mysql_stmt_bind_param(mSTMT, mBIND)) + { + uint32 lErrno = mysql_errno(_mysql); + sLog.Error(LOG_DATABASE, "STMT Execute Error[%u]: %s", lErrno, mysql_stmt_error(mSTMT)); + + if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection) + return Execute(stmt); // Try again + + cstmt->ClearParameters(); + return false; + } + + if (mysql_stmt_execute(mSTMT)) + { + uint32 lErrno = mysql_errno(_mysql); + sLog.Error(LOG_DATABASE, "STMT Execute Error[%u]: %s", lErrno, mysql_stmt_error(mSTMT)); + + if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection) + return _Query(stmt, result, resultSTMT, rowCount, fieldCount); // Try again + + cstmt->ClearParameters(); + return false; + } + + cstmt->ClearParameters(); + + *result = mysql_stmt_result_metadata(mSTMT); + rowCount = mysql_stmt_num_rows(mSTMT); + fieldCount = mysql_stmt_field_count(mSTMT); + *resultSTMT = mSTMT; + + return true; + } bool DatabaseConnection::Execute(PreparedStatement* stmt) { + if (!_mysql) + return false; + + ConnectionPreparedStatement* cstmt = GetPreparedStatement(stmt->_index); + + if (!cstmt) { + sLog.Error(LOG_DATABASE, "STMT id: %u not found!", stmt->_index); + return false; + } + + cstmt->BindParameters(stmt); + + MYSQL_STMT* mSTMT = cstmt->GetSTMT(); + MYSQL_BIND* mBIND = cstmt->GetBind(); + + if (mysql_stmt_bind_param(mSTMT, mBIND)) + { + uint32 lErrno = mysql_errno(_mysql); + sLog.Error(LOG_DATABASE, "STMT Execute Error[%u]: %s", lErrno, mysql_stmt_error(mSTMT)); + + if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection) + return Execute(stmt); // Try again + + cstmt->ClearParameters(); + return false; + } + + if (mysql_stmt_execute(mSTMT)) + { + uint32 lErrno = mysql_errno(_mysql); + sLog.Error(LOG_DATABASE, "STMT Execute Error[%u]: %s", lErrno, mysql_stmt_error(mSTMT)); + + if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection) + return Execute(stmt); // Try again + + cstmt->ClearParameters(); + return false; + } + + cstmt->ClearParameters(); + + return true; } ResultSet* DatabaseConnection::Query(PreparedStatement* stmt) { + MYSQL_RES* result = NULL; + MYSQL_STMT* resultSTMT = NULL; + uint64 rowCount = 0; + uint32 fieldCount = 0; + + if (!_Query(stmt, &result, &resultSTMT, rowCount, fieldCount)) + return NULL; + + if (mysql_more_results(_mysql)) + mysql_next_result(_mysql); + + return new ResultSet(result, resultSTMT, rowCount, fieldCount); } - void DatabaseConnection::PrepareStatement(uint32 index, const char* sql) + bool DatabaseConnection::PrepareStatement(uint32 index, const char* sql) { + if (!_mysql) + return false; + // For reconnection case //if (m_reconnecting) // delete m_stmts[index]; @@ -156,16 +263,30 @@ namespace MySQL if (!stmt) { sLog.Error(LOG_DATABASE, "In mysql_stmt_init() id: %u, sql: \"%s\"", index, sql); sLog.Error(LOG_DATABASE, "%s", mysql_error(_mysql)); - } else { - if (mysql_stmt_prepare(stmt, sql, strlen(sql))) { - sLog.Error(LOG_DATABASE, "In mysql_stmt_init() id: %u, sql: \"%s\"", index, sql); - sLog.Error(LOG_DATABASE, "%s", mysql_stmt_error(stmt)); - mysql_stmt_close(stmt); - } else { - PreparedStatement* mStmt = new PreparedStatement(stmt); - _stmts[index] = mStmt; - } + return false; + } + + if (mysql_stmt_prepare(stmt, sql, strlen(sql))) { + sLog.Error(LOG_DATABASE, "In mysql_stmt_init() id: %u, sql: \"%s\"", index, sql); + sLog.Error(LOG_DATABASE, "%s", mysql_stmt_error(stmt)); + mysql_stmt_close(stmt); + return false; } + + // Set flags to update max_length property + my_bool mysql_c_api_sucks = true; + mysql_stmt_attr_set(stmt, STMT_ATTR_UPDATE_MAX_LENGTH, (void*)&mysql_c_api_sucks); + + // Resize stmt vector + if (index >= _stmts.size()) + _stmts.resize(index+1); + + ConnectionPreparedStatement* mStmt = new ConnectionPreparedStatement(stmt); + _stmts[index] = mStmt; + + sLog.Debug(LOG_DATABASE, "Prepared STMT id: %u, sql: \"%s\"", index, sql); + + return true; } bool DatabaseConnection::_HandleMySQLErrno(uint32_t lErrno) diff --git a/src/server/shared/MySQL/DatabaseConnection.h b/src/server/shared/MySQL/DatabaseConnection.h index cf6f770..04844b5 100644 --- a/src/server/shared/MySQL/DatabaseConnection.h +++ b/src/server/shared/MySQL/DatabaseConnection.h @@ -5,6 +5,7 @@ #include "DatabaseWorker.h" #include "QueryResult.h" #include "PreparedStatement.h" +#include "Log.h" #include #include @@ -63,10 +64,18 @@ namespace MySQL ConnectionType Type; - void PrepareStatement(uint32_t index, const char* sql); + bool PrepareStatement(uint32_t index, const char* sql); + + ConnectionPreparedStatement* GetPreparedStatement(uint32 index) + { + if (index >= _stmts.size()) + return NULL; + return _stmts[index]; + } private: - bool _Query(const char *sql, MYSQL_RES** result, MYSQL_FIELD** fields, uint64_t& pRowCount, uint32_t& pFieldCount); + bool _Query(const char *sql, MYSQL_RES** result, MYSQL_FIELD** fields, uint64& rowCount, uint32& fieldCount); + bool _Query(PreparedStatement* stmt, MYSQL_RES** result, MYSQL_STMT** resultSTMT, uint64& rowCount, uint32& fieldCount); bool _HandleMySQLErrno(uint32_t lErrno); @@ -74,7 +83,7 @@ namespace MySQL MYSQL* _mysql; DatabaseWorkQueue* _asyncQueue; DatabaseWorker* _worker; - std::vector _stmts; + std::vector _stmts; ConnectionInfo _connectionInfo; }; } diff --git a/src/server/shared/MySQL/DatabaseWorkerPool.cpp b/src/server/shared/MySQL/DatabaseWorkerPool.cpp index 7c53358..5fe1cb0 100644 --- a/src/server/shared/MySQL/DatabaseWorkerPool.cpp +++ b/src/server/shared/MySQL/DatabaseWorkerPool.cpp @@ -28,6 +28,8 @@ namespace MySQL else sLog.Error(LOG_DATABASE, "Failed opening MySQL Database Pool to '%s'.", _connectionInfo.DB.c_str()); + LoadSTMT(); + return res; } @@ -53,7 +55,7 @@ namespace MySQL sLog.Info(LOG_DATABASE, "Closed all connections to MySQL Database Pool '%s'.", _connectionInfo.DB.c_str()); } - bool DatabaseWorkerPool::PrepareStatement(index, const char* sql, PreparedStatementFlags flags) + bool DatabaseWorkerPool::PrepareStatement(uint32 index, const char* sql, PreparedStatementFlags flags) { if (flags & STMT_SYNC) { for (uint8_t i = 0; i < _connections[MYSQL_CONN_SYNC].size(); ++i) { @@ -77,4 +79,4 @@ namespace MySQL return true; } -} \ No newline at end of file +} diff --git a/src/server/shared/MySQL/DatabaseWorkerPool.h b/src/server/shared/MySQL/DatabaseWorkerPool.h index 902efe2..584bc78 100644 --- a/src/server/shared/MySQL/DatabaseWorkerPool.h +++ b/src/server/shared/MySQL/DatabaseWorkerPool.h @@ -1,6 +1,7 @@ #ifndef DATABASE_WORKER_POOL_MYSQL_H_ #define DATABASE_WORKER_POOL_MYSQL_H_ +#include "Common.h" #include "DatabaseConnection.h" #include "PreparedStatement.h" #include "QueryResult.h" @@ -19,7 +20,7 @@ namespace MySQL delete _asyncQueue; } - bool Open(ConnectionInfo connInfo, uint8_t syncThreads, uint8_t asyncThreads); + bool Open(ConnectionInfo connInfo, uint8 syncThreads, uint8 asyncThreads); void Close(); @@ -81,19 +82,20 @@ namespace MySQL return true; } - bool PrepareStatement(index, const char* sql, PreparedStatementFlags flags); + virtual void LoadSTMT() = 0; + bool PrepareStatement(uint32 index, const char* sql, PreparedStatementFlags flags); // Prepared Statements - PreparedStatement* GetPreparedStatement(uint32_t stmtid) + PreparedStatement* GetPreparedStatement(uint32 stmtid) { - return NULL;//new PreparedStatement(stmtid); + return new PreparedStatement(stmtid); } private: DatabaseConnection* GetSyncConnection() { - uint32_t i; - uint8_t conn_size = _connections[MYSQL_CONN_SYNC].size(); + uint32 i; + uint8 conn_size = _connections[MYSQL_CONN_SYNC].size(); DatabaseConnection* conn = NULL; // Block until we find a free connection diff --git a/src/server/shared/MySQL/Field.h b/src/server/shared/MySQL/Field.h index 102fc82..b0a91df 100644 --- a/src/server/shared/MySQL/Field.h +++ b/src/server/shared/MySQL/Field.h @@ -1,6 +1,7 @@ #ifndef FIELD_MYSQL_H_ #define FIELD_MYSQL_H_ +#include "Log.h" #include #include @@ -31,29 +32,97 @@ namespace MySQL } data.type = type; + data.raw = false; + } + + void SetByteValue(const void* value, const size_t size, enum_field_types type, uint32 length) + { + if (data.value) + CleanUp(); + + if (value) + { + data.value = new char[size]; + memcpy(data.value, value, size); + data.length = length; + } + + data.type = type; + data.raw = true; } uint32_t GetUInt32() { + if (data.raw) + return *reinterpret_cast(data.value); return boost::lexical_cast(data.value); } + char const* GetCString() + { + return static_cast(data.value); + } + std::string GetString() { + if (data.raw) + return std::string(GetCString(), data.length); return boost::lexical_cast(data.value); } double GetDouble() { + if (data.raw) + return *reinterpret_cast(data.value); return boost::lexical_cast(data.value); } + + static size_t SizeForType(MYSQL_FIELD* field) + { + switch (field->type) + { + case MYSQL_TYPE_NULL: + return 0; + case MYSQL_TYPE_TINY: + return 1; + case MYSQL_TYPE_YEAR: + case MYSQL_TYPE_SHORT: + return 2; + case MYSQL_TYPE_INT24: + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_FLOAT: + return 4; + case MYSQL_TYPE_DOUBLE: + case MYSQL_TYPE_LONGLONG: + case MYSQL_TYPE_BIT: + return 8; + case MYSQL_TYPE_TIMESTAMP: + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_TIME: + case MYSQL_TYPE_DATETIME: + return sizeof(MYSQL_TIME); + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_VAR_STRING: + return field->max_length + 1; + case MYSQL_TYPE_DECIMAL: + case MYSQL_TYPE_NEWDECIMAL: + return 64; + default: + return 0; + } + } private: struct { - uint32_t length; + uint32 length; char* value; enum_field_types type; + bool raw; } data; void CleanUp() diff --git a/src/server/shared/MySQL/PreparedStatement.cpp b/src/server/shared/MySQL/PreparedStatement.cpp index ecab60d..ff757c7 100644 --- a/src/server/shared/MySQL/PreparedStatement.cpp +++ b/src/server/shared/MySQL/PreparedStatement.cpp @@ -3,7 +3,7 @@ namespace MySQL { - PreparedStatement::PreparedStatement(MYSQL_STMT* stmt) : + ConnectionPreparedStatement::ConnectionPreparedStatement(MYSQL_STMT* stmt) : _stmt(stmt), _bind(NULL) { _paramCount = mysql_stmt_param_count(stmt); @@ -11,7 +11,7 @@ namespace MySQL memset(_bind, 0, sizeof(MYSQL_BIND)*_paramCount); } - PreparedStatement::~PreparedStatement() + ConnectionPreparedStatement::~ConnectionPreparedStatement() { ClearParameters(); if (_stmt->bind_result_done) @@ -21,12 +21,101 @@ namespace MySQL } mysql_stmt_close(_stmt); delete[] _bind; - this->~PreparedStatement(); + } + + void ConnectionPreparedStatement::BindParameters(PreparedStatement* stmt) + { + for (uint8 i = 0; i < stmt->data.size(); ++i) + { + switch (stmt->data[i].type) + { + case MYSQL_UINT8: + SetValue(i, MYSQL_TYPE_TINY, &boost::get(stmt->data[i].value), sizeof(uint8), true); + break; + case MYSQL_UINT16: + SetValue(i, MYSQL_TYPE_SHORT, &boost::get(stmt->data[i].value), sizeof(uint16), true); + break; + case MYSQL_UINT32: + SetValue(i, MYSQL_TYPE_LONG, &boost::get(stmt->data[i].value), sizeof(uint32), true); + break; + case MYSQL_UINT64: + SetValue(i, MYSQL_TYPE_LONGLONG, &boost::get(stmt->data[i].value), sizeof(uint64), true); + break; + case MYSQL_INT8: + SetValue(i, MYSQL_TYPE_TINY, &boost::get(stmt->data[i].value), sizeof(int8), false); + break; + case MYSQL_INT16: + SetValue(i, MYSQL_TYPE_SHORT, &boost::get(stmt->data[i].value), sizeof(int16), false); + break; + case MYSQL_INT32: + SetValue(i, MYSQL_TYPE_LONG, &boost::get(stmt->data[i].value), sizeof(int32), false); + break; + case MYSQL_INT64: + SetValue(i, MYSQL_TYPE_LONGLONG, &boost::get(stmt->data[i].value), sizeof(int64), false); + break; + case MYSQL_FLOAT: + SetValue(i, MYSQL_TYPE_FLOAT, &boost::get(stmt->data[i].value), sizeof(float), false); + break; + case MYSQL_DOUBLE: + SetValue(i, MYSQL_TYPE_DOUBLE, &boost::get(stmt->data[i].value), sizeof(double), false); + break; + case MYSQL_STRING: + SetString(i, boost::get(stmt->data[i].value)); + break; + case MYSQL_NULL: + SetNull(i); + break; + default: + // need assert? + break; + } + } + } + + void ConnectionPreparedStatement::SetValue(uint8 index, enum_field_types type, const void* value, uint32 len, bool isUnsigned) + { + MYSQL_BIND* param = &_bind[index]; + + param->buffer_type = type; + delete [] static_cast(param->buffer); + param->buffer = new char[len]; + param->buffer_length = 0; + param->is_null_value = 0; + param->length = NULL; // Only != NULL for strings + param->is_unsigned = isUnsigned; + + memcpy(param->buffer, value, len); + } + + void ConnectionPreparedStatement::SetString(uint8 index, std::string str) + { + MYSQL_BIND* param = &_bind[index]; + size_t len = str.size() + 1; + param->buffer_type = MYSQL_TYPE_VAR_STRING; + delete [] static_cast(param->buffer); + param->buffer = new char[len]; + param->buffer_length = len; + param->is_null_value = 0; + delete param->length; + param->length = new unsigned long(len-1); + memcpy(param->buffer, str.c_str(), len); + } + + void ConnectionPreparedStatement::SetNull(uint8 index) + { + MYSQL_BIND* param = &_bind[index]; + param->buffer_type = MYSQL_TYPE_NULL; + delete [] static_cast(param->buffer); + param->buffer = NULL; + param->buffer_length = 0; + param->is_null_value = 1; + delete param->length; + param->length = NULL; } - void PreparedStatement::ClearParameters() + void ConnectionPreparedStatement::ClearParameters() { - for (uint8_t i = 0; i < _paramCount; ++i) + for (uint8 i = 0; i < _paramCount; ++i) { delete _bind[i].length; _bind[i].length = NULL; diff --git a/src/server/shared/MySQL/PreparedStatement.h b/src/server/shared/MySQL/PreparedStatement.h index f9a8cb1..191f76a 100644 --- a/src/server/shared/MySQL/PreparedStatement.h +++ b/src/server/shared/MySQL/PreparedStatement.h @@ -1,26 +1,136 @@ #ifndef PREPARED_STATEMENT_MYSQL_H_ #define PREPARED_STATEMENT_MYSQL_H_ -#include +#include + +#include #include +#include namespace MySQL { enum PreparedStatementFlags { - STMT_SYNC, - STMT_ASYNC, - STMT_BOTH = STMT_SYNC | STMT_ASYNC + STMT_SYNC = 1, + STMT_ASYNC = 2, + STMT_BOTH = STMT_SYNC | STMT_ASYNC + }; + + typedef boost::variant MySQLValue; + + enum MySQLValueTypes + { + MYSQL_UINT8, + MYSQL_UINT16, + MYSQL_UINT32, + MYSQL_UINT64, + MYSQL_INT8, + MYSQL_INT16, + MYSQL_INT32, + MYSQL_INT64, + MYSQL_FLOAT, + MYSQL_DOUBLE, + MYSQL_STRING, + MYSQL_NULL }; + struct PreparedStatementData + { + MySQLValue value; + MySQLValueTypes type; + }; + + // High level stmt class PreparedStatement { + friend class ConnectionPreparedStatement; + friend class DatabaseConnection; + public: - PreparedStatement(MYSQL_STMT* stmt); - ~PreparedStatement(); + PreparedStatement(uint32 index) + { + _index = index; + } template - void Set(const uint8_t index, const T value); + void Set(const uint8 index, const T value, const MySQLValueTypes type) + { + if (index >= data.size()) + data.resize(index+1); + + data[index].value = value; + data[index].type = type; + } + + void SetBool(const uint8 index, const bool value) { + SetUInt8(index, value ? 1 : 0); + } + void SetUInt8(const uint8 index, const uint8 value) { + Set(index, value, MYSQL_UINT8); + } + void SetUInt16(const uint8 index, const uint16 value) { + Set(index, value, MYSQL_UINT16); + } + void SetUInt32(const uint8 index, const uint32 value) { + Set(index, value, MYSQL_UINT32); + } + void SetUInt64(const uint8 index, const uint64 value) { + Set(index, value, MYSQL_UINT64); + } + void SetInt8(const uint8 index, const int8 value) { + Set(index, value, MYSQL_INT8); + } + void SetInt16(const uint8 index, const int16 value) { + Set(index, value, MYSQL_INT16); + } + void SetInt32(const uint8 index, const int32 value) { + Set(index, value, MYSQL_INT32); + } + void SetInt64(const uint8 index, const int64 value) { + Set(index, value, MYSQL_INT64); + } + void SetFloat(const uint8 index, const float value) { + Set(index, value, MYSQL_FLOAT); + } + void SetDouble(const uint8 index, const double value) { + Set(index, value, MYSQL_DOUBLE); + } + void SetString(const uint8 index, const std::string& value) { + Set(index, value, MYSQL_STRING); + } + void SetNull(const uint8 index) { + Set(index, 0, MYSQL_NULL); + } + + protected: + std::vector data; + uint32 _index; + }; + + // Connection specific stmt + class ConnectionPreparedStatement + { + public: + ConnectionPreparedStatement(MYSQL_STMT* stmt); + ~ConnectionPreparedStatement(); + + /*template + void Set(const uint8_t index, const T value);*/ + void BindParameters(PreparedStatement* stmt); + + void SetValue(uint8 index, enum_field_types type, const void* value, uint32 len, bool isUnsigned); + void SetString(uint8 index, std::string str); + void SetNull(uint8 index); + + MYSQL_STMT* GetSTMT() + { + return _stmt; + } + + MYSQL_BIND* GetBind() + { + return _bind; + } void ClearParameters(); private: diff --git a/src/server/shared/MySQL/QueryResult.cpp b/src/server/shared/MySQL/QueryResult.cpp index 7cca749..371610d 100644 --- a/src/server/shared/MySQL/QueryResult.cpp +++ b/src/server/shared/MySQL/QueryResult.cpp @@ -3,51 +3,127 @@ namespace MySQL { - ResultSet::ResultSet(MYSQL_RES* result, MYSQL_FIELD* fields, uint64_t rowCount, uint32_t fieldCount) : - _result(result), _fields(fields), _rowCount(rowCount), _fieldCount(fieldCount) + // Normal Query + ResultSet::ResultSet(MYSQL_RES* result, MYSQL_FIELD* resultFields, uint64 rowCount, uint32 fieldCount) : + _rowCount(rowCount), _fieldCount(fieldCount), _currentRow(0) { - _currentRow = new Field[_fieldCount]; + MYSQL_ROW row; + + while (row = mysql_fetch_row(result)) + { + Field* fields = new Field[_fieldCount]; + + for (uint32 i = 0; i < _fieldCount; ++i) + fields[i].SetValue(row[i], resultFields[i].type); + + _rows.push_back(fields); + } + + // We have it locally now! + mysql_free_result(result); } - - ResultSet::~ResultSet() + + // Prepared statement query + ResultSet::ResultSet(MYSQL_RES* result, MYSQL_STMT* stmt, uint64 rowCount, uint32 fieldCount) : + _rowCount(rowCount), _fieldCount(fieldCount), _currentRow(0) { - } - - bool ResultSet::NextRow() - { - MYSQL_ROW row; - - if (!_result) { - sLog.Debug(LOG_DATABASE, "QueryResultMySQL::NextRow(): Empty result"); - return false; + if (stmt->bind_result_done) { + delete[] stmt->bind->length; + delete[] stmt->bind->is_null; } - - row = mysql_fetch_row(_result); - if (!row) - { - sLog.Debug(LOG_DATABASE, "QueryResultMySQL::NextRow(): End of result"); - CleanUp(); - return false; + + // Store entire result set locally from server + if (mysql_stmt_store_result(stmt)) { + sLog.Error(LOG_DATABASE, "mysql_stmt_store_result, cannot bind result from MySQL server. Error: %s", mysql_stmt_error(stmt)); + return; } + + // This is where we will store data + MYSQL_BIND* bind = new MYSQL_BIND[fieldCount]; + my_bool* isNull = new my_bool[fieldCount]; + unsigned long* length = new unsigned long[fieldCount]; - for (uint32_t i = 0; i < _fieldCount; i++) - _currentRow[i].SetValue(row[i], _fields[i].type); + // Reset + memset(bind, 0, sizeof(MYSQL_BIND) * fieldCount); + memset(isNull, 0, sizeof(my_bool) * fieldCount); + memset(length, 0, sizeof(unsigned long) * fieldCount); - return true; - } + // Prepare result buffer based on metadata + uint32 i = 0; + while (MYSQL_FIELD* field = mysql_fetch_field(result)) { + size_t size = Field::SizeForType(field); + + bind[i].buffer_type = field->type; + bind[i].buffer = malloc(size); + memset(bind[i].buffer, 0, size); + bind[i].buffer_length = size; + bind[i].length = &length[i]; + bind[i].is_null = &isNull[i]; + bind[i].error = NULL; + bind[i].is_unsigned = field->flags & UNSIGNED_FLAG; - void ResultSet::CleanUp() - { - if (_currentRow) - { - delete [] _currentRow; - _currentRow = NULL; + ++i; } - if (_result) + // Bind result buffer to the statement + if (mysql_stmt_bind_result(stmt, bind)) { + sLog.Error(LOG_DATABASE, "mysql_stmt_bind_result, cannot bind result from MySQL server. Error: %s", mysql_stmt_error(stmt)); + delete[] bind; + delete[] isNull; + delete[] length; + return; + } + + _rowCount = mysql_stmt_num_rows(stmt); + + while (_NextSTMTRow(stmt)) { - mysql_free_result(_result); - _result = NULL; + Field* fields = new Field[fieldCount]; + + for (uint64 fIndex = 0; fIndex < fieldCount; ++fIndex) + { + if (!*bind[fIndex].is_null) + fields[fIndex].SetByteValue(bind[fIndex].buffer, bind[fIndex].buffer_length, bind[fIndex].buffer_type, *bind[fIndex].length); + else { + switch (bind[fIndex].buffer_type) + { + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_VAR_STRING: + fields[fIndex].SetByteValue("", bind[fIndex].buffer_length, bind[fIndex].buffer_type, *bind[fIndex].length); + break; + default: + fields[fIndex].SetByteValue(0, bind[fIndex].buffer_length, bind[fIndex].buffer_type, *bind[fIndex].length); + } + } + } + + _rows.push_back(fields); } + + mysql_free_result(result); + } + + bool ResultSet::_NextSTMTRow(MYSQL_STMT* stmt) + { + uint8 ret = mysql_stmt_fetch(stmt); + + if (!ret || ret == MYSQL_DATA_TRUNCATED) + return true; + + if (ret == MYSQL_NO_DATA) + return false; + + sLog.Error(LOG_DATABASE, "mysql_stmt_fetch, cannot fetch result row. Error: %s", mysql_stmt_error(stmt)); + return false; + } + + ResultSet::~ResultSet() + { + for (uint32 i = 0; i < _rowCount; ++i) + delete[] _rows[i]; } } diff --git a/src/server/shared/MySQL/QueryResult.h b/src/server/shared/MySQL/QueryResult.h index bcd41db..4599a92 100644 --- a/src/server/shared/MySQL/QueryResult.h +++ b/src/server/shared/MySQL/QueryResult.h @@ -1,9 +1,12 @@ #ifndef QUERY_RESULT_MYSQL_H_ #define QUERY_RESULT_MYSQL_H_ +#include "Common.h" #include "Field.h" + #include #include +#include #include namespace MySQL @@ -12,34 +15,37 @@ namespace MySQL class ResultSet { public: - ResultSet(MYSQL_RES* result, MYSQL_FIELD* fields, uint64_t rowCount, uint32_t fieldCount); + // Normal query + ResultSet(MYSQL_RES* result, MYSQL_FIELD* resultFields, uint64 rowCount, uint32 fieldCount); + // Prepared statement query + ResultSet(MYSQL_RES* result, MYSQL_STMT* stmt, uint64 rowCount, uint32 fieldCount); ~ResultSet(); // Metadata - uint64_t GetRowCount() + uint64 GetRowCount() { return _rowCount; } - uint32_t GetFieldCount() + uint32 GetFieldCount() { return _fieldCount; } - bool NextRow(); - - Field* Fetch() + Field* FetchRow() { - return _currentRow; + if (_currentRow >= _rowCount) + return NULL; + + return _rows[_currentRow++]; } private: - uint64_t _rowCount; - Field* _currentRow; - uint32_t _fieldCount; - void CleanUp(); - MYSQL_RES* _result; - MYSQL_FIELD* _fields; + bool _NextSTMTRow(MYSQL_STMT* stmt); + uint64 _rowCount; + uint64 _currentRow; + uint32 _fieldCount; + std::vector _rows; }; typedef boost::shared_ptr QueryResult; diff --git a/src/server/shared/Util.h b/src/server/shared/Util.h index 216d476..1de5dd0 100644 --- a/src/server/shared/Util.h +++ b/src/server/shared/Util.h @@ -1,13 +1,14 @@ #ifndef UTIL_H_ #define UTIL_H_ +#include + #include #include #include #include #include #include -#include #include #include #include