Amazon SageMakerの基本的な使い方を理解する(4)

【連載】

AWSではじめる機械学習 ~サービスを知り、実装を学ぶ~

【第8回】Amazon SageMakerの基本的な使い方を理解する(4)

[2020/09/17 10:00]菊地貴彰 ブックマーク ブックマーク

前回は、Amazon SageMakerのチュートリアルをベースに、学習を行って機械学習モデルを構築するところまで説明しました。今回は、構築した機械学習モデルを使って、推論を実行してみましょう。

今回の行程は、下記の表の「ステップ 5:モデルのデプロイ」「ステップ 6:モデルの性能評価」「ステップ 7:リソースを終了する」に当たります。4回に渡ってAmazon SageMakerの使い方を解説してきましたが、今回で完了です。

工程 ステップ 枝番 実施内容 連載回
開発 1 Amazon SageMakerコンソールにログインする 第5回
2 Amazon SageMakernotebook instance を作成する 第5回
3 データの準備 第6回
3a, 3b ノートブックを起動する
3c ノートブックの利用準備をする
3d S3 バケットを作成する
3e 学習/推論に利用するデータをダウンロードする
3f データを分割する
学習 4 データからのモデルのトレーニング 第7回
4a 学習データを S3 バケットにアップロードする
4b 学習の設定をする
4c 学習を行う
推論 5 モデルのデプロイ 第8回
5a 推論エンドポイントを作成して、モデルをデプロイする
5b 推論を行う
6 モデルの性能評価 第8回
後片付け 7 リソースを終了する 第8回

ステップ 5:モデルのデプロイ

学習を実行し、機械学習モデルの構築まで完了しました。ここでは、機械学習モデルのデプロイと推論データを使って推論(見込み顧客の予測)を行います。

概要図

ステップ 5a:推論エンドポイントを作成して、モデルをデプロイする

ここで言う「デプロイ」とは、APIサーバの構築を指します。deployメソッドを利用することで、下記に示す一連の処理を実行できます。

  • 推論用インスタンスの作成
  • S3バケットからの機械学習モデルのダウンロード
  • 推論用コンテナの起動
  • APIの公開

では、下記のコードをノートブックのセルにコピー&ペーストして実行してください。

xgb_predictor = xgb.deploy(initial_instance_count=1,instance_type='ml.m4.xlarge')

なお、デプロイが開始されると、下記の画像のように処理の進行と共に「- (ハイフン)」が出力されていきます。

デプロイ開始

デプロイが完了すると、完了を示す「!(エクスクラメーションマーク)」が出力されます。デプロイ処理は数分で完了します。

デプロイ完了

先述のコードを見やすく改行したものを見ながら解説していきましょう。

xgb_predictor = xgb.deploy(
  initial_instance_count=1,
  instance_type='ml.m4.xlarge'
)

このコードでは、Amazon SageMaker SDK for Pythonのdeployメソッドを使って、機械学習モデルのデプロイを実行しています。ここで設定するパラメータ(引数)とそれぞれの意味は下表の通りです。

パラメータ(引数) パラメータの意味
initial_instance_count 推論用インスタンス数を設定する
instance_type 推論用インスタンスのインスタンスタイプを設定する。MLインスタンスとして様々なタイプが用意されており、学習の特性に応じたものを選択する。ここでは、汎用(M4)の「xlarge」を設定している。その他に利用できるインスタンスタイプと料金は「Amazon SageMaker ML インスタンスタイプ」と「Amazon SageMaker の料金」を参照されたい

具体的には、推論用インスタンスを作成して、S3バケットから学習済みの機械学習モデルを、Amazon ECRから推論用コンテナイメージをダウンロードして、推論用コンテナを起動しています。

また、アプリなどからアクセスするためのエンドポイントを作成しています。エンドポイントの接続には、HTTPSを利用します。

エンドポイントはAmazon SageMakerのコンソール画面から確認することができます。左側の折りたたみメニューから「推論」をクリックして開き、「エンドポイント」をクリックします。すると、以下のようにそのアカウントとリージョンで稼働しているエンドポイントを確認することができます。

