How to predict multiple labels with ML.NET using regression task?

 ̄綄美尐妖づ 提交于 2019-12-02 01:38:36

Probably it's not working, because Fit() only returns "Label" and "Score"

Look here: here

Your Score from "TripTime" is overwritten by "FareAmount".

I guess, you have to build two models.

edited: you can try this. Copy "Score" to the right place.

public class TripFarePrediction // this class is used to store prediction result
{
    [ColumnName("fareAmount")]
    public float FareAmount { get; set; }

    [ColumnName("tripTime")]
    public float TripTime { get; set; }
}


private static ITransformer Train(MLContext mlContext, string trainDataPath)
{
    IDataView dataView = _textLoader.Read(trainDataPath);
    var pipelineForTripTime = mlContext.Transforms.CopyColumns("Label", "TripTime")
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
    .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripDistance", "PaymentType"))
    .Append(mlContext.Regression.Trainers.FastTree())
    .Append(mlContext.Transforms.CopyColumns(outputcolumn: "tripTime", inputcolumn: "Score"));

    var pipelineForFareAmount = mlContext.Transforms.CopyColumns("Label", "FareAmount")
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
    .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripDistance", "PaymentType"))
    .Append(mlContext.Regression.Trainers.FastTree())
    .Append(mlContext.Transforms.CopyColumns(outputcolumn: "fareAmount", inputcolumn: "Score"));



    var model = pipelineForTripTime.Append(pipelineForFareAmount).Fit(dataView);
    SaveModelAsFile(mlContext, model);
    return model;
}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!