Java源码示例:org.deeplearning4j.nn.api.Model
示例1
@Override
public void onGradientCalculation(Model model) {
int iterCount = getModelInfo(model).iterCount;
if (calcFromGradients() && updateConfig.reportingFrequency() > 0
&& (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
Gradient g = model.gradient();
if (updateConfig.collectHistograms(StatsType.Gradients)) {
gradientHistograms = getHistograms(g.gradientForVariable(), updateConfig.numHistogramBins(StatsType.Gradients));
}
if (updateConfig.collectMean(StatsType.Gradients)) {
meanGradients = calculateSummaryStats(g.gradientForVariable(), StatType.Mean);
}
if (updateConfig.collectStdev(StatsType.Gradients)) {
stdevGradient = calculateSummaryStats(g.gradientForVariable(), StatType.Stdev);
}
if (updateConfig.collectMeanMagnitudes(StatsType.Gradients)) {
meanMagGradients = calculateSummaryStats(g.gradientForVariable(), StatType.MeanMagnitude);
}
}
}
示例2
@Override
public void onEpochEnd(Model model) {
currentEpoch++;
// Skip if this is not an evaluation epoch
if (currentEpoch % n != 0) {
return;
}
String s = "Epoch [" + currentEpoch + "/" + numEpochs + "]\n";
if (isIntermediateEvaluationsEnabled) {
s += "Train Set: \n" + evaluateDataSetIterator(model, trainIterator, true);
if (validationIterator != null) {
s += "Validation Set: \n" + evaluateDataSetIterator(model, validationIterator, false);
}
}
log(s);
}
示例3
/**
* This method does forward pass and returns output provided by OutputAdapter
*
* @param adapter
* @param input
* @param inputMasks
* @param <T>
* @return
*/
public <T> T output(@NonNull ModelAdapter<T> adapter, INDArray[] input, INDArray[] inputMasks, INDArray[] labelsMasks) {
val holder = selector.getModelForThisThread();
Model model = null;
boolean acquired = false;
try {
model = holder.acquireModel();
acquired = true;
return adapter.apply(model, input, inputMasks, labelsMasks);
} catch (InterruptedException e) {
throw new RuntimeException(e);
} finally {
if (model != null && acquired)
holder.releaseModel(model);
}
}
示例4
protected void triggerEpochListeners(boolean epochStart, Model model, int epochNum){
Collection<TrainingListener> listeners;
if(model instanceof MultiLayerNetwork){
MultiLayerNetwork n = ((MultiLayerNetwork) model);
listeners = n.getListeners();
n.setEpochCount(epochNum);
} else if(model instanceof ComputationGraph){
ComputationGraph cg = ((ComputationGraph) model);
listeners = cg.getListeners();
cg.getConfiguration().setEpochCount(epochNum);
} else {
return;
}
if(listeners != null && !listeners.isEmpty()){
for (TrainingListener l : listeners) {
if (epochStart) {
l.onEpochStart(model);
} else {
l.onEpochEnd(model);
}
}
}
}
示例5
private static void doEval(Model m, IEvaluation[] e, Iterator<DataSet> ds, Iterator<MultiDataSet> mds, int evalBatchSize){
if(m instanceof MultiLayerNetwork){
MultiLayerNetwork mln = (MultiLayerNetwork)m;
if(ds != null){
mln.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e);
} else {
mln.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
}
} else {
ComputationGraph cg = (ComputationGraph)m;
if(ds != null){
cg.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e);
} else {
cg.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
}
}
}
示例6
@Test
public void test() throws Exception {
int testsCount = 0;
for (int numInputs = 1; numInputs <= 5; ++numInputs) {
for (int numOutputs = 1; numOutputs <= 5; ++numOutputs) {
for (Model model : new Model[]{
buildMultiLayerNetworkModel(numInputs, numOutputs),
buildComputationGraphModel(numInputs, numOutputs)
}) {
doTest(model, numInputs, numOutputs);
++testsCount;
}
}
}
assertEquals(50, testsCount);
}
示例7
@Test
public void testNormalizerInPlace() throws Exception {
MultiLayerNetwork net = getNetwork();
File tempFile = testDir.newFile("testNormalizerInPlace.bin");
NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
ModelSerializer.writeModel(net, tempFile, true,normalizer);
Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath());
assertEquals(model, net);
assertEquals(normalizer, normalizer1);
}
示例8
private static Map<String,INDArray> getFrozenLayerParamCopies(Model m){
Map<String,INDArray> out = new LinkedHashMap<>();
org.deeplearning4j.nn.api.Layer[] layers;
if (m instanceof MultiLayerNetwork) {
layers = ((MultiLayerNetwork) m).getLayers();
} else {
layers = ((ComputationGraph) m).getLayers();
}
for(org.deeplearning4j.nn.api.Layer l : layers){
if(l instanceof FrozenLayer){
String paramPrefix;
if(m instanceof MultiLayerNetwork){
paramPrefix = l.getIndex() + "_";
} else {
paramPrefix = l.conf().getLayer().getLayerName() + "_";
}
Map<String,INDArray> paramTable = l.paramTable();
for(Map.Entry<String,INDArray> e : paramTable.entrySet()){
out.put(paramPrefix + e.getKey(), e.getValue().dup());
}
}
}
return out;
}
示例9
/**
* Loads a dl4j zip file (either computation graph or multi layer network)
*
* @param path the path to the file to load
* @return a loaded dl4j model
* @throws Exception if loading a dl4j model fails
*/
public static Model loadDl4jGuess(String path) throws Exception {
if (isZipFile(new File(path))) {
log.debug("Loading file " + path);
boolean compGraph = false;
try (ZipFile zipFile = new ZipFile(path)) {
List<String> collect = zipFile.stream().map(ZipEntry::getName)
.collect(Collectors.toList());
log.debug("Entries " + collect);
if (collect.contains(ModelSerializer.COEFFICIENTS_BIN) && collect.contains(ModelSerializer.CONFIGURATION_JSON)) {
ZipEntry entry = zipFile.getEntry(ModelSerializer.CONFIGURATION_JSON);
log.debug("Loaded configuration");
try (InputStream is = zipFile.getInputStream(entry)) {
String configJson = IOUtils.toString(is, StandardCharsets.UTF_8);
JSONObject jsonObject = new JSONObject(configJson);
if (jsonObject.has("vertexInputs")) {
log.debug("Loading computation graph.");
compGraph = true;
} else {
log.debug("Loading multi layer network.");
}
}
}
}
if (compGraph) {
return ModelSerializer.restoreComputationGraph(new File(path));
} else {
return ModelSerializer.restoreMultiLayerNetwork(new File(path));
}
}
return null;
}
示例10
public static String saveModel(String name, Model model, int index, int accuracy) throws Exception {
System.err.println("Saving model, don't shutdown...");
try {
String fn = name + "_idx_" + index + "_" + accuracy + ".zip";
File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
boolean saveUpdater = true; //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future
ModelSerializer.writeModel(model, locationToSave, saveUpdater);
System.err.println("Model saved");
return fn;
} catch (IOException e) {
System.err.println("Save model failed");
e.printStackTrace();
throw e;
}
}
示例11
@Override
public void onBackwardPass(Model model) {
if(!printOnBackwardPass || printFileTarget == null)
return;
writeFileWithMessage("backward pass");
}
示例12
private static Model buildModel() throws Exception {
final int numInputs = 3;
final int numOutputs = 2;
final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list(
new OutputLayer.Builder()
.nIn(numInputs)
.nOut(numOutputs)
.activation(Activation.IDENTITY)
.lossFunction(LossFunctions.LossFunction.MSE)
.build()
)
.build();
final MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
final float[] floats = new float[]{ +1, +1, +1, -1, -1, -1, 0, 0 };
// positive weight for first output, negative weight for second output, no biases
assertEquals((numInputs+1)*numOutputs, floats.length);
final INDArray params = Nd4j.create(floats);
model.setParams(params);
return model;
}
示例13
@Override
public void iterationDone(Model model, int iteration, int epoch) {
sleep(lastIteration.get(), timerIteration);
if (lastIteration.get() == null)
lastIteration.set(new AtomicLong(System.currentTimeMillis()));
else
lastIteration.get().set(System.currentTimeMillis());
}
示例14
@Override
public void iterationDone(Model model, int iteration, int epoch) {
if (statusListeners == null) {
return;
}
for (StatusListener sl : statusListeners) {
sl.onCandidateIteration(candidateInfo, model, iteration);
}
}
示例15
@Override
public void onForwardPass(Model model, List<INDArray> activations) {
if(!printOnBackwardPass || printFileTarget == null)
return;
writeFileWithMessage("forward pass");
}
示例16
private static void validateLayerIterCounts(Model m, int expEpoch, int expIter){
//Check that the iteration and epoch counts - on the layers - are synced
org.deeplearning4j.nn.api.Layer[] layers;
if (m instanceof MultiLayerNetwork) {
layers = ((MultiLayerNetwork) m).getLayers();
} else {
layers = ((ComputationGraph) m).getLayers();
}
for(org.deeplearning4j.nn.api.Layer l : layers){
assertEquals("Epoch count", expEpoch, l.getEpochCount());
assertEquals("Iteration count", expIter, l.getIterationCount());
}
}
示例17
public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) {
if (layer instanceof MultiLayerNetwork) {
return new MultiLayerUpdater((MultiLayerNetwork) layer);
} else if (layer instanceof ComputationGraph) {
return new ComputationGraphUpdater((ComputationGraph) layer);
} else {
return new LayerUpdater((Layer) layer);
}
}
示例18
/**
*
* @param conf
* @param stepFunction
* @param trainingListeners
* @param model
*/
public BaseOptimizer(NeuralNetConfiguration conf, StepFunction stepFunction,
Collection<TrainingListener> trainingListeners, Model model) {
this.conf = conf;
this.stepFunction = (stepFunction != null ? stepFunction : getDefaultStepFunctionForOptimizer(this.getClass()));
this.trainingListeners = trainingListeners != null ? trainingListeners : new ArrayList<TrainingListener>();
this.model = model;
lineMaximizer = new BackTrackLineSearch(model, this.stepFunction, this);
lineMaximizer.setStepMax(stepMax);
lineMaximizer.setMaxIterations(conf.getMaxNumLineSearchIterations());
}
示例19
@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
if(!printOnForwardPass)
return;
SystemInfo systemInfo = new SystemInfo();
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
示例20
public static int getEpochCount(Model model){
if (model instanceof MultiLayerNetwork) {
return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount();
} else if (model instanceof ComputationGraph) {
return ((ComputationGraph) model).getConfiguration().getEpochCount();
} else {
return model.conf().getEpochCount();
}
}
示例21
@Override
public void onBackwardPass(Model model) {
sleep(lastBP.get(), timerBP);
if (lastBP.get() == null)
lastBP.set(new AtomicLong(System.currentTimeMillis()));
else
lastBP.get().set(System.currentTimeMillis());
}
示例22
protected static String getModelType(Model model){
if(model.getClass() == MultiLayerNetwork.class){
return "MultiLayerNetwork";
} else if(model.getClass() == ComputationGraph.class){
return "ComputationGraph";
} else {
return "Model";
}
}
示例23
/**
* Uses the {@link ModelGuesser#loadModelGuess(InputStream)} method.
*/
protected Model restoreModel(InputStream inputStream) throws IOException {
final File instanceDir = solrResourceLoader.getInstancePath().toFile();
try {
return ModelGuesser.loadModelGuess(inputStream, instanceDir);
} catch (Exception e) {
throw new IOException("Failed to restore model from given file (" + serializedModelFileName + ")", e);
}
}
示例24
@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
if(!printOnForwardPass || printFileTarget == null)
return;
writeFileWithMessage("forward pass");
}
示例25
protected static int getEpoch(Model model) {
if (model instanceof MultiLayerNetwork) {
return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount();
} else if (model instanceof ComputationGraph) {
return ((ComputationGraph) model).getConfiguration().getEpochCount();
} else {
return model.conf().getEpochCount();
}
}
示例26
@Override
public void onGradientCalculation(Model model) {
if(!printOnGradientCalculation)
return;
SystemInfo systemInfo = new SystemInfo();
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
示例27
@Override
public void onEpochEnd(Model model) {
int epochsDone = getEpoch(model) + 1;
if(saveEveryNEpochs != null && epochsDone > 0 && epochsDone % saveEveryNEpochs == 0){
//Save:
saveCheckpoint(model);
}
//General saving conditions: don't need to check here - will check in iterationDone
}
示例28
@Override
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
boolean b = false;
for(FailureTrigger ft : triggers)
b |= ft.triggerFailure(callType, iteration, epoch, model);
return b;
}
示例29
@Override
protected synchronized Model[] getCurrentModelsFromWorkers() {
val models = new Model[holders.size()];
int cnt = 0;
for (val h:holders) {
models[cnt++] = h.sourceModel;
}
return models;
}
示例30
@Override
public Class<? extends Model> modelType() {
return ComputationGraph.class;
}