こんにちは。データサイエンス部の石川です。
弊社では広告配信の最適化のために CTR・CVR*1 を推定する機械学習モデルを開発していて、定期的な学習とモデルの更新を行っています。
このようなシステムにおいて、学習済みモデルが推論システムで問題なく動作することを保証するために、デプロイされるモデルの挙動を検証する仕組みが必要です。 特に、学習時と推論時で同一の広告リクエストに対して同じ推論値を出力するかを確認する仕組みを「差分検知」と呼んでいます。
この記事では、弊社の広告システムにおける機械学習モデルの差分検知について紹介します。
背景
弊社の CTR・CVR を推定する機械学習システムでは、ワークフローエンジンが定期的にモデルの学習を実行し、その後学習済みモデルを S3 にアップロードします。 広告スコアリングサーバは S3 上のモデル変更を検知し、学習済みモデルをロードして CTR・CVR 推定を行います。
学習ロジックは Python で記述されていますが、スコアリングサーバは Rust (以前は Go )で記述されているため、学習時の処理の一部(例えば、データ前処理や特徴量生成)はスコアリングサーバでも同様に行われる必要があります*2。
そのため、同一の広告リクエストに対して学習時と推論時に同一の推定値を返すか(適切な前処理や特徴量生成が行われ、モデルの入出力値が同じか)を確認する必要があります。そうでなければ、意図しない推定値に基づいて最適化され、収益性や広告効果などのビジネス KPI に悪影響を及ぼす可能性があります。
課題
最近 Go で書かれていた推論 API を Rust でリプレイスをしました。 Go サーバの際は推論用の実行バイナリを作成し、モデル学習時に Python の subprocess ライブラリを利用して実行バイナリを呼び出すことで、推論 API の推論値を取得していました。
しかし Rust 移行の際にアーキテクチャを変更し、複数のコンテナによって実行されるサービスにしたため、今までのように単一の実行バイナリを呼び出すことができなくなりました。 そのため、別の方法でスコアリングサーバの CTR・CVR 推定の差分検知を行える仕組みが必要になりました。
移行時の詳細の話は以前のブログにまとめられているので、ご覧ください。
解決策
要件を整理した上で、以下の 3 つの解決策を考えました。
- 差分検知用サーバの構築
- 差分検知用 API エンドポイントの追加
- 推論処理の Python バインディングの作成
上記の 3 つの解決策について詳細と、メリット・デメリットを整理します。
1. 差分検知用サーバの構築
詳細
- 差分検知用のサーバを構築し、学習ワークフローはそのサーバに対して gRPC リクエストを送る
メリット
- アプリケーションコードへの変更が少ない
- ローカル環境でもテストが容易である
デメリット
- 別のインスタンスが起動するため運用のコストがかかる
2. 差分検知用 API エンドポイントの追加
詳細
- 差分検知用の API エンドポイントを追加し、学習ワークフローはステージング環境の API エンドポイントをコールする
メリット
- 新しい API の追加のみで、実装が容易である
デメリット
- 動作確認のために staging 環境へのデプロイが都度発生し、開発効率が悪い
3. 推論処理の Python バインディングの作成
詳細
- 推論処理の Python バインディングを作成し、学習ワークフローは作成したパッケージをインストールして用いる
メリット
- サーバの起動や staging 環境への修正が不要である
- ポータビリティが高く、CI での動作確認も容易である
デメリット
- モデル管理システムを含めた統合テストにはならない
- Python バインディングの実装が必要になる
以上のメリット・デメリットを検討した結果、3 番の案を採用することになりました。
3 番の案を選んだ理由は、サーバの起動が不要で個人の開発環境をすぐ利用できることを重視したためです。実際、新しい特徴量の追加時には推論 API 側の修正が必要になりますが、その際のトライアンドエラーが容易になりました。
また、デメリットとして挙がった統合テストにならない点については単体テストを拡充することで許容することにしました。 もう一つのデメリットだった Python のバインディングの実装については一度実装してしまえばその後大きな修正が入ることは稀だと判断し、頑張って実装することにしました。
PyO3 の実装
Rust で実装した CTR・CVR 推定モジュールを Python から実行できるように PyO3 を利用しました。
Rust で実装された機能を Python から実行できるようにする PyO3 のイメージについて、 PyO3 のドキュメントにあるサンプルコードを用いて説明します。
まず以下のような Rust コードがあるとします*3。 sum_as_string
という 2 つの引数を足し合わせて文字列にして返す関数があり、これを string_sum
というモジュール*4に追加することで、Python 側から import して参照することができます。
use pyo3::prelude::*; #[pyfunction] fn sum_as_string(a: usize, b: usize) -> PyResult<String> { Ok((a + b).to_string()) } #[pymodule] fn string_sum(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; Ok(()) }
maturin develop
でパッケージをビルド・インストールすることで準備完了です。Python では以下のようにして Rust で実装した機能を利用することが出来ます。
$ python >>> import string_sum >>> string_sum.sum_as_string(5, 20) '25'
同じ要領で Rust で実装された CTR・CVR 推定モジュールを Python から実行できるように整備し、ここで作られたパッケージを含んだ wheel を作成し、学習システムに差分検知の仕組みを取り込むことが出来ました。
PyO3 の導入については以前にブログが書かれているので参照してください。
PyO3 導入によって Python に公開されている関数や構造体を変更する場合、追従するための実装コストが発生する苦労がありました。また、API 側で必要なリクエストのデータクラスを学習側に公開して利用するようにしているので、リクエストのデータクラスのスキーマが変更した場合に学習側でも対応する必要があり、バージョン管理の煩雑さが発生しました。
しかし、ローカルでの開発で実際のスコアリングサーバに依存することなく、API での推論結果を確認可能になりました。
まとめ
この記事では、学習時と推論時で同一の広告リクエストに対して同じ推論値を出力するかを確認する機構について紹介しました。
今回の差分検知の方法によって、API の推論結果の開発時の確認が容易になり、トライアンドエラーが簡単になりました。
現在は推論結果の差分の有無だけを確認する機構ですが、デバッグをより容易にするために、どの特徴量の変換が間違っているかを把握できるように対応する予定です。