Skip to content

Commit 8f9a91a

Browse files
Extract shape information and encode it as constants. This improves the step time of the ptb_word_lm model by 5 to 10%, and the training speed of the inception model by 5%.
We avoid encoding static shapes as constant when control flow operations are involved since the static shape information may be incorrect in some cases. Change: 129636322
1 parent ce43d9b commit 8f9a91a

5 files changed

Lines changed: 94 additions & 23 deletions

File tree

tensorflow/python/framework/tensor_util_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def testSizeOfScalar(self):
526526
tf_val = tf.size(tf.constant(0.0))
527527
c_val = tf.contrib.util.constant_value(tf_val)
528528
self.assertEqual(1, c_val)
529-
self.assertEqual(np.int32, type(c_val))
529+
self.assertEqual(np.ndarray, type(c_val))
530530

531531
def testRank(self):
532532
tf_val = tf.rank(tf.constant(0.0, shape=[1, 2, 3]))

tensorflow/python/ops/array_ops.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@
104104
# Aliases for some automatically-generated names.
105105
listdiff = gen_array_ops.list_diff
106106

107-
108107
def shape(input, name=None):
109108
"""Returns the shape of a tensor.
110109
@@ -121,13 +120,34 @@ def shape(input, name=None):
121120
input: A `Tensor` or `SparseTensor`.
122121
name: A name for the operation (optional).
123122
123+
Returns:
124+
A `Tensor` of type `int32`.
125+
"""
126+
return shape_internal(input, name, optimize=True)
127+
128+
129+
def shape_internal(input, name=None, optimize=True):
130+
"""Returns the shape of a tensor.
131+
132+
Args:
133+
input: A `Tensor` or `SparseTensor`.
134+
name: A name for the operation (optional).
135+
optimize: if true, encode the shape as a constant when possible.
136+
124137
Returns:
125138
A `Tensor` of type `int32`.
126139
"""
127140
with ops.op_scope([input], name, "Shape") as name:
128141
if isinstance(input, ops.SparseTensor):
129142
return gen_math_ops.cast(input.shape, dtypes.int32)
130143
else:
144+
input_tensor = ops.convert_to_tensor(input)
145+
input_shape = input_tensor.get_shape()
146+
# Static shape inference can be incorrect when loops are involved: disable
147+
# shape optimization in this case to avoid generating invalid constants.
148+
optimize &= input_tensor.graph._get_control_flow_context() is None
149+
if optimize and input_shape.is_fully_defined():
150+
return constant(input_shape.as_list(), dtypes.int32, name=name)
131151
return gen_array_ops.shape(input, name=name)
132152

133153

@@ -148,6 +168,20 @@ def size(input, name=None):
148168
input: A `Tensor` or `SparseTensor`.
149169
name: A name for the operation (optional).
150170
171+
Returns:
172+
A `Tensor` of type `int32`.
173+
"""
174+
return size_internal(input, name, optimize=True)
175+
176+
177+
def size_internal(input, name=None, optimize=True):
178+
"""Returns the size of a tensor.
179+
180+
Args:
181+
input: A `Tensor` or `SparseTensor`.
182+
name: A name for the operation (optional).
183+
optimize: if true, encode the size as a constant when possible.
184+
151185
Returns:
152186
A `Tensor` of type `int32`.
153187
"""
@@ -156,6 +190,13 @@ def size(input, name=None):
156190
return gen_math_ops._prod(gen_math_ops.cast(input.shape, dtypes.int32), 0,
157191
name=name)
158192
else:
193+
input_tensor = ops.convert_to_tensor(input)
194+
input_shape = input_tensor.get_shape()
195+
# Static shape inference can be incorrect when loops are involved: disable
196+
# shape optimization in this case to avoid generating invalid constants.
197+
optimize &= input_tensor.graph._get_control_flow_context() is None
198+
if optimize and input_shape.is_fully_defined():
199+
return constant(input_shape.num_elements(), dtypes.int32, name=name)
159200
return gen_array_ops.size(input, name=name)
160201

161202

@@ -180,13 +221,34 @@ def rank(input, name=None):
180221
input: A `Tensor` or `SparseTensor`.
181222
name: A name for the operation (optional).
182223
224+
Returns:
225+
A `Tensor` of type `int32`.
226+
"""
227+
return rank_internal(input, name, optimize=True)
228+
229+
230+
def rank_internal(input, name=None, optimize=True):
231+
"""Returns the rank of a tensor.
232+
233+
Args:
234+
input: A `Tensor` or `SparseTensor`.
235+
name: A name for the operation (optional).
236+
optimize: if true, encode the rank as a constant when possible.
237+
183238
Returns:
184239
A `Tensor` of type `int32`.
185240
"""
186241
with ops.op_scope([input], name, "Rank") as name:
187242
if isinstance(input, ops.SparseTensor):
188243
return gen_array_ops.size(input.shape, name=name)
189244
else:
245+
input_tensor = ops.convert_to_tensor(input)
246+
input_shape = input_tensor.get_shape()
247+
# Static shape inference can be incorrect when loops are involved: disable
248+
# shape optimization in this case to avoid generating invalid constants.
249+
optimize &= input_tensor.graph._get_control_flow_context() is None
250+
if optimize and input_shape.ndims is not None:
251+
return constant(input_shape.ndims, dtypes.int32, name=name)
190252
return gen_array_ops.rank(input, name=name)
191253