エンドポイント

エンドポイントを選択すると、詳細画面に移動します。エンドポイントに関するさまざまな情報が提示されますが、「エンドポイント設定」の「URL」でデプロイした機械学習モデルに接続するためのエンドポイントを確認することができます。

エンドポイント

ステップ 5b:推論を行う

下記のコードをセルにコピー&ペーストして実行してください。

test_data_array = test_data.drop(['y_no', 'y_yes'], axis=1).values #load the data into an array
xgb_predictor.content_type = 'text/csv' # set the data type for an inference
xgb_predictor.serializer = csv_serializer # set the serializer type
predictions = xgb_predictor.predict(test_data_array).decode('utf-8') # predict!
predictions_array = np.fromstring(predictions[1:], sep=',') # and turn the prediction into an array
print(predictions_array.shape)

実行した結果出力される (12357,) は、予測した顧客の数 (12,357人) です。

では、順番に見ていきましょう。

test_data_array = test_data.drop(['y_no', 'y_yes'], axis=1).values #load the data into an array

ステップ 3eでデータをダウンロードし、ステップ 3fでそのデータを学習データと推論データ(テストデータ)に分割しました。推論データは「test_data」に格納されていますが、ステップ 3fでは単純に分割しただけで、正解(「y_yes」と「y_no」)が含まれています。顧客の属性情報から見込み顧客であるか否かを予測したいので、正解の値を削除(drop)しています。

続く以下の部分では、推論データの「Content type」とシリアライザのタイプにCSV用のものを指定しています。

xgb_predictor.content_type = 'text/csv' # set the data type for an inference
xgb_predictor.serializer = csv_serializer # set the serializer type

次に、predictメソッドを使って推論データを推論エンドポイントに送信し、推論結果を得ているのが以下です。

predictions = xgb_predictor.predict(test_data_array).decode('utf-8') # predict!

最後に、推論結果がカンマ区切りのテキストデータで返されるので、後続の精度の評価をするために「Numpy」のfromstringメソッドを使ってArrayに変換しています。

predictions_array = np.fromstring(predictions[1:], sep=',') # and turn the prediction into an array

ステップ 6:モデルの性能評価

推論結果を得ることができたので、最後に精度の評価を行います。ここでは、二値分類の精度評価においてよく利用される「混同行列(confusion matrix)」を使います。下記のコードをセルにコピー&ペーストして実行してください。

cm = pd.crosstab(index=test_data['y_yes'], columns=np.round(predictions_array), rownames=['Observed'], colnames=['Predicted'])
tn = cm.iloc[0,0]; fn = cm.iloc[1,0]; tp = cm.iloc[1,1]; fp = cm.iloc[0,1]; p = (tp+tn)/(tp+tn+fp+fn)*100
print("\n{0:<20}{1:<4.1f}%\n".format("Overall Classification Rate: ", p))
print("{0:<15}{1:<15}{2:>8}".format("Predicted", "No Purchase", "Purchase"))
print("Observed")
print("{0:<15}{1:<2.0f}% ({2:<}){3:>6.0f}% ({4:<})".format("No Purchase", tn/(tn+fn)*100,tn, fp/(tp+fp)*100, fp))
print("{0:<16}{1:<1.0f}% ({2:<}){3:>7.0f}% ({4:<}) \n".format("Purchase", fn/(tn+fn)*100,fn, tp/(tp+fp)*100, tp))

このコードでは、Pandasのcrosstab関数を使って、推論結果のクロス集計を行っています。ここで、精度評価で利用している混同行列について補足しておきましょう。

混同行列

行方向(縦軸)が「正解(観測; Observed)」、列方向(横軸)が「予測(Predicted)」を表します。「正解(2種類) * 予測(2種類)」で、以下の4つの指標があります。

  • True Positive(TP):定期預金を申し込むと予測して、実際に申し込んだ顧客の数
  • False Poritive(FP):定期預金を申し込むと予測したが、実際には申し込まなかった顧客の数
  • True Negative(TN):定期預金を申し込まないと予測して実際に申し込まなかった顧客の数
  • False Negative(FN):定期預金を申し込まないと予測したが、実際には申し込んだ顧客の数

