Example of training a model (and saving and restoring checkpoints) using the TensorFlow Java API.
-
Train for a few steps:
mvn -q compile exec:java -Dexec.args="model/graph.pb checkpoint" -
Resume training from previous checkpoint and train some more:
mvn -q exec:java -Dexec.args="model/graph.pb checkpoint" -
Delete checkpoint:
rm -rf checkpoint
The model in model/graph.pb represents a very simple linear model:
y = x * W + b
The graph.pb file is generated by executing create_graph.py in Python.
The training is orchestrated by src/main/java/Train.java, which generates
training data of the form y = 3.0 * x + 2.0 and over time, using gradient
descent, the model should "learn" and the value of W should converge to 3.0,
and b to 2.0.