192254

@@ -1074,7 +1136,7 @@ def zeros(shape, dtype=dtypes.float32, name=None):
10741136
return output
10751137

10761138

1077-
def zeros_like(tensor, dtype=None, name=None):
1139+
def zeros_like(tensor, dtype=None, name=None, optimize=True):
10781140
"""Creates a tensor with all elements set to zero.
10791141
10801142
Given a single tensor (`tensor`), this operation returns a tensor of the
@@ -1093,21 +1155,23 @@ def zeros_like(tensor, dtype=None, name=None):
10931155
dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
10941156
`int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`.
10951157
name: A name for the operation (optional).
1158+
optimize: if true, attempt to statically determine the shape of 'tensor'
1159+
and encode it as a constant.
10961160
10971161
Returns:
10981162
A `Tensor` with all elements set to zero.
10991163
"""
11001164
with ops.op_scope([tensor], name, "zeros_like") as name:
11011165
tensor = ops.convert_to_tensor(tensor, name="tensor")
11021166
if dtype is not None and tensor.dtype != dtype:
1103-
ret = zeros(shape(tensor), dtype, name=name)
1167+
ret = zeros(shape_internal(tensor, optimize=optimize), dtype, name=name)
11041168
ret.set_shape(tensor.get_shape())
11051169
return ret
11061170
else:
11071171
return gen_array_ops._zeros_like(tensor, name=name)
11081172

11091173

1110-
def ones_like(tensor, dtype=None, name=None):
1174+
def ones_like(tensor, dtype=None, name=None, optimize=True):
11111175
"""Creates a tensor with all elements set to 1.
11121176
11131177
Given a single tensor (`tensor`), this operation returns a tensor of the same
@@ -1126,13 +1190,15 @@ def ones_like(tensor, dtype=None, name=None):
11261190
dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
11271191
`int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`.
11281192
name: A name for the operation (optional).
1193+
optimize: if true, attempt to statically determine the shape of 'tensor'
1194+
and encode it as a constant.
11291195
11301196
Returns:
11311197
A `Tensor` with all elements set to 1.
11321198
"""
11331199
with ops.op_scope([tensor], name, "ones_like") as name:
11341200
tensor = ops.convert_to_tensor(tensor, name="tensor")
1135-
ones_shape = shape(tensor)
1201+
ones_shape = shape_internal(tensor, optimize=optimize)
11361202
if dtype is None:
11371203
dtype = tensor.dtype
11381204
ret = ones(ones_shape, dtype=dtype, name=name)

