public class Dl4jMlpClassifier extends RandomizableClassifier implements BatchPredictor, CapabilitiesHandler, IterativeClassifier
Modifier and Type | Field and Description |
---|---|
static int |
FILTER_NONE
filter: No normalization/standardization
|
static int |
FILTER_NORMALIZE
filter: Normalize training data
|
static int |
FILTER_STANDARDIZE
filter: Standardize training data
|
static Tag[] |
TAGS_FILTER
The filter to apply to the training data
|
BATCH_SIZE_DEFAULT, NUM_DECIMAL_PLACES_DEFAULT
Constructor and Description |
---|
Dl4jMlpClassifier() |
Modifier and Type | Method and Description |
---|---|
void |
buildClassifier(Instances data)
The method used to train the classifier.
|
double[] |
distributionForInstance(Instance inst)
The method to use when making a prediction for a test instance.
|
double[][] |
distributionsForInstances(Instances insts)
The method to use when making predictions for test instances.
|
void |
done()
Clean up after learning.
|
Capabilities |
getCapabilities()
Returns default capabilities of the classifier.
|
AbstractDataSetIterator |
getDataSetIterator() |
SelectedTag |
getFilterType() |
org.deeplearning4j.nn.conf.layers.Layer[] |
getLayers() |
java.io.File |
getLogFile()
Get the log file
|
NeuralNetConfiguration |
getNeuralNetConfiguration() |
int |
getNumEpochs() |
int |
getQueueSize() |
java.lang.String |
globalInfo() |
boolean |
implementsMoreEfficientBatchPrediction()
Performs efficient batch prediction
|
void |
initializeClassifier(Instances data)
The method used to initialize the classifier.
|
static void |
main(java.lang.String[] argv)
The main method for running this class.
|
boolean |
next()
Perform another epoch.
|
void |
setDataSetIterator(AbstractDataSetIterator iterator) |
void |
setFilterType(SelectedTag newType) |
void |
setLayers(org.deeplearning4j.nn.conf.layers.Layer[] layers) |
void |
setLogFile(java.io.File logFile)
Set the log file
|
void |
setNeuralNetConfiguration(NeuralNetConfiguration config) |
void |
setNumEpochs(int numEpochs) |
void |
setQueueSize(int QueueSize) |
java.lang.String |
toString()
Returns a string describing the model.
|
getOptions, getSeed, listOptions, seedTipText, setOptions, setSeed
batchSizeTipText, classifyInstance, debugTipText, doNotCheckCapabilitiesTipText, forName, getBatchSize, getDebug, getDoNotCheckCapabilities, getNumDecimalPlaces, getRevision, makeCopies, makeCopy, numDecimalPlacesTipText, postExecution, preExecution, run, runClassifier, setBatchSize, setDebug, setDoNotCheckCapabilities, setNumDecimalPlaces
equals, getClass, hashCode, notify, notifyAll, wait, wait, wait
getBatchSize, setBatchSize
classifyInstance
public static final int FILTER_NORMALIZE
public static final int FILTER_STANDARDIZE
public static final int FILTER_NONE
public static final Tag[] TAGS_FILTER
public static void main(java.lang.String[] argv)
argv
- the command-line argumentspublic java.lang.String globalInfo()
public Capabilities getCapabilities()
getCapabilities
in interface Classifier
getCapabilities
in interface CapabilitiesHandler
getCapabilities
in class AbstractClassifier
public java.io.File getLogFile()
@OptionMetadata(displayName="log file", description="The name of the log file to write loss information to (default = no log file).", commandLineParamName="logFile", commandLineParamSynopsis="-logFile <string>", displayOrder=1) public void setLogFile(java.io.File logFile)
logFile
- the log filepublic org.deeplearning4j.nn.conf.layers.Layer[] getLayers()
@OptionMetadata(displayName="layer specification.", description="The specification of a layer. This option can be used multiple times.", commandLineParamName="layer", commandLineParamSynopsis="-layer <string>", displayOrder=2) public void setLayers(org.deeplearning4j.nn.conf.layers.Layer[] layers)
public int getNumEpochs()
@OptionMetadata(description="The number of epochs to perform.", displayName="number of epochs", commandLineParamName="numEpochs", commandLineParamSynopsis="-numEpochs <int>", displayOrder=4) public void setNumEpochs(int numEpochs)
@OptionMetadata(description="The dataset iterator to use.", displayName="dataset iterator", commandLineParamName="iterator", commandLineParamSynopsis="-iterator <string>", displayOrder=6) public AbstractDataSetIterator getDataSetIterator()
public void setDataSetIterator(AbstractDataSetIterator iterator)
@OptionMetadata(description="The neural network configuration to use.", displayName="network configuration", commandLineParamName="config", commandLineParamSynopsis="-config <string>", displayOrder=7) public NeuralNetConfiguration getNeuralNetConfiguration()
public void setNeuralNetConfiguration(NeuralNetConfiguration config)
@OptionMetadata(description="The type of normalization to perform.", displayName="attribute normalization", commandLineParamName="normalization", commandLineParamSynopsis="-normalization <int>", displayOrder=8) public SelectedTag getFilterType()
public void setFilterType(SelectedTag newType)
public int getQueueSize()
@OptionMetadata(description="The queue size for asynchronous data transfer (default: 0, synchronous transfer).", displayName="queue size for asynchronous data transfer", commandLineParamName="queueSize", commandLineParamSynopsis="-queueSize <int>", displayOrder=9) public void setQueueSize(int QueueSize)
public void buildClassifier(Instances data) throws java.lang.Exception
buildClassifier
in interface Classifier
data
- set of instances serving as training datajava.lang.Exception
- if something goes wrong in the training processpublic void initializeClassifier(Instances data) throws java.lang.Exception
initializeClassifier
in interface IterativeClassifier
data
- set of instances serving as training datajava.lang.Exception
- if something goes wrong in the training processpublic boolean next() throws java.lang.Exception
next
in interface IterativeClassifier
java.lang.Exception
public void done()
done
in interface IterativeClassifier
public boolean implementsMoreEfficientBatchPrediction()
implementsMoreEfficientBatchPrediction
in interface BatchPredictor
implementsMoreEfficientBatchPrediction
in class AbstractClassifier
public double[] distributionForInstance(Instance inst) throws java.lang.Exception
distributionForInstance
in interface Classifier
distributionForInstance
in class AbstractClassifier
inst
- the instance to get a prediction forjava.lang.Exception
- if something goes wrong at prediction timepublic double[][] distributionsForInstances(Instances insts) throws java.lang.Exception
distributionsForInstances
in interface BatchPredictor
distributionsForInstances
in class AbstractClassifier
insts
- the instances to get predictions forjava.lang.Exception
- if something goes wrong at prediction timepublic java.lang.String toString()
toString
in class java.lang.Object