public class LogRegressionMultiClassTrainer<P extends Serializable> extends SingleLabelDatasetTrainer<LogRegressionMultiClassModel>
DatasetTrainer.EmptyDatasetExceptionenvironment| Constructor and Description |
|---|
LogRegressionMultiClassTrainer() |
| Modifier and Type | Method and Description |
|---|---|
protected boolean |
checkState(LogRegressionMultiClassModel mdl) |
<K,V> LogRegressionMultiClassModel |
fit(DatasetBuilder<K,V> datasetBuilder,
IgniteBiFunction<K,V,Vector> featureExtractor,
IgniteBiFunction<K,V,Double> lbExtractor)
Trains model based on the specified data.
|
int |
getAmountOfIterations()
Get the amount of outer iterations of SGD algorithm.
|
int |
getAmountOfLocIterations()
Get the amount of local iterations.
|
double |
getBatchSize()
Get the batch size.
|
UpdatesStrategy |
getUpdatesStgy()
Get the update strategy.
|
long |
seed()
Get the seed for random generator.
|
<K,V> LogRegressionMultiClassModel |
updateModel(LogRegressionMultiClassModel newMdl,
DatasetBuilder<K,V> datasetBuilder,
IgniteBiFunction<K,V,Vector> featureExtractor,
IgniteBiFunction<K,V,Double> lbExtractor)
Gets state of model in arguments, update in according to new data and return new model.
|
LogRegressionMultiClassTrainer |
withAmountOfIterations(int amountOfIterations)
Set up the amount of outer iterations.
|
LogRegressionMultiClassTrainer |
withAmountOfLocIterations(int amountOfLocIterations)
Set up the amount of local iterations of SGD algorithm.
|
LogRegressionMultiClassTrainer |
withBatchSize(int batchSize)
Set up the regularization parameter.
|
LogRegressionMultiClassTrainer |
withSeed(long seed)
Set up the random seed parameter.
|
LogRegressionMultiClassTrainer |
withUpdatesStgy(UpdatesStrategy updatesStgy)
Set up the updates strategy.
|
fit, fit, fit, fit, getLastTrainedModelOrThrowEmptyDatasetException, setEnvironment, update, update, update, update, updatepublic <K,V> LogRegressionMultiClassModel fit(DatasetBuilder<K,V> datasetBuilder, IgniteBiFunction<K,V,Vector> featureExtractor, IgniteBiFunction<K,V,Double> lbExtractor)
fit in class DatasetTrainer<LogRegressionMultiClassModel,Double>K - Type of a key in upstream data.V - Type of a value in upstream data.datasetBuilder - Dataset builder.featureExtractor - Feature extractor.lbExtractor - Label extractor.public <K,V> LogRegressionMultiClassModel updateModel(LogRegressionMultiClassModel newMdl, DatasetBuilder<K,V> datasetBuilder, IgniteBiFunction<K,V,Vector> featureExtractor, IgniteBiFunction<K,V,Double> lbExtractor)
updateModel in class DatasetTrainer<LogRegressionMultiClassModel,Double>K - Type of a key in upstream data.V - Type of a value in upstream data.newMdl - Learned model.datasetBuilder - Dataset builder.featureExtractor - Feature extractor.lbExtractor - Label extractor.protected boolean checkState(LogRegressionMultiClassModel mdl)
checkState in class DatasetTrainer<LogRegressionMultiClassModel,Double>mdl - Model.public LogRegressionMultiClassTrainer withBatchSize(int batchSize)
batchSize - The size of learning batch.public double getBatchSize()
public int getAmountOfIterations()
public LogRegressionMultiClassTrainer withAmountOfIterations(int amountOfIterations)
amountOfIterations - The parameter value.public int getAmountOfLocIterations()
public LogRegressionMultiClassTrainer withAmountOfLocIterations(int amountOfLocIterations)
amountOfLocIterations - The parameter value.public LogRegressionMultiClassTrainer withSeed(long seed)
seed - Seed for random generator.public long seed()
public LogRegressionMultiClassTrainer withUpdatesStgy(UpdatesStrategy updatesStgy)
updatesStgy - Update strategy.public UpdatesStrategy getUpdatesStgy()
Follow @ApacheIgnite
Ignite Database and Caching Platform : ver. 2.7.2 Release Date : February 6 2019