tensorflow/python/ops/control_flow_ops.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def merge(inputs, name=None):
371371
return ops.SparseTensor(indices, values, dense_shape), chosen_index
372372
else:
373373
# For now convert all the inputs as IndexedSlices.
374-
inputs = math_ops._as_indexed_slices_list(inputs)
374+
inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
375375
values, _ = merge([inp.values for inp in inputs], name=name)
376376
indices, chosen_index = gen_control_flow_ops._merge(
377377
[inp.indices for inp in inputs], name="indices")
@@ -452,7 +452,7 @@ def _AddNextAndBackEdge(m, v):
452452
m.op._update_input(1, v) # pylint: disable=protected-access
453453
elif isinstance(m, ops.IndexedSlices):
454454
# pylint: disable=protected-access
455-
v = math_ops._as_indexed_slices(v)
455+
v = math_ops._as_indexed_slices(v, optimize=False)
456456
v = _NextIteration(v)
457457
m.values.op._update_input(1, v.values)
458458
m.indices.op._update_input(1, v.indices)
@@ -902,7 +902,7 @@ def ZerosLikeForExit(self, val):
902902
else:
903903
# Only the shape of value is needed for backprop.
904904
forward_ctxt.outer_context.Enter()
905-
shape = array_ops.shape(val)
905+
shape = array_ops.shape_internal(val, optimize=False)
906906
forward_ctxt.outer_context.Exit()
907907
# Save the shape to a stack.
908908
history_shape = outer_grad_state.AddForwardAccumulator(shape)
@@ -920,7 +920,7 @@ def ZerosLikeForExit(self, val):
920920
# with the right shape.
921921
result = array_ops.zeros(val_shape.dims, val.dtype)
922922
else:
923-
result = array_ops.zeros_like(val)
923+
result = array_ops.zeros_like(val, optimize=False)
924924
return result
925925

926926
def ZerosLike(self, op, index):
@@ -963,13 +963,13 @@ def ZerosLike(self, op, index):
963963
branch = op_ctxt.branch
964964
op_ctxt.outer_context.Enter()
965965
val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
966-
zeros_shape = array_ops.shape(val)
966+
zeros_shape = array_ops.shape_internal(val, optimize=False)
967967
op_ctxt.outer_context.Exit()
968968
val.op._set_control_flow_context(op_ctxt)
969969
zeros_shape.op._set_control_flow_context(op_ctxt)
970970
else:
971971
op_ctxt.Enter()
972-
zeros_shape = array_ops.shape(val)
972+
zeros_shape = array_ops.shape_internal(val, optimize=False)
973973
op_ctxt.Exit()
974974

975975
# Add forward accumulator for shape.
@@ -1054,13 +1054,13 @@ def ZerosLikeOutsideLoop(op, index):
10541054
"""Create zeros_like for the specified output of an op."""
10551055
val = op.outputs[index]
10561056
if not IsSwitch(op):
1057-
return array_ops.zeros_like(val)
1057+
return array_ops.zeros_like(val, optimize=False)
10581058
else:
10591059
op_ctxt = op._get_control_flow_context()
10601060
pred = op_ctxt.pred
10611061
branch = op_ctxt.branch
10621062
switch_val = switch(op.inputs[0], pred)[1 - branch]
1063-
zeros_shape = array_ops.shape(switch_val)
1063+
zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
10641064
return array_ops.zeros(zeros_shape, dtype=val.dtype)
10651065

10661066

@@ -1664,7 +1664,7 @@ def AddBackPropAccumulator(self, op, grad):
16641664
if self.outer_context:
16651665
forward_ctxt = self.grad_state.forward_ctxt
16661666
forward_ctxt.outer_context.Enter()
1667-
zeros_shape = array_ops.shape(value)
1667+
zeros_shape = array_ops.shape_internal(value, optimize=False)
16681668
forward_ctxt.outer_context.Exit()
16691669
history_zeros_shape = grad_state.AddForwardAccumulator(zeros_shape)
16701670
self.outer_context.Enter()
@@ -1673,7 +1673,7 @@ def AddBackPropAccumulator(self, op, grad):
16731673
acc = array_ops.zeros(real_shape, grad.dtype)
16741674
self.outer_context.Exit()
16751675
else:
1676-
zeros_shape = array_ops.shape(value)
1676+
zeros_shape = array_ops.shape_internal(value, optimize=False)
16771677
acc = array_ops.zeros(zeros_shape, grad.dtype)
16781678
acc._shape = grad.get_shape() # pylint: disable=protected-access
16791679

