機械学習で利用できる「トレーニング」「推論」「説明」のSQL関数を理解しよう
はじめに
今回は、MySQL HeatWave MLで実現できる機械学習について、「トレーニング」「推論」「説明」を実行できるSQL関数を紹介します。
トレーニング
ML_TRAINは分類モデルと回帰モデルをサポートし、指定されたテーブルに対する機械学習モデルを「トレーニング」します。指定するテーブルは secondary_engineとしてRAPIDを指定する必要があります。トレーニングに必要な時間は、データセット内の行と列の数、指定されたパラメーター、およびHeatWaveクラスターのサイズに応じて、数分から数時間かかる場合があります。
ML_TRAIN(テーブル名,対象列名,[オプション],モデル名) テーブル名: ラベル付けされたトレーニングデータセットを含むテーブルの名前。テーブル名は有効で完全修飾されている必要があり、スキーマ名.テーブル名の形でスキーマ名を含める 対象列名: ターゲット列名の名前 オプション: 分類: NULL または JSON_OBJECT('task', 'classification') 回帰: JSON_OBJECT('task', 'regression') モデル名: 接続中に機械学習モデルのハンドルを保存するユーザー定義のセッション変数の名前 【実行例】 ml_data.iris_trainテーブルのclass列に対する分類 mysql> CALL sys.ML_TRAIN(‘ml_data.iris_train’,’class’, JSON_OBJECT('task', 'classification'),@iria_model); Query OK, 0 rows affected (24.3940 sec)
https://dev.mysql.com/doc/heatwave/en/hwml-ml-train.html
推論
MySQL HeatWave MLでは対象データに応じて2つの関数が用意されています。なお、いずれも場合も使用する前に「トレーニング」済みモデルをロードする必要があります。
ML_MODEL_LOAD(「トレーニング」済みモデル名,ユーザー) 「 トレーニング」済みモデル名: モデルハンドルまたはモデルハンドルを含むセッション変数を指定 ユーザー: モデル所有者のMySQLユーザー名。モデルの所有者が現在のユーザーである場合は、NULL指定可能 【実行例】 mysql> CALL sys.ML_MODEL_LOAD(@iria_model,NULL); Query OK, 0 rows affected (0.7262 sec)
https://dev.mysql.com/doc/heatwave/en/hwml-ml-model-load.html
データ行の推論
ML_PREDICT_ROWはJSON形式で指定されたデータの1つ以上の行の推論を生成します。ML_PREDICT_ROWはSELECTステートメントを使用して呼び出されます。
ML_PREDICT_ROW(対象データ,「トレーニング」済みモデル名) 対象データ: 予測を生成するデータを指定 *モデルの「トレーニング」に利用したテーブルの列名をキーとしたKey-Value型のデータ構造を JSONオブジェクトで指定 *列を形式のキーと値のペアとして指定しJSONテーブルから選択することにより、複数行の設定が可能 「 トレーニング」済みモデル名: モデルハンドルまたはモデルハンドルを含むセッション変数を指定 【実行例】 mysql> SET @row_input = JSON_OBJECT( "sepal length", 7.3, "sepal width", 2.9, "petal length", 6.3, "petal width", 1.8); mysql> SELECT sys.ML_PREDICT_ROW(@row_input, @iris_model); ----------------------------------------------------------------------------+ | sys.ML_PREDICT_ROW(@row_input, @iris_model) | +---------------------------------------------------------------------------+ | {"Prediction": "Iris-virginica", "petal width": 1.8, "sepal width": 2.9, | | "petal length": 6.3, "sepal length": 7.3} | +---------------------------------------------------------------------------+
https://dev.mysql.com/doc/heatwave/en/hwml-ml-predict-row.html
テーブルの推論
ML_PREDICT_TABLEはラベルなしのデータのデーブル全体の推論を生成し、結果を出力テーブルに保存します。この処理は計算集約型のプロセスで行われます。そのため、大きなテーブルを指定する場合は小さいテーブルに分割し、操作を10〜100行のバッチに制限することをお勧めします。
ML_PREDICT_TABLE(入力テーブル,「トレーニング」済みモデル名,出力テーブル) 入力テーブル: 入力テーブルの完全修飾名を指定(schema_name.table_name)。入力テーブルには、トレーニングデータセットと同じ特徴列が含まれている必要がある 「 トレーニング」済みモデル名: モデルハンドルまたはモデルハンドルを含むセッション変数を指定 出力テーブル: 予測が保存されるテーブルを指定。テーブルが存在しない場合は作成される。出力テーブルも完全修飾テーブル名で指定する必要がある 【実行例】 mysql> CALL sys.ML_PREDICT_TABLE('ml_data.iris_test', @iris_model, 'ml_data.iris_predictions'); Query OK, 0 rows affected (5.3438 sec) mysql> SELECT * FROM iris_predictions; *************************** 1. row *************************** sepal length: 5.9 sepal width: 3 petal length: 4.2 petal width: 1.5 Prediction: Iris-setosa *************************** 2. row *************************** sepal length: 6.9 sepal width: 3.1 petal length: 5.4 petal width: 2.1 Prediction: Iris-virginica
https://dev.mysql.com/doc/heatwave/en/hwml-ml-predict-table.html
説明
説明は、ラベルのないデータに対してML_EXPLAIN_ROWまたはML_EXPLAIN_TABLEを実行することで生成されます。モデルのトレーニングに使用されるデータと同じ特徴列が必要ですが、ターゲット列は必要ありません。ML_EXPLAIN_ROWは1行以上のデータの説明を生成します。ML_EXPLAIN_TABLEはデータのテーブル全体に関する説明を生成し、結果を出力テーブルに保存します。
特徴の重要度は、-1から1の範囲の値として表されます。正の値は特徴が予測に貢献したことを示し、負の値は機能が異なる予測に貢献したことを示します。例えば、2つの可能な予測(「承認」と「拒否」)を持つローン承認モデルの機能が「承認」予測に対して負の値を持っている場合、その機能は「拒否」予測に対して正の値を持ちます。値0または0に近い値は、特徴値が適用される予測に影響を与えないことを示しています。
ML_EXPLAIN_ROW(対象データ,「トレーニング」済みモデル名) ML_EXPLAIN_TABLE(入力テーブル,「トレーニング」済みモデル名,出力テーブル) 対象データ: 予測を生成するデータを指定 *モデルの「トレーニング」に利用したテーブルの列名をキーとしたKey-Value型のデータ構造を JSONオブジェクトで指定 *列を形式のキーと値のペアとして指定しJSONテーブルから選択することにより、複数行の設定が可能 入力テーブル: 入力テーブルの完全修飾名を指定(schema_name.table_name)。入力テーブルには、トレーニングデータセットと同じ特徴列が含まれている必要がある 出力テーブル: 予測が保存されるテーブルを指定。テーブルが存在しない場合は作成される。出力テーブルも完全修飾テーブル名で指定する必要がある 「 トレーニング」済みモデル名: モデルハンドルまたはモデルハンドルを含むセッション変数を指定 【実行例】 mysql> SET @row_input = JSON_OBJECT( "sepal length", 7.3, "sepal width", 2.9, "petal length", 6.3, "petal width", 1.8); mysql> SELECT sys.ML_EXPLAIN_ROW(@row_input,@iria_model); / mysql> CALL sys.ML_EXPLAIN_TABLE(‘heatwaveml_bench.census_test_subset’, @iria_model, ’heateaveml_bench.census_explanations’); mysql> SELECT * FROM heatwaveml_bench.census_explanations’;
https://dev.mysql.com/doc/heatwave/en/hwml-ml-explain-row.html
https://dev.mysql.com/doc/heatwave/en/hwml-ml-explain-table.html