混同行列という名前通り混乱しやすいですが、Positive/Negativeはあくまで「予測」に対してかかっている言葉です。その予測の正解と不正解に応じて、True/Negativeが付いていると考えると理解しやすいと思います。

推論結果についての考察

上記のコードを実行すると、次のような出力結果を得られます。

出力結果

「Overall Classification Rate」で示されている数値が正解率です。今回は、89.5%でした。「True Negative」が90%である一方で、「True Poritive」は65%とかなり低いように見えます。なぜでしょうか?

学習データ(train_data)には、28,831人分の顧客データを利用しました。このうち、定期預金を申し込んだ顧客数(「y_yes」が「1」の合計)と申し込まなかった(「y_no」が「1」の合計)を確認してみます。

確認

定期預金を申し込んだ顧客数が3,219 人であることに対して、定期預金を申し込まなかった顧客数は25,612人と5倍近い差があることがわかります。定期預金を申し込んだ顧客のデータ数が不足している可能性があり、後2.2万人程度のデータを用意して再度学習をすると、同程度の精度を出せる可能性があります。

二値分類における各ラベルの学習データは、十分な量を均等に用意することが望ましいと言われています。今回はUCIが公開しているオープンデータを利用しているので、データを増やして再検証をすることはできません。実際の業務で同様の問題が発生した場合は、顧客の申し込み履歴がデータベースに格納されているのであれば、さらにデータを抽出するか、そもそもこれ以上データが存在しない場合は新たにデータを取得して増やさなければなりません。

また、「True Negative」が90%と比較的高い数値を示しているように見えますが、この数値で十分な精度を確保できているのかどうかは業務要件次第です。100%に近い数字でなければならない場合は、定期預金を申し込まなかった顧客のデータも増やさなければならない可能性があります。

「機械学習にかかる時間の約8割はデータの準備や前処理である」と言われることがありますが、それが決して大げさではないことが今回の結果からも垣間うかがえるかと思います。

データを増やす以外の方法として、ステップ 4bで固定値としたハイパーパラメータをチューニングすることで精度を改善できるかもしれません。いずれにしても機械学習はこのような仮説立案と検証の繰り返しが必要であり、ビジネスにおいて活用していくには根気や労力が必要です。

ステップ 7:リソースを終了する

ここまでで「開発」「学習」「推論」を一通り実施することができました。課金対象のリソースがあるので、課金を防ぐためにリソースを削除します。課金される可能性があるリソースは、下表の通りです。

サービス名 リソース名
Amazon SageMaker ノートブックインスタンス。推論エンドポイント(推論用インスタンス/推論用コンテナ)
Amazon S3 S3バケットに格納されているオブジェクト
Amazon CloudWatch CloudWatch Logsに格納されているログ

チュートリアルのステップ 7の手順では「推論エンドポイント」と「S3バケットに格納されているオブジェクト」しかありませんので、その他のリソースの削除漏れに注意してください。また、請求ダッシュボードやCost Explorerでも課金されていないことを必ず確認してください。

リソースの削除を行うには、以下のコードをセルにコピー&ペーストして実行します。

sagemaker.Session().delete_endpoint(xgb_predictor.endpoint)
bucket_to_delete = boto3.resource('s3').Bucket(bucket_name)
bucket_to_delete.objects.all().delete()

すると、以下の処理が実行されます。

  • 推論用エンドポイントの削除
  • S3バケット内のオブジェクトの削除

また、先述の通り、「ノートブックインスタンス」と「CloudWatch Logs」も課金対象となるので、こちらも忘れずに削除していきます。

ノートブックインスタンスの削除

ノートブックインスタンスを削除するには、「停止」と「削除」の2段階での対応が必要です。

JupyterLabのタブは、Webブラウザの「×」ボタンなどで閉じてください。その上で、AWSマネジメントコンソールのノートブックインスタンスの画面に移動します。削除対象のノートブックインスタンスをラジオボタンで選択し、「アクション」で「停止」を選択します。

