ML.NET 示例:二元分類之垃圾簡訊檢測
寫在前面
準備近期將微軟的machinelearning-samples翻譯成中文,水平有限,如有錯漏,請大家多多指正。
如果有朋友對此感興趣,可以加入我:ofollow,noindex" target="_blank">https://github.com/feiyun0112/machinelearning-samples.zh-cn
垃圾簡訊檢測
ML.NET 版本 | API 型別 | 狀態 | 應用程式型別 | 資料型別 | 場景 | 機器學習任務 | 演算法 |
---|---|---|---|---|---|---|---|
v0.7 | 動態API | 可能需要更新專案結構以匹配模板 | 控制檯應用程式 | .tsv 檔案 | 垃圾資訊檢測 | 二元分類 | SDCA(線性學習器),還展示了CustomMapping評估器,它可以將自定義程式碼新增到ML.NET管道 |
在這個示例中,您將看到如何使用ML.NET 來預測簡訊是否是垃圾資訊。在機器學習領域中,這種型別的預測被稱為二元分類 。
問題
我們的目標是預測一個簡訊是否是垃圾資訊(一個不相關的/不想要的訊息)。我們將使用UCI的SMS+Spam+Collection" rel="nofollow,noindex" target="_blank">SMS Spam Collection Data Set ,其中包含近6000條被分類為“垃圾資訊”或“ham”(不是垃圾資訊)的訊息。我們將使用這個資料集來訓練一個模型,該模型可以接收新訊息並預測它們是否是垃圾資訊。
這是一個二元分類的示例,因為我們將簡訊分類為兩個類別。
解決方案
要解決這個問題,首先我們將建立一個評估器來定義我們想要使用的機器學習管道。 然後,我們將在現有資料上訓練這個評估器,評估其有多好,最後我們將使用該模型來預測一些示例訊息是否是垃圾資訊。
1. 建立評估器
為了建立評估器,我們將:
-
定義如何讀取從https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection 下載的垃圾資訊資料集。
-
應用多個數據轉換:
- 將標籤(“spam”或“ham”)轉換為布林值(“true”表示垃圾資訊),這樣我們就可以在二元分類器中使用它。
- 將簡訊轉換為數字向量,以便機器學習訓練器可以使用它
-
新增一個訓練器(如
StochasticDualCoordinateAscent
)。
初始程式碼類似以下內容:
// Set up the MLContext, which is a catalog of components in ML.NET. var mlContext = new MLContext(); // Create the reader and define which columns from the file should be read. var reader = new TextLoader(mlContext, new TextLoader.Arguments() { Separator = "tab", HasHeader = true, Column = new[] { new TextLoader.Column("Label", DataKind.Text, 0), new TextLoader.Column("Message", DataKind.Text, 1) } }); var data = reader.Read(new MultiFileSource(TrainDataPath)); // Create the estimator which converts the text label to boolean, featurizes the text, and adds a linear trainer. var estimator = mlContext.Transforms.CustomMapping<MyInput, MyOutput>(MyLambda.MyAction, "MyLambda") .Append(mlContext.Transforms.Text.FeaturizeText("Message", "Features")) .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());
2. 評估模型
對於這個資料集,我們將使用交叉驗證 來評估我們的模型。將資料集劃分成5個不相交的子集,訓練5個模型(每個模型使用其中4個子集),並在訓練中沒有使用的資料子集上測試模型。
var cvResults = mlContext.BinaryClassification.CrossValidate(data, estimator, numFolds: 5); var aucs = cvResults.Select(r => r.metrics.Auc); Console.WriteLine("The AUC is {0}", aucs.Average());
請注意,通常我們在訓練後評估模型。 但是,交叉驗證包括模型訓練部分,因此我們不需要先執行Fit()
。 但是,我們稍後將在完整資料集上訓練模型以利用其他資料。
3. 訓練模型
為了訓練模型,我們將呼叫評估器的Fit()
方法,同時提供完整的訓練資料。
var model = estimator.Fit(data);
4. 使用模型
訓練完模型後,您可以使用Predict()
API來預測新文字是否垃圾資訊。 在這種情況下,我們更改模型的閾值以獲得更好的預測。 我們這樣做是因為我們的資料有偏差,大多數訊息都不是垃圾資訊。
// The dataset we have is skewed, as there are many more non-spam messages than spam messages. // While our model is relatively good at detecting the difference, this skewness leads it to always // say the message is not spam. We deal with this by lowering the threshold of the predictor. In reality, // it is useful to look at the precision-recall curve to identify the best possible threshold. var inPipe = new TransformerChain<ITransformer>(model.Take(model.Count() - 1).ToArray()); var lastTransformer = new BinaryPredictionTransformer<IPredictorProducing<float>>(mlContext, model.LastTransformer.Model, inPipe.GetOutputSchema(data.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f, thresholdColumn: DefaultColumnNames.Probability); ITransformer[] parts = model.ToArray(); parts[parts.Length - 1] = lastTransformer; var newModel = new TransformerChain<ITransformer>(parts); // Create a PredictionFunction from our model var predictor = newModel.MakePredictionFunction<SpamInput, SpamPrediction>(mlContext); var input = new SpamInput { Message = "free medicine winner! congratulations" }; Console.WriteLine("The message '{0}' is {1}", input.Message, predictor.Predict(input).isSpam ? "spam" : "not spam");