@@ -1720,7 +1720,7 @@ def AddBackPropIndexedSlicesAccumulator(self, op, grad):
17201720
name="b_acc")
17211721
if self.outer_context: self.outer_context.Exit()
17221722
else:
1723-
values_shape = array_ops.shape(op.inputs[0])[1:]
1723+
values_shape = array_ops.shape_internal(op.inputs[0], optimize=False)[1:]
17241724
values_shape = array_ops.concat(0, [[1], values_shape])
17251725
values_acc = array_ops.zeros(values_shape)
17261726
indices_acc = constant_op.constant([0], indices.dtype)
@@ -1732,7 +1732,10 @@ def AddBackPropIndexedSlicesAccumulator(self, op, grad):
17321732
shape=dense_shape.get_shape())
17331733
if self.outer_context: self.outer_context.Exit()
17341734
else:
1735-
shape_acc = array_ops.zeros_like(array_ops.shape(op.inputs[0]))
1735+
shape_acc = array_ops.zeros_like(
1736+
array_ops.shape_internal(
1737+
op.inputs[0], optimize=False),
1738+
optimize=False)
17361739

17371740
if self.outer_context: self.outer_context.Exit()
17381741

tensorflow/python/ops/math_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,13 +1387,14 @@ def _calc_mat_mul_weight_parameters(graph, node):
13871387
(int(weights_shape[1]) * int(weights_shape[0])))
13881388

13891389

1390-
def _as_indexed_slices(x):
1390+
def _as_indexed_slices(x, optimize=True):
13911391
"""Convert 'x' to IndexedSlices.
13921392
13931393
Convert a dense Tensor to a block-sparse IndexedSlices.
13941394
13951395
Args:
13961396
x: Either a Tensor object, or an IndexedSlices object.
1397+
optimize: if true, attempt to optimize the conversion of 'x'.
13971398
13981399
Returns:
13991400
An IndexedSlices object.
@@ -1406,18 +1407,19 @@ def _as_indexed_slices(x):
14061407
raise TypeError("Not a Tensor or IndexedSlices: %s" % type(x))
14071408
if isinstance(x, ops.IndexedSlices):
14081409
return x
1409-
x_shape = array_ops.shape(x)
1410+
x_shape = array_ops.shape_internal(x, optimize=optimize)
14101411
return ops.IndexedSlices(x, range(0, x_shape[0]), x_shape)
14111412

14121413

1413-
def _as_indexed_slices_list(inputs):
1414+
def _as_indexed_slices_list(inputs, optimize=True):
14141415
"""Convert all elements of 'inputs' to IndexedSlices.
14151416
14161417
Additionally, homogenize the types of all the indices to
14171418
either int32 or int64.
14181419
14191420
Args:
14201421
inputs: List containing either Tensor or IndexedSlices objects.
1422+
optimize: if true, attempt to optimize the conversion of each input.
14211423
14221424
Returns:
14231425
A list of IndexedSlices objects.
@@ -1427,7 +1429,7 @@ def _as_indexed_slices_list(inputs):
14271429
"""
14281430
if not isinstance(inputs, (list, tuple)):
14291431
raise TypeError("Expected a list or tuple, not a %s" % type(inputs))
1430-
outputs = [_as_indexed_slices(i) for i in inputs]
1432+
outputs = [_as_indexed_slices(i, optimize=optimize) for i in inputs]
14311433
with_int32_index = [o.indices for o in outputs
14321434
if o.indices.dtype == dtypes.int32]
14331435
if not with_int32_index or len(with_int32_index) == len(outputs):

tensorflow/python/ops/variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,7 @@ def assert_variables_initialized(var_list=None):
999999
ranks = []
10001000
for var in var_list:
10011001
with ops.colocate_with(var.op):
1002-
ranks.append(array_ops.rank(var))
1002+
ranks.append(array_ops.rank_internal(var, optimize=False))
10031003
if len(ranks) == 1:
10041004
return ranks[0]
10051005
else:

0 commit comments

Comments
 (0)