104104# Aliases for some automatically-generated names.
105105listdiff = gen_array_ops .list_diff
106106
107-
108107def shape (input , name = None ):
109108 """Returns the shape of a tensor.
110109
@@ -121,13 +120,34 @@ def shape(input, name=None):
121120 input: A `Tensor` or `SparseTensor`.
122121 name: A name for the operation (optional).
123122
123+ Returns:
124+ A `Tensor` of type `int32`.
125+ """
126+ return shape_internal (input , name , optimize = True )
127+
128+
129+ def shape_internal (input , name = None , optimize = True ):
130+ """Returns the shape of a tensor.
131+
132+ Args:
133+ input: A `Tensor` or `SparseTensor`.
134+ name: A name for the operation (optional).
135+ optimize: if true, encode the shape as a constant when possible.
136+
124137 Returns:
125138 A `Tensor` of type `int32`.
126139 """
127140 with ops .op_scope ([input ], name , "Shape" ) as name :
128141 if isinstance (input , ops .SparseTensor ):
129142 return gen_math_ops .cast (input .shape , dtypes .int32 )
130143 else :
144+ input_tensor = ops .convert_to_tensor (input )
145+ input_shape = input_tensor .get_shape ()
146+ # Static shape inference can be incorrect when loops are involved: disable
147+ # shape optimization in this case to avoid generating invalid constants.
148+ optimize &= input_tensor .graph ._get_control_flow_context () is None
149+ if optimize and input_shape .is_fully_defined ():
150+ return constant (input_shape .as_list (), dtypes .int32 , name = name )
131151 return gen_array_ops .shape (input , name = name )
132152
133153
@@ -148,6 +168,20 @@ def size(input, name=None):
148168 input: A `Tensor` or `SparseTensor`.
149169 name: A name for the operation (optional).
150170
171+ Returns:
172+ A `Tensor` of type `int32`.
173+ """
174+ return size_internal (input , name , optimize = True )
175+
176+
177+ def size_internal (input , name = None , optimize = True ):
178+ """Returns the size of a tensor.
179+
180+ Args:
181+ input: A `Tensor` or `SparseTensor`.
182+ name: A name for the operation (optional).
183+ optimize: if true, encode the size as a constant when possible.
184+
151185 Returns:
152186 A `Tensor` of type `int32`.
153187 """
@@ -156,6 +190,13 @@ def size(input, name=None):
156190 return gen_math_ops ._prod (gen_math_ops .cast (input .shape , dtypes .int32 ), 0 ,
157191 name = name )
158192 else :
193+ input_tensor = ops .convert_to_tensor (input )
194+ input_shape = input_tensor .get_shape ()
195+ # Static shape inference can be incorrect when loops are involved: disable
196+ # shape optimization in this case to avoid generating invalid constants.
197+ optimize &= input_tensor .graph ._get_control_flow_context () is None
198+ if optimize and input_shape .is_fully_defined ():
199+ return constant (input_shape .num_elements (), dtypes .int32 , name = name )
159200 return gen_array_ops .size (input , name = name )
160201
161202
@@ -180,13 +221,34 @@ def rank(input, name=None):
180221 input: A `Tensor` or `SparseTensor`.
181222 name: A name for the operation (optional).
182223
224+ Returns:
225+ A `Tensor` of type `int32`.
226+ """
227+ return rank_internal (input , name , optimize = True )
228+
229+
230+ def rank_internal (input , name = None , optimize = True ):
231+ """Returns the rank of a tensor.
232+
233+ Args:
234+ input: A `Tensor` or `SparseTensor`.
235+ name: A name for the operation (optional).
236+ optimize: if true, encode the rank as a constant when possible.
237+
183238 Returns:
184239 A `Tensor` of type `int32`.
185240 """
186241 with ops .op_scope ([input ], name , "Rank" ) as name :
187242 if isinstance (input , ops .SparseTensor ):
188243 return gen_array_ops .size (input .shape , name = name )
189244 else :
245+ input_tensor = ops .convert_to_tensor (input )
246+ input_shape = input_tensor .get_shape ()
247+ # Static shape inference can be incorrect when loops are involved: disable
248+ # shape optimization in this case to avoid generating invalid constants.
249+ optimize &= input_tensor .graph ._get_control_flow_context () is None
250+ if optimize and input_shape .ndims is not None :
251+ return constant (input_shape .ndims , dtypes .int32 , name = name )
190252 return gen_array_ops .rank (input , name = name )
191253
192254
@@ -1074,7 +1136,7 @@ def zeros(shape, dtype=dtypes.float32, name=None):
10741136 return output
10751137
10761138
1077- def zeros_like (tensor , dtype = None , name = None ):
1139+ def zeros_like (tensor , dtype = None , name = None , optimize = True ):
10781140 """Creates a tensor with all elements set to zero.
10791141
10801142 Given a single tensor (`tensor`), this operation returns a tensor of the
@@ -1093,21 +1155,23 @@ def zeros_like(tensor, dtype=None, name=None):
10931155 dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
10941156 `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`.
10951157 name: A name for the operation (optional).
1158+ optimize: if true, attempt to statically determine the shape of 'tensor'
1159+ and encode it as a constant.
10961160
10971161 Returns:
10981162 A `Tensor` with all elements set to zero.
10991163 """
11001164 with ops .op_scope ([tensor ], name , "zeros_like" ) as name :
11011165 tensor = ops .convert_to_tensor (tensor , name = "tensor" )
11021166 if dtype is not None and tensor .dtype != dtype :
1103- ret = zeros (shape (tensor ), dtype , name = name )
1167+ ret = zeros (shape_internal (tensor , optimize = optimize ), dtype , name = name )
11041168 ret .set_shape (tensor .get_shape ())
11051169 return ret
11061170 else :
11071171 return gen_array_ops ._zeros_like (tensor , name = name )
11081172
11091173
1110- def ones_like (tensor , dtype = None , name = None ):
1174+ def ones_like (tensor , dtype = None , name = None , optimize = True ):
11111175 """Creates a tensor with all elements set to 1.
11121176
11131177 Given a single tensor (`tensor`), this operation returns a tensor of the same
@@ -1126,13 +1190,15 @@ def ones_like(tensor, dtype=None, name=None):
11261190 dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
11271191 `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`.
11281192 name: A name for the operation (optional).
1193+ optimize: if true, attempt to statically determine the shape of 'tensor'
1194+ and encode it as a constant.
11291195
11301196 Returns:
11311197 A `Tensor` with all elements set to 1.
11321198 """
11331199 with ops .op_scope ([tensor ], name , "ones_like" ) as name :
11341200 tensor = ops .convert_to_tensor (tensor , name = "tensor" )
1135- ones_shape = shape (tensor )
1201+ ones_shape = shape_internal (tensor , optimize = optimize )
11361202 if dtype is None :
11371203 dtype = tensor .dtype
11381204 ret = ones (ones_shape , dtype = dtype , name = name )
0 commit comments