forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstring_ops.py
More file actions
97 lines (74 loc) · 2.98 KB
/
string_ops.py
File metadata and controls
97 lines (74 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""## Hashing
String hashing ops take a string input tensor and map each element to an
integer.
@@string_to_hash_bucket
## Joining
String joining ops concatenate elements of input string tensors to produce a new
string tensor.
@@reduce_join
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import common_shapes
# pylint: disable=unused-import
from tensorflow.python.ops import gen_string_ops
# pylint: enable=unused-import
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_string_ops import *
# pylint: enable=wildcard-import
ops.NoGradient("StringToHashBucket")
ops.NoGradient("ReduceJoin")
ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape)
@ops.RegisterShape("ReduceJoin")
def _ReduceJoinShape(op):
"""Shape function for the ReduceJoin op."""
input_shape = op.inputs[0].get_shape()
reduction_indices = np.ravel(tensor_util.constant_value(op.inputs[1]))
keep_dims = op.get_attr("keep_dims")
if input_shape.ndims is None:
return [tensor_shape.unknown_shape()]
if input_shape.ndims == 0:
raise ValueError("Input string tensor cannot be a scalar.")
true_indices = set()
for reduction_index in reduction_indices:
if reduction_index is None:
return [tensor_shape.unknown_shape()]
if (reduction_index < -input_shape.ndims or
reduction_index >= input_shape.ndims):
raise ValueError("Invalid reduction dimension %d for input with %d "
"dimensions" % (reduction_index, input_shape.ndims))
true_index = reduction_index % input_shape.ndims
if true_index in true_indices:
raise ValueError("Duplicate reduction index %d." % reduction_index)
if input_shape.dims[true_index] == 0:
raise ValueError("Cannot reduce dimension %d with size 0." %
reduction_index)
true_indices.add(true_index)
returned_dims = []
for i, dim in enumerate(input_shape.dims):
if i in true_indices:
if keep_dims:
returned_dims.append(1)
else:
returned_dims.append(dim)
return [tensor_shape.TensorShape(returned_dims)]