You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In woking with Model training, an issue on Optimizer has shown its head.
Currently, when calling minimize(loss) on the Optimizer instance, the Optimizer code walks the entire Graph and pulls out all the defined Variables in the graph. The idea is when you call minimize(loss), the Optimizer builds gradients based on all the variables. However, when working with Model, this "all variables approach" breaks down, because some variables are not referenced in the loss operand execution path. This produces the following error:
org.tensorflow.exceptions.TFInvalidArgumentException: Cannot compute the partial derivative for node 'model/mse_total' as it's unreachable from the output node(s).
This specific error is because the MSE metric's internal variables are not within the loss execution path. This pattern of "non-trainable variables (weights)" is in most Metric classes, and in the Model itself, so it is wide spread. What we need is a way to distinguish between trainable and non-trainable variables. Trainable variables would then be used to calculate the gradient values in the Optimizer.
In Python tensorflow, the Keras Layers track the trainable variables as an attribute list, the Model then passes the collected lists to the Optimizer's minimize method.
There are a couple of options here:
Mimic TF Keras, and have each Layer identify its trainable variables, Then, pass the trainable variables as a List<Variable<?> list using a call like, Optimizer.minimize(loss, trainableVariables), then have the Optimizerminimize routine call addGradients with this variable list, rather than walk the whole Graph, to compute the gradients.
Within Optimzier.minimize(loss), walk the loss operand execution path to locate any variables contributing to the loss calculation, then pass these to addGradients. A solution based on this option may be facilitated using Add graph walking functions to Graph and GraphOperation #232, "Add graph walking functions to Graph and GraphOperation".
In woking with
Modeltraining, an issue onOptimizerhas shown its head.Currently, when calling
minimize(loss)on theOptimizerinstance, theOptimizercode walks the entireGraphand pulls out all the definedVariablesin the graph. The idea is when you callminimize(loss), theOptimizerbuilds gradients based on all the variables. However, when working withModel, this "all variables approach" breaks down, because some variables are not referenced in thelossoperand execution path. This produces the following error:org.tensorflow.exceptions.TFInvalidArgumentException: Cannot compute the partial derivative for node 'model/mse_total' as it's unreachable from the output node(s).This specific error is because the MSE metric's internal variables are not within the
lossexecution path. This pattern of "non-trainable variables (weights)" is in mostMetricclasses, and in theModelitself, so it is wide spread. What we need is a way to distinguish between trainable and non-trainable variables. Trainable variables would then be used to calculate the gradient values in theOptimizer.In Python tensorflow, the Keras
Layers track the trainable variables as an attribute list, the Model then passes the collected lists to theOptimizer's minimize method.There are a couple of options here:
Layeridentify its trainable variables, Then, pass the trainable variables as aList<Variable<?> listusing a call like,Optimizer.minimize(loss, trainableVariables), then have theOptimizerminimizeroutine calladdGradientswith this variable list, rather than walk the whole Graph, to compute the gradients.Optimzier.minimize(loss), walk thelossoperand execution path to locate any variables contributing to the loss calculation, then pass these toaddGradients. A solution based on this option may be facilitated using Add graph walking functions to Graph and GraphOperation #232, "Add graph walking functions to Graph and GraphOperation".