停止処理に移ると、「ステータス」が「InService」から「Stopping」に移行しますので、数分待ちます。

停止処理

「ステータス」が「Stopped」になったことを確認して、「アクション」で「削除」をクリックします。

削除を選択

ノートブックインスタンスの削除を確認するウィンドウが表示されますので、「削除」をクリックします。

削除

「ステータス」が「Deleting」に移行し、数分待つとノートブックインスタンスが画面からなくなります。これでノートブックインスタンスの削除は完了です。

CloudWatch Logsの削除

次に、CloudWatch Logs の削除を行います。AWSマネジメントコンソールのトップ画面などから「CloudWatch」を検索して移動します 。

「CloudWatch」を検索

CloudWatchの画面に移動できたら、左側のメニューから「ロググループ」を選択します。

「ロググループ」を選択

ロググループの画面に移動できたら、検索バーに「sagemaker」と入力してログを検索します。削除するログを全て選択して、「アクション」で「ロググループの削除」をクリックします。

ログを削除

ロググループの削除を確認するウィンドウが表示されるので、「削除」をクリックします。

削除

ロググループが画面からなくなったら削除は完了です。

なお、チュートリアルではS3バケットとIAMロールも作成しましたが、これらは存在するだけでは課金されません。それぞれのサービスの画面から必要に応じて削除してください。

* * *

今回まで数回に渡り、Amazon SageMakerを実際に使って開発/学習/推論をする方法について説明しました。次回は、機械学習用の統合開発環境「Amazon SageMaker Studio」の使い方について見ていきたいと思います。

著者紹介


菊地 貴彰 (KIKUCHI Takaaki) - 株式会社NTTデータ
システム技術本部 デジタル技術部
Agile プロフェッショナル担当

大学・大学院では、機械学習を専攻。ベイズ的枠組みを用いて、複数の遺伝子のデータから遺伝子どうしの相互作用ネットワークの推定に関する研究を行った。

株式会社NTTデータに入社後は、法人や金融のシステム開発のシステム基盤担当としてキャリアを積み、現在はデジタル技術やAgile開発を専門に扱う組織でシステム開発全般を担当する。
2019, 2020 APN AWS Top Engineers, Japan APN Ambassador 2020に選出に選出。

本連載の内容に対するご意見・ご質問はtwitter: @kikuchitk7まで。

※ 本記事は掲載時点の情報であり、最新のものとは異なる場合がございます。予めご了承ください。

一覧はこちら

連載目次

もっと知りたい!こちらもオススメ

三菱UFJ信託銀行はAI導入の「落とし穴」をどうやって乗り越えたのか?

三菱UFJ信託銀行はAI導入の「落とし穴」をどうやって乗り越えたのか?

DataRobot社は11月20日、プライベートカンファレンス「AI Experience 2019 Tokyo」を都内にて開催した。「AIで成功した先進企業ではビジネスにどんな変化が起こっているのか」をテーマに掲げた同カンファレンスでは、さまざまな分野から機械学習を活用したビジネス変革を推し進める企業が登壇し、その知見が語られた。

関連リンク

この記事に興味を持ったら"いいね!"を Click
Facebook で IT Search+ の人気記事をお届けします
注目の特集/連載
[解説動画] Googleアナリティクス分析&活用講座 - Webサイト改善の正しい考え方
[解説動画] 個人の業務効率化術 - 短時間集中はこうして作る
ミッションステートメント
教えてカナコさん! これならわかるAI入門
AWSではじめる機械学習 ~サービスを知り、実装を学ぶ~
対話システムをつくろう! Python超入門
Kubernetes入門
SAFeでつくる「DXに強い組織」~企業の課題を解決する13のアプローチ~
PowerShell Core入門
AWSで作るマイクロサービス
マイナビニュース スペシャルセミナー 講演レポート/当日講演資料 まとめ
セキュリティアワード特設ページ

一覧はこちら

今注目のIT用語の意味を事典でチェック!

一覧はこちら

会員登録(無料)

ページの先頭に戻る