Hi
We are using tf.lookup.experimental.DenseHashTable
withing a tf.map_fn
to perform specialised sorting of items within rows of a RaggedTensor
. When loading the SavedModel
with Tensorflow Serving, we get the following MLIR failure message:
error: 'tfg.While' op body function argument #6 type 'tensor<!tf_type.resource<tensor<!tf_type.string>>>' is not compatible with corresponding operand type: 'tensor<!tf_type.resource<tensor<!tf_type.string>, tensor<i32>>>' 2024-05-15 06:42:15.209491: E external/org_tensorflow/tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed:
The tfg.While
is generated from the map_fn
. Argument #6 is the DenseHashTable
, which has tf.string
keys and tf.int32
values.
The code:
@tf.function def sort_rail_horizontally_hashmap_map_fn(sorted_list: tf.Tensor, rails_to_sort: tf.RaggedTensor): @tf.function def reorder_rail(rail): lookup_indexes = lookup_table.lookup(rail) match_mask = ~tf.equal(lookup_indexes, -1) match_indexes = tf.where(condition=match_mask) extracted_match_indexes = tf.gather( lookup_indexes, match_indexes[:, 0], name="gather_extracted_match_indexes" ) extracted_sorted_list = tf.gather( sorted_list, extracted_match_indexes, name="gather_extracted_sorted_list" ) sorted_match_indexes = tf.argsort(extracted_match_indexes) reordered_extracted_sorted_list = tf.gather( extracted_sorted_list, sorted_match_indexes, name="gather_reordered_extracted_sorted_list", ) composite_rail = tf.tensor_scatter_nd_update( tensor=rail, indices=match_indexes, updates=reordered_extracted_sorted_list, name="composite_rail_scatter", ) return composite_rail lookup_table = tf.lookup.experimental.DenseHashTable( key_dtype=tf.string, value_dtype=tf.int32, default_value=-1, empty_key="$", deleted_key="£", ) lookup_table.insert( sorted_list, tf.range(0, tf.size(sorted_list), dtype=tf.int32), name="lookup_table_insert" ) ragged_rails = tf.map_fn( reorder_rail, rails_to_sort, parallel_iterations=50, fn_output_signature=tf.RaggedTensorSpec(shape=[None], dtype=tf.string), name="rails_to_sort_map_fn", ) return ragged_rails
The While
node looks like this:
node_def { name: "rails_to_sort_map_fn/while" op: "While" input: "rails_to_sort_map_fn/while/loop_counter:output:0" input: "rails_to_sort_map_fn/strided_slice:output:0" input: "rails_to_sort_map_fn/Const:output:0" input: "rails_to_sort_map_fn/TensorArrayV2_1:handle:0" input: "rails_to_sort_map_fn/strided_slice:output:0" input: "rails_to_sort_map_fn/TensorArrayUnstack/TensorListFromTensor:output_handle:0" input: "MutableDenseHashTable:table_handle:0" input: "default_value:output:0" input: "sorted_list" input: "^lookup_table_insert/LookupTableInsertV2" attr { key: "T" value { list { type: DT_INT32 type: DT_INT32 type: DT_INT32 type: DT_VARIANT type: DT_INT32 type: DT_VARIANT type: DT_RESOURCE type: DT_INT32 type: DT_STRING } } } attr { key: "_lower_using_switch_merge" value { b: true } } attr { key: "_num_original_outputs" value { i: 9 } } attr { key: "_read_only_resource_inputs" value { list { } } } attr { key: "body" value { func { name: "rails_to_sort_map_fn_while_body_18098" } } } attr { key: "cond" value { func { name: "rails_to_sort_map_fn_while_cond_18097" } } } attr { key: "output_shapes" value { list { shape { } shape { } shape { } shape { } shape { } shape { } shape { } shape { } shape { dim { size: 828 } } } } } attr { key: "parallel_iterations" value { i: 50 } } }
The graph executes as expected, and we guess that maybe MLIR does not yet support references to DenseHashTable
or it’s an MLIR bug?
Our primary concern is what is the effect of the failure. Does it stop all graph optimization on the target server, or only that node of the graph?
Thanks
Adrian