| |
| |
|
|
| |
| import argparse |
| import logging |
|
|
| import numpy as np |
| import onnx |
| import sympy |
| from onnx import helper, numpy_helper, shape_inference |
| from packaging import version |
|
|
| assert version.parse(onnx.__version__) >= version.parse("1.8.0") |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def get_attribute(node, attr_name, default_value=None): |
| found = [attr for attr in node.attribute if attr.name == attr_name] |
| if found: |
| return helper.get_attribute_value(found[0]) |
| return default_value |
|
|
|
|
| def get_dim_from_proto(dim): |
| return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None |
|
|
|
|
| def is_sequence(type_proto): |
| cls_type = type_proto.WhichOneof("value") |
| assert cls_type in ["tensor_type", "sequence_type"] |
| return cls_type == "sequence_type" |
|
|
|
|
| def get_shape_from_type_proto(type_proto): |
| assert not is_sequence(type_proto) |
| if type_proto.tensor_type.HasField("shape"): |
| return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] |
| else: |
| return None |
|
|
|
|
| def get_elem_type_from_type_proto(type_proto): |
| if is_sequence(type_proto): |
| return type_proto.sequence_type.elem_type.tensor_type.elem_type |
| else: |
| return type_proto.tensor_type.elem_type |
|
|
|
|
| def get_shape_from_value_info(vi): |
| cls_type = vi.type.WhichOneof("value") |
| if cls_type is None: |
| return None |
| if is_sequence(vi.type): |
| if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type": |
| return get_shape_from_type_proto(vi.type.sequence_type.elem_type) |
| else: |
| return None |
| else: |
| return get_shape_from_type_proto(vi.type) |
|
|
|
|
| def make_named_value_info(name): |
| vi = onnx.ValueInfoProto() |
| vi.name = name |
| return vi |
|
|
|
|
| def get_shape_from_sympy_shape(sympy_shape): |
| return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape] |
|
|
|
|
| def is_literal(dim): |
| return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number) |
|
|
|
|
| def handle_negative_axis(axis, rank): |
| assert axis < rank and axis >= -rank |
| return axis if axis >= 0 else rank + axis |
|
|
|
|
| def get_opset(mp, domain=None): |
| domain = domain or ["", "onnx", "ai.onnx"] |
| if type(domain) != list: |
| domain = [domain] |
| for opset in mp.opset_import: |
| if opset.domain in domain: |
| return opset.version |
|
|
| return None |
|
|
|
|
| def as_scalar(x): |
| if type(x) == list: |
| assert len(x) == 1 |
| return x[0] |
| elif type(x) == np.ndarray: |
| return x.item() |
| else: |
| return x |
|
|
|
|
| def as_list(x, keep_none): |
| if type(x) == list: |
| return x |
| elif type(x) == np.ndarray: |
| return list(x) |
| elif keep_none and x is None: |
| return None |
| else: |
| return [x] |
|
|
|
|
| def sympy_reduce_product(x): |
| if type(x) == list: |
| value = sympy.Integer(1) |
| for v in x: |
| value = value * v |
| else: |
| value = x |
| return value |
|
|
|
|
| class SymbolicShapeInference: |
| def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): |
| self.dispatcher_ = { |
| "Add": self._infer_symbolic_compute_ops, |
| "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor, |
| "AveragePool": self._infer_Pool, |
| "BatchNormalization": self._infer_BatchNormalization, |
| "Cast": self._infer_Cast, |
| "CategoryMapper": self._infer_CategoryMapper, |
| "Compress": self._infer_Compress, |
| "Concat": self._infer_Concat, |
| "ConcatFromSequence": self._infer_ConcatFromSequence, |
| "Constant": self._infer_Constant, |
| "ConstantOfShape": self._infer_ConstantOfShape, |
| "Conv": self._infer_Conv, |
| "CumSum": self._pass_on_shape_and_type, |
| "Div": self._infer_symbolic_compute_ops, |
| "Einsum": self._infer_Einsum, |
| "Expand": self._infer_Expand, |
| "Equal": self._infer_symbolic_compute_ops, |
| "Floor": self._infer_symbolic_compute_ops, |
| "Gather": self._infer_Gather, |
| "GatherElements": self._infer_GatherElements, |
| "GatherND": self._infer_GatherND, |
| "Identity": self._pass_on_shape_and_type, |
| "AllReduce": self._pass_on_shape_and_type, |
| "If": self._infer_If, |
| "Loop": self._infer_Loop, |
| "MatMul": self._infer_MatMul, |
| "MatMulInteger16": self._infer_MatMulInteger, |
| "MaxPool": self._infer_Pool, |
| "Max": self._infer_symbolic_compute_ops, |
| "MemcpyFromHost": self._pass_on_shape_and_type, |
| "MemcpyToHost": self._pass_on_shape_and_type, |
| "Min": self._infer_symbolic_compute_ops, |
| "MoE": self._pass_on_shape_and_type, |
| "Mul": self._infer_symbolic_compute_ops, |
| "NonMaxSuppression": self._infer_NonMaxSuppression, |
| "NonZero": self._infer_NonZero, |
| "OneHot": self._infer_OneHot, |
| "Pad": self._infer_Pad, |
| "Range": self._infer_Range, |
| "Reciprocal": self._pass_on_shape_and_type, |
| "ReduceSum": self._infer_ReduceSum, |
| "ReduceProd": self._infer_ReduceProd, |
| "Reshape": self._infer_Reshape, |
| "Resize": self._infer_Resize, |
| "Round": self._pass_on_shape_and_type, |
| "Scan": self._infer_Scan, |
| "ScatterElements": self._infer_ScatterElements, |
| "SequenceAt": self._infer_SequenceAt, |
| "SequenceInsert": self._infer_SequenceInsert, |
| "Shape": self._infer_Shape, |
| "Size": self._infer_Size, |
| "Slice": self._infer_Slice, |
| "SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss, |
| "SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss, |
| "NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss, |
| "Split": self._infer_Split, |
| "SplitToSequence": self._infer_SplitToSequence, |
| "Squeeze": self._infer_Squeeze, |
| "Sub": self._infer_symbolic_compute_ops, |
| "Tile": self._infer_Tile, |
| "TopK": self._infer_TopK, |
| "Transpose": self._infer_Transpose, |
| "Unsqueeze": self._infer_Unsqueeze, |
| "Where": self._infer_symbolic_compute_ops, |
| "ZipMap": self._infer_ZipMap, |
| "Neg": self._infer_symbolic_compute_ops, |
| |
| "Attention": self._infer_Attention, |
| "BiasAdd": self._infer_BiasAdd, |
| "BiasGelu": self._infer_BiasGelu, |
| "BiasSplitGelu": self._infer_BiasSplitGelu, |
| "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, |
| "DequantizeLinear": self._infer_DequantizeLinear, |
| "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, |
| "FastGelu": self._infer_FastGelu, |
| "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, |
| "Gelu": self._infer_Gelu, |
| "GemmFastGelu": self._infer_GemmFastGelu, |
| "GemmFloat8": self._infer_GemmFloat8, |
| "GroupNorm": self._infer_GroupNorm, |
| "GroupQueryAttention": self._infer_GroupQueryAttention, |
| "SkipGroupNorm": self._infer_SkipGroupNorm, |
| "LayerNormalization": self._infer_LayerNormalization, |
| "LongformerAttention": self._infer_LongformerAttention, |
| "MultiHeadAttention": self._infer_MultiHeadAttention, |
| "NhwcConv": self._infer_NhwcConv, |
| "PackedAttention": self._infer_PackedAttention, |
| "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, |
| "PagedAttention": self._infer_PagedAttention, |
| "PythonOp": self._infer_PythonOp, |
| "QuantizeLinear": self._infer_QuantizeLinear, |
| "QuickGelu": self._infer_FastGelu, |
| "RelativePositionBias": self._infer_RelativePositionBias, |
| "RemovePadding": self._infer_RemovePadding, |
| "RestorePadding": self._infer_RestorePadding, |
| "RotaryEmbedding": self._infer_RotaryEmbedding, |
| "SimplifiedLayerNormalization": self._infer_LayerNormalization, |
| "SkipLayerNormalization": self._infer_SkipLayerNormalization, |
| "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, |
| } |
| self.aten_op_dispatcher_ = { |
| "embedding": self._infer_Gather, |
| "bitwise_or": self._infer_aten_bitwise_or, |
| "diagonal": self._infer_aten_diagonal, |
| "max_pool2d_with_indices": self._infer_aten_pool2d, |
| "max": self._infer_aten_minmax, |
| "min": self._infer_aten_minmax, |
| "multinomial": self._infer_aten_multinomial, |
| "unfold": self._infer_aten_unfold, |
| "argmax": self._infer_aten_argmax, |
| "avg_pool2d": self._infer_aten_pool2d, |
| "_adaptive_avg_pool2d": self._infer_aten_pool2d, |
| "numpy_T": self._infer_Transpose, |
| "native_group_norm": self._infer_aten_group_norm, |
| "upsample_nearest1d": self._infer_aten_upsample, |
| "upsample_nearest2d": self._infer_aten_upsample, |
| "upsample_nearest3d": self._infer_aten_upsample, |
| "upsample_bicubic2d": self._infer_aten_upsample, |
| } |
| self.run_ = True |
| self.suggested_merge_ = {} |
| self.symbolic_dims_ = {} |
| self.input_symbols_ = {} |
| self.auto_merge_ = auto_merge |
| self.guess_output_rank_ = guess_output_rank |
| self.verbose_ = verbose |
| self.int_max_ = int_max |
| self.subgraph_id_ = 0 |
| self.prefix_ = prefix |
|
|
| def _add_suggested_merge(self, symbols, apply=False): |
| assert all([(type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols]) |
| symbols = set(symbols) |
| for k, v in self.suggested_merge_.items(): |
| if k in symbols: |
| symbols.remove(k) |
| symbols.add(v) |
| map_to = None |
| |
| for s in symbols: |
| if is_literal(s): |
| map_to = s |
| break |
| |
| if map_to is None: |
| for s in symbols: |
| if s in self.input_symbols_: |
| map_to = s |
| break |
| if map_to is None: |
| for s in symbols: |
| if type(self.symbolic_dims_[s]) == sympy.Symbol: |
| map_to = s |
| break |
| |
| if map_to is None: |
| if self.verbose_ > 0: |
| logger.warning("Potential unsafe merge between symbolic expressions: ({})".format(",".join(symbols))) |
| symbols_list = list(symbols) |
| lens = [len(s) for s in symbols_list] |
| map_to = symbols_list[lens.index(min(lens))] |
| symbols.remove(map_to) |
|
|
| for s in symbols: |
| if s == map_to: |
| continue |
| if is_literal(map_to) and is_literal(s): |
| assert int(map_to) == int(s) |
| self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to |
| for k, v in self.suggested_merge_.items(): |
| if v == s: |
| self.suggested_merge_[k] = map_to |
| if apply and self.auto_merge_: |
| self._apply_suggested_merge() |
|
|
| def _apply_suggested_merge(self, graph_input_only=False): |
| if not self.suggested_merge_: |
| return |
| for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)): |
| for d in i.type.tensor_type.shape.dim: |
| if d.dim_param in self.suggested_merge_: |
| v = self.suggested_merge_[d.dim_param] |
| if is_literal(v): |
| d.dim_value = int(v) |
| else: |
| d.dim_param = v |
|
|
| def _preprocess(self, in_mp): |
| self.out_mp_ = onnx.ModelProto() |
| self.out_mp_.CopyFrom(in_mp) |
| self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)} |
| self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer} |
| self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)} |
| self.known_vi_.update( |
| { |
| i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)) |
| for i in self.out_mp_.graph.initializer |
| } |
| ) |
|
|
| def _merge_symbols(self, dims): |
| if not all([type(d) == str for d in dims]): |
| if self.auto_merge_: |
| unique_dims = list(set(dims)) |
| is_int = [is_literal(d) for d in unique_dims] |
| assert sum(is_int) <= 1 |
| if sum(is_int) == 1: |
| int_dim = is_int.index(1) |
| if self.verbose_ > 0: |
| logger.debug( |
| "dim {} has been merged with value {}".format( |
| unique_dims[:int_dim] + unique_dims[int_dim + 1 :], |
| unique_dims[int_dim], |
| ) |
| ) |
| self._check_merged_dims(unique_dims, allow_broadcast=False) |
| return unique_dims[int_dim] |
| else: |
| if self.verbose_ > 0: |
| logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}") |
| return dims[0] |
| else: |
| return None |
| if all([d == dims[0] for d in dims]): |
| return dims[0] |
| merged = [self.suggested_merge_.get(d, d) for d in dims] |
| if all([d == merged[0] for d in merged]): |
| assert merged[0] in self.symbolic_dims_ |
| return merged[0] |
| else: |
| return None |
|
|
| |
| def _broadcast_shapes(self, shape1, shape2): |
| new_shape = [] |
| rank1 = len(shape1) |
| rank2 = len(shape2) |
| new_rank = max(rank1, rank2) |
| for i in range(new_rank): |
| dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1 |
| dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1 |
| if dim1 == 1 or dim1 == dim2: |
| new_dim = dim2 |
| elif dim2 == 1: |
| new_dim = dim1 |
| else: |
| new_dim = self._merge_symbols([dim1, dim2]) |
| if not new_dim: |
| |
| |
| |
| if self.auto_merge_: |
| self._add_suggested_merge([dim1, dim2], apply=True) |
| else: |
| logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2)) |
| new_shape = [new_dim, *new_shape] |
| return new_shape |
|
|
| def _get_shape(self, node, idx): |
| name = node.input[idx] |
| if name in self.known_vi_: |
| vi = self.known_vi_[name] |
| return get_shape_from_value_info(vi) |
| else: |
| assert name in self.initializers_ |
| return list(self.initializers_[name].dims) |
|
|
| def _try_get_shape(self, node, idx): |
| if idx > len(node.input) - 1: |
| return None |
| name = node.input[idx] |
| if name in self.known_vi_: |
| vi = self.known_vi_[name] |
| return get_shape_from_value_info(vi) |
| if name in self.initializers_: |
| return list(self.initializers_[name].dims) |
| return None |
|
|
| def _get_shape_rank(self, node, idx): |
| return len(self._get_shape(node, idx)) |
|
|
| def _get_sympy_shape(self, node, idx): |
| sympy_shape = [] |
| for d in self._get_shape(node, idx): |
| if type(d) == str: |
| sympy_shape.append( |
| self.symbolic_dims_[d] |
| if d in self.symbolic_dims_ |
| else sympy.Symbol(d, integer=True, nonnegative=True) |
| ) |
| else: |
| assert None is not d |
| sympy_shape.append(d) |
| return sympy_shape |
|
|
| def _get_value(self, node, idx): |
| name = node.input[idx] |
| assert name in self.sympy_data_ or name in self.initializers_ |
| return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name]) |
|
|
| def _try_get_value(self, node, idx): |
| if idx >= len(node.input): |
| return None |
| name = node.input[idx] |
| if name in self.sympy_data_ or name in self.initializers_: |
| return self._get_value(node, idx) |
| return None |
|
|
| def _update_computed_dims(self, new_sympy_shape): |
| for i, new_dim in enumerate(new_sympy_shape): |
| if not is_literal(new_dim) and type(new_dim) != str: |
| str_dim = str(new_dim) |
| if str_dim in self.suggested_merge_: |
| if is_literal(self.suggested_merge_[str_dim]): |
| continue |
| new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]] |
| else: |
| |
| if str(new_dim) not in self.symbolic_dims_: |
| self.symbolic_dims_[str(new_dim)] = new_dim |
|
|
| def _onnx_infer_single_node(self, node): |
| |
| skip_infer = node.op_type in [ |
| "If", |
| "Loop", |
| "Scan", |
| "SplitToSequence", |
| "ZipMap", |
| "Attention", |
| "BiasGelu", |
| "EmbedLayerNormalization", |
| "FastGelu", |
| "Gelu", |
| "GemmFastGelu", |
| "LayerNormalization", |
| "LongformerAttention", |
| "DequantizeLinear", |
| "QuantizeLinear", |
| "RelativePositionBias", |
| "RemovePadding", |
| "RestorePadding", |
| "SimplifiedLayerNormalization", |
| "SkipLayerNormalization", |
| "SkipSimplifiedLayerNormalization", |
| "PackedAttention", |
| "PagedAttention", |
| "PythonOp", |
| "MultiHeadAttention", |
| "GroupNorm", |
| "GroupQueryAttention", |
| "SkipGroupNorm", |
| "BiasSplitGelu", |
| "BiasAdd", |
| "NhwcConv", |
| "QuickGelu", |
| "RotaryEmbedding", |
| ] |
|
|
| if not skip_infer: |
| |
| |
| |
| |
| |
| initializers = [] |
| if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]: |
| initializers = [ |
| self.initializers_[name] |
| for name in node.input |
| if (name in self.initializers_ and name not in self.graph_inputs_) |
| ] |
|
|
| |
| tmp_graph = helper.make_graph( |
| [node], |
| "tmp", |
| [self.known_vi_[i] for i in node.input if i], |
| [make_named_value_info(i) for i in node.output], |
| initializers, |
| ) |
|
|
| self.tmp_mp_.graph.CopyFrom(tmp_graph) |
|
|
| self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) |
|
|
| for i_o in range(len(node.output)): |
| o = node.output[i_o] |
| if o: |
| vi = self.out_mp_.graph.value_info.add() |
| if not skip_infer: |
| vi.CopyFrom(self.tmp_mp_.graph.output[i_o]) |
| else: |
| vi.name = o |
| self.known_vi_[o] = vi |
|
|
| def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True): |
| if self.verbose_ > 2: |
| logger.debug(f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}") |
| |
| |
| |
| |
| subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)} |
| subgraph_implicit_input = {name for name in self.known_vi_ if name not in subgraph_inputs} |
| tmp_graph = helper.make_graph( |
| list(subgraph.node), |
| "tmp", |
| list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input], |
| [make_named_value_info(i.name) for i in subgraph.output], |
| ) |
| tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input]) |
| tmp_graph.initializer.extend(subgraph.initializer) |
| self.tmp_mp_.graph.CopyFrom(tmp_graph) |
|
|
| symbolic_shape_inference = SymbolicShapeInference( |
| self.int_max_, |
| self.auto_merge_, |
| self.guess_output_rank_, |
| self.verbose_, |
| prefix=self.prefix_ + "_" + str(self.subgraph_id_), |
| ) |
| if inc_subgraph_id: |
| self.subgraph_id_ += 1 |
|
|
| symbolic_shape_inference._preprocess(self.tmp_mp_) |
| symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy() |
| while symbolic_shape_inference.run_: |
| symbolic_shape_inference._infer_impl(self.sympy_data_.copy()) |
| symbolic_shape_inference._update_output_from_vi() |
| if use_node_input: |
| |
| subgraph.ClearField("input") |
| subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)]) |
| subgraph.ClearField("output") |
| subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) |
| subgraph.ClearField("value_info") |
| subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info) |
| subgraph.ClearField("node") |
| subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) |
| |
| subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output] |
| subgraph_new_symbolic_dims = { |
| d for s in subgraph_shapes if s for d in s if type(d) == str and d not in self.symbolic_dims_ |
| } |
| new_dims = {} |
| for d in subgraph_new_symbolic_dims: |
| assert d in symbolic_shape_inference.symbolic_dims_ |
| new_dims[d] = symbolic_shape_inference.symbolic_dims_[d] |
| self.symbolic_dims_.update(new_dims) |
| return symbolic_shape_inference |
|
|
| def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=False): |
| def int_or_float(value, allow_float_values): |
| |
| if allow_float_values and value % 1 != 0: |
| return value |
| return int(value) |
|
|
| values = [self._try_get_value(node, i) for i in range(len(node.input))] |
| if all([v is not None for v in values]): |
| |
| for i, v in enumerate(values): |
| if type(v) != np.ndarray: |
| continue |
| if len(v.shape) > 1: |
| new_v = None |
| elif len(v.shape) == 0: |
| new_v = int_or_float(v.item(), allow_float_values) |
| else: |
| assert len(v.shape) == 1 |
| new_v = [int_or_float(vv, allow_float_values) for vv in v] |
| values[i] = new_v |
| values_len = [len(v) if isinstance(v, list) else 0 for v in values] |
| max_len = max(values_len) |
| if max_len >= 1 and broadcast: |
| |
| for i, v in enumerate(values): |
| if v is None: |
| continue |
| if isinstance(v, list): |
| if len(v) < max_len: |
| values[i] = v * max_len |
| else: |
| assert len(v) == max_len |
| else: |
| values[i] = [v] * max_len |
| return values |
|
|
| def _compute_on_sympy_data(self, node, op_func): |
| assert len(node.output) == 1 |
|
|
| |
| |
| |
| if node.op_type in ["Mul", "Div"]: |
| values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True) |
| else: |
| values = self._get_int_or_float_values(node, broadcast=True) |
|
|
| if all([v is not None for v in values]): |
| is_list = [isinstance(v, list) for v in values] |
| as_list = any(is_list) |
| if as_list: |
| self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)] |
| else: |
| self.sympy_data_[node.output[0]] = op_func(values) |
|
|
| def _pass_on_sympy_data(self, node): |
| assert len(node.input) == 1 or node.op_type in [ |
| "Reshape", |
| "Unsqueeze", |
| "Squeeze", |
| ] |
| self._compute_on_sympy_data(node, lambda x: x[0]) |
|
|
| def _pass_on_shape_and_type(self, node): |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type), |
| self._get_shape(node, 0), |
| ) |
| ) |
|
|
| def _new_symbolic_dim(self, prefix, dim): |
| new_dim = f"{prefix}_d{dim}" |
| if new_dim in self.suggested_merge_: |
| v = self.suggested_merge_[new_dim] |
| new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v |
| else: |
| new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True) |
| self.symbolic_dims_[new_dim] = new_symbolic_dim |
| return new_symbolic_dim |
|
|
| def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): |
| return self._new_symbolic_dim( |
| "{}{}_{}_o{}_".format( |
| node.op_type, |
| self.prefix_, |
| list(self.out_mp_.graph.node).index(node), |
| out_idx, |
| ), |
| dim, |
| ) |
|
|
| def _new_symbolic_shape(self, rank, node, out_idx=0): |
| return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)] |
|
|
| def _compute_conv_pool_shape(self, node, channels_last=False): |
| sympy_shape = self._get_sympy_shape(node, 0) |
| if len(node.input) > 1: |
| W_shape = self._get_sympy_shape(node, 1) |
| rank = len(W_shape) - 2 |
| kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:] |
| sympy_shape[3 if channels_last else 1] = W_shape[0] |
| else: |
| W_shape = None |
| kernel_shape = get_attribute(node, "kernel_shape") |
| rank = len(kernel_shape) |
|
|
| assert len(sympy_shape) == rank + 2 |
|
|
| |
| spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] |
| is_symbolic_dims = [not is_literal(i) for i in spatial_shape] |
|
|
| if not any(is_symbolic_dims): |
| shape = get_shape_from_value_info(self.known_vi_[node.output[0]]) |
| if len(shape) > 0: |
| assert len(sympy_shape) == len(shape) |
| if channels_last: |
| sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]] |
| else: |
| sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] |
| return sympy_shape |
|
|
| dilations = get_attribute(node, "dilations", [1] * rank) |
| strides = get_attribute(node, "strides", [1] * rank) |
| effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)] |
| pads = get_attribute(node, "pads") |
| if pads is None: |
| pads = [0] * (2 * rank) |
| auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") |
| if auto_pad != "VALID" and auto_pad != "NOTSET": |
| try: |
| residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] |
| total_pads = [ |
| max(0, (k - s) if r == 0 else (k - r)) |
| for k, s, r in zip(effective_kernel_shape, strides, residual) |
| ] |
| except TypeError: |
| total_pads = [ |
| max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) |
| ] |
| elif auto_pad == "VALID": |
| total_pads = [] |
| else: |
| total_pads = [0] * rank |
| else: |
| assert len(pads) == 2 * rank |
| total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] |
|
|
| ceil_mode = get_attribute(node, "ceil_mode", 0) |
| for i in range(rank): |
| effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)] |
| if len(total_pads) > 0: |
| effective_input_size = effective_input_size + total_pads[i] |
| if ceil_mode: |
| strided_kernel_positions = sympy.ceiling( |
| (effective_input_size - effective_kernel_shape[i]) / strides[i] |
| ) |
| else: |
| strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i] |
| sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1 |
| return sympy_shape |
|
|
| def _check_merged_dims(self, dims, allow_broadcast=True): |
| if allow_broadcast: |
| dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)] |
| if not all([d == dims[0] for d in dims]): |
| self._add_suggested_merge(dims, apply=True) |
|
|
| def _compute_matmul_shape(self, node, output_dtype=None): |
| lhs_shape = self._get_shape(node, 0) |
| rhs_shape = self._get_shape(node, 1) |
| lhs_rank = len(lhs_shape) |
| rhs_rank = len(rhs_shape) |
| lhs_reduce_dim = 0 |
| rhs_reduce_dim = 0 |
| assert lhs_rank > 0 and rhs_rank > 0 |
| if lhs_rank == 1 and rhs_rank == 1: |
| new_shape = [] |
| elif lhs_rank == 1: |
| rhs_reduce_dim = -2 |
| new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]] |
| elif rhs_rank == 1: |
| lhs_reduce_dim = -1 |
| new_shape = lhs_shape[:lhs_reduce_dim] |
| else: |
| lhs_reduce_dim = -1 |
| rhs_reduce_dim = -2 |
| new_shape = [*self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]), lhs_shape[-2], rhs_shape[-1]] |
| |
| self._check_merged_dims( |
| [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], |
| allow_broadcast=False, |
| ) |
| if output_dtype is None: |
| |
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) |
|
|
| def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): |
| """ |
| update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches |
| """ |
| dst_tensor_type = ( |
| dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type |
| ) |
| src_tensor_type = ( |
| src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type |
| ) |
| if dst_tensor_type.elem_type != src_tensor_type.elem_type: |
| node_id = node.name if node.name else node.op_type |
| raise ValueError( |
| f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " |
| f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " |
| f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" |
| ) |
| if dst_tensor_type.HasField("shape"): |
| for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): |
| if ds[0] != ds[1]: |
| |
| |
| new_dim = onnx.TensorShapeProto.Dimension() |
| if not is_sequence(dst_type): |
| new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di)) |
| dst_tensor_type.shape.dim[di].CopyFrom(new_dim) |
| else: |
| dst_tensor_type.CopyFrom(src_tensor_type) |
|
|
| def _infer_ArrayFeatureExtractor(self, node): |
| data_shape = self._get_shape(node, 0) |
| indices_shape = self._get_shape(node, 1) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| data_shape[:-1] + indices_shape, |
| ) |
| ) |
|
|
| def _infer_symbolic_compute_ops(self, node): |
| funcs = { |
| "Add": lambda l: l[0] + l[1], |
| "Div": lambda l: ( |
| int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1] |
| ), |
| "Equal": lambda l: l[0] == l[1], |
| "Floor": lambda l: sympy.floor(l[0]), |
| "Max": lambda l: ( |
| l[1] |
| if is_literal(l[0]) and int(l[0]) < -self.int_max_ |
| else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])) |
| ), |
| "Min": lambda l: ( |
| l[1] |
| if is_literal(l[0]) and int(l[0]) > self.int_max_ |
| else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])) |
| ), |
| "Mul": lambda l: int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1], |
| "Sub": lambda l: l[0] - l[1], |
| "Where": lambda l: l[1] if l[0] else l[2], |
| "Neg": lambda l: -l[0], |
| } |
| assert node.op_type in funcs |
| self._compute_on_sympy_data(node, funcs[node.op_type]) |
|
|
| def _infer_Cast(self, node): |
| self._pass_on_sympy_data(node) |
|
|
| def _infer_CategoryMapper(self, node): |
| input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| if input_type == onnx.TensorProto.STRING: |
| output_type = onnx.TensorProto.INT64 |
| else: |
| output_type = onnx.TensorProto.STRING |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0))) |
|
|
| def _infer_Compress(self, node): |
| input_shape = self._get_shape(node, 0) |
| |
| compress_len = str(self._new_symbolic_dim_from_output(node)) |
| axis = get_attribute(node, "axis") |
| if axis is None: |
| |
| output_shape = [compress_len] |
| else: |
| output_shape = input_shape |
| output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| output_shape, |
| ) |
| ) |
|
|
| def _infer_Concat(self, node): |
| if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]): |
| values = self._get_int_or_float_values(node) |
| if all([v is not None for v in values]): |
| assert get_attribute(node, "axis") == 0 |
| self.sympy_data_[node.output[0]] = [] |
| for i in range(len(node.input)): |
| value = values[i] |
| if isinstance(value, list): |
| self.sympy_data_[node.output[0]].extend(value) |
| else: |
| self.sympy_data_[node.output[0]].append(value) |
|
|
| sympy_shape = self._get_sympy_shape(node, 0) |
| axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape)) |
| for i_idx in range(1, len(node.input)): |
| input_shape = self._get_sympy_shape(node, i_idx) |
| if input_shape: |
| sympy_shape[axis] = sympy_shape[axis] + input_shape[axis] |
| self._update_computed_dims(sympy_shape) |
| |
| for d in range(len(sympy_shape)): |
| if d == axis: |
| continue |
| dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)] |
| if all([d == dims[0] for d in dims]): |
| continue |
| merged = self._merge_symbols(dims) |
| if type(merged) == str: |
| sympy_shape[d] = self.symbolic_dims_[merged] if merged else None |
| else: |
| sympy_shape[d] = merged |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(sympy_shape), |
| ) |
| ) |
|
|
| def _infer_ConcatFromSequence(self, node): |
| seq_shape = self._get_shape(node, 0) |
| new_axis = 1 if get_attribute(node, "new_axis") else 0 |
| axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis) |
| concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) |
| new_shape = seq_shape |
| if new_axis: |
| new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:] |
| else: |
| new_shape[axis] = concat_dim |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type, |
| new_shape, |
| ) |
| ) |
|
|
| def _infer_Constant(self, node): |
| t = get_attribute(node, "value") |
| self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) |
|
|
| def _infer_ConstantOfShape(self, node): |
| sympy_shape = self._get_int_or_float_values(node)[0] |
| vi = self.known_vi_[node.output[0]] |
| if sympy_shape is not None: |
| if type(sympy_shape) != list: |
| sympy_shape = [sympy_shape] |
| self._update_computed_dims(sympy_shape) |
| |
| if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]): |
| self.sympy_data_[node.output[0]] = np.ones( |
| [int(x) for x in sympy_shape], dtype=np.int64 |
| ) * numpy_helper.to_array(get_attribute(node, "value", 0)) |
| else: |
| |
| |
| sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node) |
|
|
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| vi.type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(sympy_shape), |
| ) |
| ) |
|
|
| def _infer_Conv(self, node): |
| sympy_shape = self._compute_conv_pool_shape(node) |
| self._update_computed_dims(sympy_shape) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| vi.type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(sympy_shape), |
| ) |
| ) |
|
|
| def _infer_NhwcConv(self, node): |
| sympy_shape = self._compute_conv_pool_shape(node, channels_last=True) |
| self._update_computed_dims(sympy_shape) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(sympy_shape), |
| ) |
| ) |
|
|
| def _infer_DequantizeLinear(self, node): |
| |
| output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type |
|
|
| |
| output_shape = self._get_shape(node, 0) |
|
|
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
|
|
| def _infer_QuantizeLinear(self, node): |
| |
| |
| output_dtype = onnx.TensorProto.UINT8 |
| if len(node.input) > 2 and node.input[2]: |
| output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type |
|
|
| |
| output_shape = self._get_shape(node, 0) |
|
|
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
|
|
| def _infer_Einsum(self, node): |
| |
| equation = get_attribute(node, "equation") |
| equation = equation.replace(b" ", b"") |
| mid_index = equation.find(b"->") |
| left_equation = equation[:mid_index] if mid_index != -1 else equation |
|
|
| num_operands = 0 |
| num_ellipsis = 0 |
| num_ellipsis_indices = 0 |
|
|
| letter_to_dim = {} |
|
|
| terms = left_equation.split(b",") |
| for term in terms: |
| ellipsis_index = term.find(b"...") |
| shape = self._get_shape(node, num_operands) |
| rank = len(shape) |
| if ellipsis_index != -1: |
| if num_ellipsis == 0: |
| num_ellipsis_indices = rank - len(term) + 3 |
| num_ellipsis = num_ellipsis + 1 |
| for i in range(1, rank + 1): |
| letter = term[-i] |
| if letter != 46: |
| dim = shape[-i] |
| if letter not in letter_to_dim: |
| letter_to_dim[letter] = dim |
| elif type(dim) != sympy.Symbol: |
| letter_to_dim[letter] = dim |
| num_operands = num_operands + 1 |
|
|
| new_sympy_shape = [] |
| from collections import OrderedDict |
|
|
| num_letter_occurrences = OrderedDict() |
| if mid_index != -1: |
| right_equation = equation[mid_index + 2 :] |
| right_ellipsis_index = right_equation.find(b"...") |
| if right_ellipsis_index != -1: |
| for i in range(num_ellipsis_indices): |
| new_sympy_shape.append(shape[i]) |
| for c in right_equation: |
| if c != 46: |
| new_sympy_shape.append(letter_to_dim[c]) |
| else: |
| for i in range(num_ellipsis_indices): |
| new_sympy_shape.append(shape[i]) |
| for c in left_equation: |
| if c != 44 and c != 46: |
| if c in num_letter_occurrences: |
| num_letter_occurrences[c] = num_letter_occurrences[c] + 1 |
| else: |
| num_letter_occurrences[c] = 1 |
| for key, value in num_letter_occurrences.items(): |
| if value == 1: |
| new_sympy_shape.append(letter_to_dim[key]) |
|
|
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape)) |
|
|
| def _infer_Expand(self, node): |
| expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) |
| if expand_to_shape is not None: |
| |
| self._update_computed_dims(expand_to_shape) |
| shape = self._get_shape(node, 0) |
| new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape)) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| new_shape, |
| ) |
| ) |
|
|
| def _infer_Gather(self, node): |
| data_shape = self._get_shape(node, 0) |
| axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape)) |
| indices_shape = self._get_shape(node, 1) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| data_shape[:axis] + indices_shape + data_shape[axis + 1 :], |
| ) |
| ) |
| |
| if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0: |
| idx = self._try_get_value(node, 1) |
| if idx is not None: |
| data = self.sympy_data_[node.input[0]] |
| if type(data) == list: |
| if type(idx) == np.ndarray and len(idx.shape) == 1: |
| self.sympy_data_[node.output[0]] = [data[int(i)] for i in idx] |
| else: |
| self.sympy_data_[node.output[0]] = data[int(idx)] |
| else: |
| assert idx == 0 or idx == -1 |
| self.sympy_data_[node.output[0]] = data |
|
|
| def _infer_GatherElements(self, node): |
| indices_shape = self._get_shape(node, 1) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| indices_shape, |
| ) |
| ) |
|
|
| def _infer_GatherND(self, node): |
| data_shape = self._get_shape(node, 0) |
| data_rank = len(data_shape) |
| indices_shape = self._get_shape(node, 1) |
| len(indices_shape) |
| last_index_dimension = indices_shape[-1] |
| assert is_literal(last_index_dimension) and last_index_dimension <= data_rank |
| new_shape = indices_shape[:-1] + data_shape[last_index_dimension:] |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| new_shape, |
| ) |
| ) |
|
|
| def _infer_If(self, node): |
| |
| subgraphs = [ |
| get_attribute(node, "then_branch"), |
| get_attribute(node, "else_branch"), |
| ] |
| cond = self._try_get_value(node, 0) |
| if cond is not None: |
| if as_scalar(cond) > 0: |
| subgraphs[1].CopyFrom(subgraphs[0]) |
| else: |
| subgraphs[0].CopyFrom(subgraphs[1]) |
|
|
| for i_sub, subgraph in enumerate(subgraphs): |
| subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False) |
| for i_out in range(len(node.output)): |
| vi = self.known_vi_[node.output[i_out]] |
| if i_sub == 0: |
| vi.CopyFrom(subgraph.output[i_out]) |
| vi.name = node.output[i_out] |
| else: |
| self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type) |
|
|
| |
| if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else 1): |
| if subgraph.output[i_out].name in subgraph_infer.sympy_data_: |
| self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name] |
|
|
| def _infer_Loop(self, node): |
| subgraph = get_attribute(node, "body") |
| assert len(subgraph.input) == len(node.input) |
| num_loop_carried = len(node.input) - 2 |
| |
| |
| for i, si in enumerate(subgraph.input): |
| si_name = si.name |
| si.CopyFrom(self.known_vi_[node.input[i]]) |
| si.name = si_name |
|
|
| self._onnx_infer_subgraph(node, subgraph) |
|
|
| |
| |
| |
| need_second_infer = False |
| for i_out in range(1, num_loop_carried + 1): |
| so = subgraph.output[i_out] |
| so_shape = get_shape_from_value_info(so) |
| if is_sequence(so.type): |
| if so_shape and None in so_shape: |
| |
| |
| |
| subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type) |
| need_second_infer = True |
| else: |
| si = subgraph.input[i_out + 1] |
| si_shape = get_shape_from_value_info(si) |
| for di, dims in enumerate(zip(si_shape, so_shape)): |
| if dims[0] != dims[1]: |
| new_dim = onnx.TensorShapeProto.Dimension() |
| new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di)) |
| si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) |
| so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) |
| need_second_infer = True |
|
|
| if need_second_infer: |
| if self.verbose_ > 2: |
| logger.debug( |
| "Rerun Loop: {}({}...), because of sequence in loop carried variables".format( |
| node.name, node.output[0] |
| ) |
| ) |
| self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False) |
|
|
| |
| loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) |
| for i in range(len(node.output)): |
| vi = self.known_vi_[node.output[i]] |
| vi.CopyFrom(subgraph.output[i + 1]) |
| if i >= num_loop_carried: |
| assert not is_sequence(vi.type) |
| subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim |
| vi.type.tensor_type.shape.ClearField("dim") |
| vi_dim = vi.type.tensor_type.shape.dim |
| vi_dim.add().dim_param = loop_iter_dim |
| vi_dim.extend(list(subgraph_vi_dim)) |
| vi.name = node.output[i] |
|
|
| def _infer_MatMul(self, node): |
| self._compute_matmul_shape(node) |
|
|
| def _infer_MatMulInteger(self, node): |
| self._compute_matmul_shape(node, onnx.TensorProto.INT32) |
|
|
| def _infer_NonMaxSuppression(self, node): |
| selected = str(self._new_symbolic_dim_from_output(node)) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3])) |
|
|
| def _infer_NonZero(self, node): |
| input_rank = self._get_shape_rank(node, 0) |
| |
| nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len])) |
|
|
| def _infer_OneHot(self, node): |
| sympy_shape = self._get_sympy_shape(node, 0) |
| depth = self._try_get_value(node, 1) |
| axis = get_attribute(node, "axis", -1) |
| axis = handle_negative_axis(axis, len(sympy_shape) + 1) |
| new_shape = get_shape_from_sympy_shape( |
| sympy_shape[:axis] |
| + [self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth] |
| + sympy_shape[axis:] |
| ) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[2]].type.tensor_type.elem_type, |
| new_shape, |
| ) |
| ) |
|
|
| def _infer_Pad(self, node): |
| if get_opset(self.out_mp_) <= 10: |
| pads = get_attribute(node, "pads") |
| else: |
| pads = self._try_get_value(node, 1) |
|
|
| sympy_shape = self._get_sympy_shape(node, 0) |
| rank = len(sympy_shape) |
|
|
| if pads is not None: |
| assert len(pads) == 2 * rank |
| new_sympy_shape = [ |
| d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) |
| ] |
| self._update_computed_dims(new_sympy_shape) |
| else: |
| |
| new_sympy_shape = self._new_symbolic_shape(rank, node) |
| output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
|
|
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)) |
| ) |
|
|
| def _infer_Pool(self, node): |
| sympy_shape = self._compute_conv_pool_shape(node) |
| self._update_computed_dims(sympy_shape) |
| for o in node.output: |
| if not o: |
| continue |
| vi = self.known_vi_[o] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| o, |
| vi.type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(sympy_shape), |
| ) |
| ) |
|
|
| def _infer_aten_bitwise_or(self, node): |
| shape0 = self._get_shape(node, 0) |
| shape1 = self._get_shape(node, 1) |
| new_shape = self._broadcast_shapes(shape0, shape1) |
| t0 = self.known_vi_[node.input[0]] |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape)) |
|
|
| def _infer_aten_diagonal(self, node): |
| sympy_shape = self._get_sympy_shape(node, 0) |
| rank = len(sympy_shape) |
| offset = self._try_get_value(node, 1) |
| dim1 = self._try_get_value(node, 2) |
| dim2 = self._try_get_value(node, 3) |
|
|
| assert offset is not None and dim1 is not None and dim2 is not None |
| dim1 = handle_negative_axis(dim1, rank) |
| dim2 = handle_negative_axis(dim2, rank) |
|
|
| new_shape = [] |
| for dim, val in enumerate(sympy_shape): |
| if dim not in [dim1, dim2]: |
| new_shape.append(val) |
|
|
| shape1 = sympy_shape[dim1] |
| shape2 = sympy_shape[dim2] |
| if offset >= 0: |
| diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset)) |
| else: |
| diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2)) |
| new_shape.append(diag_shape) |
|
|
| if node.output[0]: |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(new_shape), |
| ) |
| ) |
|
|
| def _infer_aten_multinomial(self, node): |
| sympy_shape = self._get_sympy_shape(node, 0) |
| rank = len(sympy_shape) |
| assert rank in [1, 2] |
| num_samples = self._try_get_value(node, 1) |
| di = rank - 1 |
| last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di)) |
| output_shape = sympy_shape[:-1] + [last_dim] |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| onnx.TensorProto.INT64, |
| get_shape_from_sympy_shape(output_shape), |
| ) |
| ) |
|
|
| def _infer_aten_pool2d(self, node): |
| sympy_shape = self._get_sympy_shape(node, 0) |
| assert len(sympy_shape) == 4 |
| sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]] |
| self._update_computed_dims(sympy_shape) |
| for i, o in enumerate(node.output): |
| if not o: |
| continue |
| vi = self.known_vi_[o] |
| elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape))) |
|
|
| def _infer_aten_minmax(self, node): |
| vi = self.known_vi_[node.output[0]] |
| if len(node.input) == 1: |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, [] |
| ) |
| ) |
| else: |
| assert len(node.input) == 3 |
| keepdim = self._try_get_value(node, 2) |
| assert keepdim is not None |
| dim = self._try_get_value(node, 1) |
| if dim is None: |
| rank = self._get_shape_rank(node, 0) |
| output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) |
| else: |
| shape = self._get_sympy_shape(node, 0) |
| dim = handle_negative_axis(dim, len(shape)) |
| output_shape = shape[:dim] |
| if keepdim: |
| output_shape += [1] |
| output_shape += shape[dim + 1 :] |
|
|
| output_shape = get_shape_from_sympy_shape(output_shape) |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, output_shape |
| ) |
| ) |
| vi1 = self.known_vi_[node.output[1]] |
| vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape)) |
|
|
| def _infer_aten_unfold(self, node): |
| sympy_shape = self._get_sympy_shape(node, 0) |
| dimension = self._try_get_value(node, 1) |
| size = self._try_get_value(node, 2) |
| step = self._try_get_value(node, 3) |
| if dimension is not None and size is not None and step is not None: |
| assert dimension < len(sympy_shape) |
| sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1 |
| sympy_shape.append(size) |
| else: |
| rank = len(sympy_shape) |
| sympy_shape = self._new_symbolic_shape(rank + 1, node) |
| self._update_computed_dims(sympy_shape) |
| if node.output[0]: |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(sympy_shape), |
| ) |
| ) |
|
|
| def _infer_aten_argmax(self, node): |
| new_shape = None |
| if not node.input[1]: |
| |
| new_shape = [] |
| else: |
| dim = self._try_get_value(node, 1) |
| keepdim = self._try_get_value(node, 2) |
| if keepdim is not None: |
| sympy_shape = self._get_sympy_shape(node, 0) |
| if dim is not None: |
| dim = handle_negative_axis(dim, len(sympy_shape)) |
| if keepdim: |
| sympy_shape[dim] = 1 |
| else: |
| del sympy_shape[dim] |
| else: |
| rank = len(sympy_shape) |
| sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) |
| self._update_computed_dims(sympy_shape) |
| new_shape = get_shape_from_sympy_shape(sympy_shape) |
| if node.output[0] and new_shape is not None: |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape)) |
|
|
| def _infer_aten_group_norm(self, node): |
| self._propagate_shape_and_type(node) |
| input_shape = self._get_shape(node, 0) |
| N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None |
| group = self._try_get_value(node, 6) |
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| for i in [1, 2]: |
| if node.output[i]: |
| vi = self.known_vi_[node.output[i]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[i], |
| output_dtype, |
| [ |
| N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)), |
| ( |
| as_scalar(group) |
| if group is not None |
| else str(self._new_symbolic_dim_from_output(node, i, 1)) |
| ), |
| ], |
| ) |
| ) |
|
|
| def _infer_aten_upsample(self, node): |
| new_shape = None |
| input_shape = self._get_shape(node, 0) |
| if input_shape is not None: |
| new_shape = input_shape[:2] |
| output_size = self._try_get_value(node, 1) |
| if output_size is not None: |
| new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size] |
| else: |
| rank = len(input_shape) |
| new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)] |
| if node.output[0] and new_shape is not None: |
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) |
|
|
| def _infer_BatchNormalization(self, node): |
| self._propagate_shape_and_type(node) |
|
|
| |
| for i in [1, 2, 3, 4]: |
| if i < len(node.output) and node.output[i]: |
| |
| self._propagate_shape_and_type(node, input_index=1, output_index=i) |
|
|
| def _infer_Range(self, node): |
| vi = self.known_vi_[node.output[0]] |
| input_data = self._get_int_or_float_values(node) |
| if all([i is not None for i in input_data]): |
| start = as_scalar(input_data[0]) |
| limit = as_scalar(input_data[1]) |
| delta = as_scalar(input_data[2]) |
| new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)] |
| else: |
| new_sympy_shape = [self._new_symbolic_dim_from_output(node)] |
| self._update_computed_dims(new_sympy_shape) |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(new_sympy_shape), |
| ) |
| ) |
|
|
| def _infer_ReduceSum(self, node): |
| keep_dims = get_attribute(node, "keepdims", 1) |
| if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: |
| |
| axes = self._try_get_value(node, 1) |
| vi = self.known_vi_[node.output[0]] |
| if axes is None: |
| assert keep_dims |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)), |
| ) |
| ) |
| else: |
| shape = self._get_shape(node, 0) |
| output_shape = [] |
| axes = [handle_negative_axis(a, len(shape)) for a in axes] |
| for i, d in enumerate(shape): |
| if i in axes: |
| if keep_dims: |
| output_shape.append(1) |
| else: |
| output_shape.append(d) |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| output_shape, |
| ) |
| ) |
|
|
| def _infer_ReduceProd(self, node): |
| axes = get_attribute(node, "axes") |
| keep_dims = get_attribute(node, "keepdims", 1) |
| if keep_dims == 0 and axes == [0]: |
| data = self._get_int_or_float_values(node)[0] |
| if data is not None: |
| self.sympy_data_[node.output[0]] = sympy_reduce_product(data) |
|
|
| def _infer_RelativePositionBias(self, node): |
| seq_len = self._try_get_value(node, 1) |
| real_seq_len = self._try_get_value(node, 2) |
| if seq_len is None or real_seq_len is None: |
| return |
| num_heads = self._get_sympy_shape(node, 0)[1] |
|
|
| new_shape = [1, num_heads, str(seq_len), str(real_seq_len)] |
|
|
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) |
|
|
| def _infer_Reshape(self, node): |
| shape_value = self._try_get_value(node, 1) |
| vi = self.known_vi_[node.output[0]] |
| if shape_value is None: |
| shape_shape = self._get_shape(node, 1) |
| assert len(shape_shape) == 1 |
| shape_rank = shape_shape[0] |
| assert is_literal(shape_rank) |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| vi.type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)), |
| ) |
| ) |
| else: |
| input_sympy_shape = self._get_sympy_shape(node, 0) |
| total = 1 |
| for d in input_sympy_shape: |
| total = total * d |
| new_sympy_shape = [] |
| deferred_dim_idx = -1 |
| non_deferred_size = 1 |
| for i, d in enumerate(shape_value): |
| if type(d) == sympy.Symbol: |
| new_sympy_shape.append(d) |
| elif d == 0: |
| new_sympy_shape.append(input_sympy_shape[i]) |
| non_deferred_size = non_deferred_size * input_sympy_shape[i] |
| else: |
| new_sympy_shape.append(d) |
| if d == -1: |
| deferred_dim_idx = i |
| elif d != 0: |
| non_deferred_size = non_deferred_size * d |
|
|
| assert new_sympy_shape.count(-1) < 2 |
| if -1 in new_sympy_shape: |
| new_dim = total // non_deferred_size |
| new_sympy_shape[deferred_dim_idx] = new_dim |
|
|
| self._update_computed_dims(new_sympy_shape) |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| vi.type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(new_sympy_shape), |
| ) |
| ) |
|
|
| self._pass_on_sympy_data(node) |
|
|
| def _infer_Resize(self, node): |
| vi = self.known_vi_[node.output[0]] |
| input_sympy_shape = self._get_sympy_shape(node, 0) |
| if get_opset(self.out_mp_) <= 10: |
| scales = self._try_get_value(node, 1) |
| if scales is not None: |
| new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)] |
| self._update_computed_dims(new_sympy_shape) |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(new_sympy_shape), |
| ) |
| ) |
| else: |
| roi = self._try_get_value(node, 1) |
| scales = self._try_get_value(node, 2) |
| sizes = self._try_get_value(node, 3) |
| if sizes is not None: |
| new_sympy_shape = [sympy.simplify(sympy.floor(s)) for s in sizes] |
| self._update_computed_dims(new_sympy_shape) |
| elif scales is not None: |
| rank = len(scales) |
| if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize": |
| assert len(roi) == 2 * rank |
| roi_start = list(roi)[:rank] |
| roi_end = list(roi)[rank:] |
| else: |
| roi_start = [0] * rank |
| roi_end = [1] * rank |
| scales = list(scales) |
| new_sympy_shape = [ |
| sympy.simplify(sympy.floor(d * (end - start) * scale)) |
| for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales) |
| ] |
| self._update_computed_dims(new_sympy_shape) |
| else: |
| new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) |
|
|
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(new_sympy_shape), |
| ) |
| ) |
|
|
| def _infer_Scan(self, node): |
| subgraph = get_attribute(node, "body") |
| num_scan_inputs = get_attribute(node, "num_scan_inputs") |
| scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs) |
| num_scan_states = len(node.input) - num_scan_inputs |
| scan_input_axes = [ |
| handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states)) |
| for i, ax in enumerate(scan_input_axes) |
| ] |
| |
| |
| assert len(subgraph.input) >= len(node.input) |
| subgraph_inputs = subgraph.input[: len(node.input)] |
| for i, si in enumerate(subgraph_inputs): |
| subgraph_name = si.name |
| si.CopyFrom(self.known_vi_[node.input[i]]) |
| if i >= num_scan_states: |
| scan_input_dim = si.type.tensor_type.shape.dim |
| scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]]) |
| si.name = subgraph_name |
| self._onnx_infer_subgraph(node, subgraph) |
| num_scan_outputs = len(node.output) - num_scan_states |
| scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs) |
| scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] |
| for i, o in enumerate(node.output): |
| vi = self.known_vi_[o] |
| if i >= num_scan_states: |
| shape = get_shape_from_type_proto(subgraph.output[i].type) |
| new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1) |
| shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] |
| vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape)) |
| else: |
| vi.CopyFrom(subgraph.output[i]) |
| vi.name = o |
|
|
| def _infer_ScatterElements(self, node): |
| data_shape = self._get_shape(node, 0) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| data_shape, |
| ) |
| ) |
|
|
| def _infer_SequenceAt(self, node): |
| |
| seq_shape = self._get_shape(node, 0) |
| vi = self.known_vi_[node.output[0]] |
| if seq_shape is not None: |
| for di, d in enumerate(seq_shape): |
| if d is not None: |
| continue |
| new_dim = onnx.TensorShapeProto.Dimension() |
| new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, 0, di)) |
| vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim) |
|
|
| def _infer_SequenceInsert(self, node): |
| |
| vi_seq = self.known_vi_[node.input[0]] |
| vi_tensor = self.known_vi_[node.input[1]] |
| vi_out_seq = self.known_vi_[node.output[0]] |
| vi_out_seq.CopyFrom(vi_seq) |
| vi_out_seq.name = node.output[0] |
| self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type) |
|
|
| def _infer_Shape(self, node): |
| self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0) |
|
|
| def _infer_Size(self, node): |
| sympy_shape = self._get_sympy_shape(node, 0) |
| self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) |
| self.known_vi_[node.output[0]].CopyFrom( |
| helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []) |
| ) |
|
|
| def _infer_Slice(self, node): |
| |
| |
| |
| |
| |
| |
| |
| def flatten_min(expr): |
| assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}" |
| min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)] |
| if len(min_positions) == 1: |
| min_pos = min_positions[0] |
|
|
| def replace_min_with_arg(arg_idx): |
| replaced = list(expr.args) |
| assert isinstance( |
| replaced[min_pos], sympy.Min |
| ), f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}" |
| assert ( |
| len(replaced[min_pos].args) == 2 |
| ), f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}" |
| replaced[min_pos] = replaced[min_pos].args[arg_idx] |
| return sympy.Add(*replaced) |
|
|
| return [ |
| replace_min_with_arg(0), |
| replace_min_with_arg(1), |
| ] |
| return [expr] |
|
|
| def less_equal(x, y): |
| try: |
| return bool(x <= y) |
| except TypeError: |
| pass |
| try: |
| return bool(y >= x) |
| except TypeError: |
| pass |
| try: |
| return bool(-x >= -y) |
| except TypeError: |
| pass |
| try: |
| return bool(-y <= -x) |
| except TypeError: |
| pass |
| try: |
| return bool(y - x >= 0) |
| except TypeError: |
| |
| return all(bool(d >= 0) for d in flatten_min(y - x)) |
|
|
| def handle_negative_index(index, bound): |
| """normalizes a negative index to be in [0, bound)""" |
| try: |
| if not less_equal(0, index): |
| if is_literal(index) and index <= -self.int_max_: |
| |
| return index |
| return bound + index |
| except TypeError: |
| logger.warning(f"Cannot determine if {index} < 0") |
| return index |
|
|
| if get_opset(self.out_mp_) <= 9: |
| axes = get_attribute(node, "axes") |
| starts = get_attribute(node, "starts") |
| ends = get_attribute(node, "ends") |
| if not axes: |
| axes = list(range(len(starts))) |
| steps = [1] * len(axes) |
| else: |
| starts = as_list(self._try_get_value(node, 1), keep_none=True) |
| ends = as_list(self._try_get_value(node, 2), keep_none=True) |
| axes = self._try_get_value(node, 3) |
| steps = self._try_get_value(node, 4) |
| if axes is None and not (starts is None and ends is None): |
| axes = list(range(0, len(starts if starts is not None else ends))) |
| if steps is None and not (starts is None and ends is None): |
| steps = [1] * len(starts if starts is not None else ends) |
| axes = as_list(axes, keep_none=True) |
| steps = as_list(steps, keep_none=True) |
|
|
| new_sympy_shape = self._get_sympy_shape(node, 0) |
| if starts is None or ends is None: |
| if axes is None: |
| for i in range(len(new_sympy_shape)): |
| new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) |
| else: |
| new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape) |
| for i in axes: |
| new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) |
| else: |
| for i, s, e, t in zip(axes, starts, ends, steps): |
| e = handle_negative_index(e, new_sympy_shape[i]) |
| if is_literal(e): |
| if e >= self.int_max_: |
| e = new_sympy_shape[i] |
| elif e <= -self.int_max_: |
| e = 0 if s > 0 else -1 |
| elif is_literal(new_sympy_shape[i]): |
| if e < 0: |
| e = max(0, e + new_sympy_shape[i]) |
| e = min(e, new_sympy_shape[i]) |
| else: |
| if e > 0: |
| e = ( |
| sympy.Min(e, new_sympy_shape[i]) if e > 1 else e |
| ) |
| else: |
| if is_literal(new_sympy_shape[i]): |
| e = sympy.Min(e, new_sympy_shape[i]) |
| else: |
| try: |
| if not less_equal(e, new_sympy_shape[i]): |
| e = new_sympy_shape[i] |
| except Exception: |
| logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal") |
| e = new_sympy_shape[i] |
|
|
| s = handle_negative_index(s, new_sympy_shape[i]) |
| if is_literal(new_sympy_shape[i]) and is_literal(s): |
| s = max(0, min(s, new_sympy_shape[i])) |
|
|
| new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t) |
|
|
| self._update_computed_dims(new_sympy_shape) |
|
|
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| vi.type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(new_sympy_shape), |
| ) |
| ) |
|
|
| |
| if ( |
| node.input[0] in self.sympy_data_ |
| and [0] == axes |
| and starts is not None |
| and len(starts) == 1 |
| and ends is not None |
| and len(ends) == 1 |
| and steps is not None |
| and len(steps) == 1 |
| ): |
| input_sympy_data = self.sympy_data_[node.input[0]] |
| if type(input_sympy_data) == list or ( |
| type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1 |
| ): |
| self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]] |
|
|
| def _infer_SoftmaxCrossEntropyLoss(self, node): |
| vi = self.known_vi_[node.output[0]] |
| elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
|
|
| |
| specified_output_type = get_attribute(node, "output_type", None) |
| if specified_output_type is not None: |
| elem_type = specified_output_type |
|
|
| vi.type.tensor_type.elem_type = elem_type |
| vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) |
|
|
| if len(node.output) > 1: |
| data_shape = self._get_shape(node, 0) |
| vi = self.known_vi_[node.output[1]] |
| vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape)) |
|
|
| def _infer_Split_Common(self, node, make_value_info_func): |
| input_sympy_shape = self._get_sympy_shape(node, 0) |
| axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) |
| op_set = get_opset(self.out_mp_) |
|
|
| |
| if op_set < 13: |
| split = get_attribute(node, "split") |
| assert self._try_get_value(node, 1) is None |
| else: |
| split = self._try_get_value(node, 1) |
| assert get_attribute(node, "split") is None |
|
|
| if split is None: |
| num_outputs = len(node.output) |
| split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs |
| self._update_computed_dims(split) |
| else: |
| split = [sympy.Integer(s) for s in split] |
|
|
| for i_o in range(len(split)): |
| vi = self.known_vi_[node.output[i_o]] |
| vi.CopyFrom( |
| make_value_info_func( |
| node.output[i_o], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1 :]), |
| ) |
| ) |
| self.known_vi_[vi.name] = vi |
|
|
| def _infer_Split(self, node): |
| self._infer_Split_Common(node, helper.make_tensor_value_info) |
|
|
| def _infer_SplitToSequence(self, node): |
| self._infer_Split_Common(node, helper.make_sequence_value_info) |
|
|
| def _infer_Squeeze(self, node): |
| input_shape = self._get_shape(node, 0) |
| op_set = get_opset(self.out_mp_) |
|
|
| |
| if op_set < 13: |
| axes = get_attribute(node, "axes") |
| assert self._try_get_value(node, 1) is None |
| else: |
| axes = self._try_get_value(node, 1) |
| assert get_attribute(node, "axes") is None |
|
|
| if axes is None: |
| |
| |
| |
| output_shape = [s for s in input_shape if s != 1] |
| if self.verbose_ > 0: |
| symbolic_dimensions = [s for s in input_shape if type(s) != int] |
| if len(symbolic_dimensions) > 0: |
| logger.debug( |
| f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " |
| f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}" |
| ) |
| else: |
| axes = [handle_negative_axis(a, len(input_shape)) for a in axes] |
| output_shape = [] |
| for i in range(len(input_shape)): |
| if i not in axes: |
| output_shape.append(input_shape[i]) |
| else: |
| assert input_shape[i] == 1 or type(input_shape[i]) != int |
| if self.verbose_ > 0 and type(input_shape[i]) != int: |
| logger.debug( |
| f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " |
| f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1." |
| ) |
|
|
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| output_shape, |
| ) |
| ) |
| self._pass_on_sympy_data(node) |
|
|
| def _infer_Tile(self, node): |
| repeats_value = self._try_get_value(node, 1) |
| new_sympy_shape = [] |
| if repeats_value is not None: |
| input_sympy_shape = self._get_sympy_shape(node, 0) |
| for i, d in enumerate(input_sympy_shape): |
| new_dim = d * repeats_value[i] |
| new_sympy_shape.append(new_dim) |
| self._update_computed_dims(new_sympy_shape) |
| else: |
| new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| vi.type.tensor_type.elem_type, |
| get_shape_from_sympy_shape(new_sympy_shape), |
| ) |
| ) |
|
|
| def _infer_TopK(self, node): |
| rank = self._get_shape_rank(node, 0) |
| axis = handle_negative_axis(get_attribute(node, "axis", -1), rank) |
| new_shape = self._get_shape(node, 0) |
|
|
| if get_opset(self.out_mp_) <= 9: |
| k = get_attribute(node, "k") |
| else: |
| k = self._get_int_or_float_values(node)[1] |
|
|
| if k is None: |
| k = self._new_symbolic_dim_from_output(node) |
| else: |
| k = as_scalar(k) |
|
|
| if type(k) in [int, str]: |
| new_shape[axis] = k |
| else: |
| new_sympy_shape = self._get_sympy_shape(node, 0) |
| new_sympy_shape[axis] = k |
| self._update_computed_dims( |
| new_sympy_shape |
| ) |
| new_shape = get_shape_from_sympy_shape(new_sympy_shape) |
|
|
| for i_o in range(len(node.output)): |
| vi = self.known_vi_[node.output[i_o]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape)) |
|
|
| def _infer_Transpose(self, node): |
| if node.input[0] in self.sympy_data_: |
| data_shape = self._get_shape(node, 0) |
| perm = get_attribute(node, "perm", reversed(list(range(len(data_shape))))) |
| input_data = self.sympy_data_[node.input[0]] |
| self.sympy_data_[node.output[0]] = ( |
| np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist() |
| ) |
|
|
| def _infer_Unsqueeze(self, node): |
| input_shape = self._get_shape(node, 0) |
| op_set = get_opset(self.out_mp_) |
|
|
| |
| if op_set < 13: |
| axes = get_attribute(node, "axes") |
| assert self._try_get_value(node, 1) is None |
| else: |
| axes = self._try_get_value(node, 1) |
| assert get_attribute(node, "axes") is None |
|
|
| output_rank = len(input_shape) + len(axes) |
| axes = [handle_negative_axis(a, output_rank) for a in axes] |
|
|
| input_axis = 0 |
| output_shape = [] |
| for i in range(output_rank): |
| if i in axes: |
| output_shape.append(1) |
| else: |
| output_shape.append(input_shape[input_axis]) |
| input_axis += 1 |
|
|
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| node.output[0], |
| self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| output_shape, |
| ) |
| ) |
|
|
| self._pass_on_sympy_data(node) |
|
|
| def _infer_ZipMap(self, node): |
| map_key_type = None |
| if get_attribute(node, "classlabels_int64s") is not None: |
| map_key_type = onnx.TensorProto.INT64 |
| elif get_attribute(node, "classlabels_strings") is not None: |
| map_key_type = onnx.TensorProto.STRING |
|
|
| assert map_key_type is not None |
| new_vi = onnx.ValueInfoProto() |
| new_vi.name = node.output[0] |
| new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT |
| new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(new_vi) |
|
|
| def _infer_Attention(self, node): |
| shape = self._get_shape(node, 0) |
| shape_weights = self._get_shape(node, 1) |
| shape_bias = self._try_get_shape(node, 2) |
| if shape_bias is not None: |
| assert len(shape_bias) == 1 |
| tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] |
| if shape and len(shape) == 3: |
| qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") |
| if qkv_hidden_sizes_attr is not None: |
| assert len(qkv_hidden_sizes_attr) == 3 |
| shape[2] = int(qkv_hidden_sizes_attr[2]) |
| elif isinstance(tripled_hidden_size, int): |
| shape[2] = int(tripled_hidden_size / 3) |
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) |
|
|
| if len(node.output) > 1: |
| |
| |
| |
| |
| input_shape = self._get_shape(node, 0) |
| past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else [] |
| mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else [] |
|
|
| if past_shape and len(past_shape) == 5: |
| if mask_shape and len(mask_shape) in [2, 3]: |
| past_shape[3] = mask_shape[-1] |
| elif input_shape and len(input_shape) == 3: |
| if isinstance(input_shape[1], int) and isinstance(past_shape[3], int): |
| past_shape[3] = input_shape[1] + past_shape[3] |
| else: |
| past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" |
| vi = self.known_vi_[node.output[1]] |
| vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) |
| |
| else: |
| num_heads = get_attribute(node, "num_heads") |
| head_size = input_shape[2] // num_heads |
| present_shape = [2, input_shape[0], num_heads, input_shape[1], head_size] |
| vi = self.known_vi_[node.output[1]] |
| vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) |
|
|
| def _infer_GatedRelativePositionBias(self, node): |
| |
| |
| |
| |
| |
| |
| |
| num_heads = get_attribute(node, "num_heads") |
|
|
| token_offset_shape = self._try_get_shape(node, 6) |
| if token_offset_shape is not None: |
| output_shape = [token_offset_shape[0], num_heads, token_offset_shape[1], token_offset_shape[1]] |
| else: |
| query_layer_shape = self._get_shape(node, 0) |
| assert query_layer_shape is not None and len(query_layer_shape) == 3 |
| output_shape = [query_layer_shape[0], num_heads, query_layer_shape[1], query_layer_shape[1]] |
|
|
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
|
|
| def _infer_PackedAttention(self, node): |
| shape = self._get_shape(node, 0) |
| shape_weights = self._get_shape(node, 1) |
| shape_bias = self._try_get_shape(node, 2) |
| if shape_bias is not None: |
| assert len(shape_bias) == 1 |
| tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] |
| if shape and len(shape) == 2: |
| qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") |
| if qkv_hidden_sizes_attr is not None: |
| assert len(qkv_hidden_sizes_attr) == 3 |
| shape[1] = int(qkv_hidden_sizes_attr[2]) |
| elif isinstance(tripled_hidden_size, int): |
| shape[1] = int(tripled_hidden_size / 3) |
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) |
|
|
| def _infer_PackedMultiHeadAttention(self, node): |
| shape_value = self._try_get_shape(node, 2) |
| if shape_value is not None and len(shape_value) == 2: |
| output_shape = shape_value |
| else: |
| shape_query = self._get_shape(node, 0) |
| assert shape_query is not None and len(shape_query) == 4 |
| output_shape = [shape_query[0], shape_query[1] * shape_query[3]] |
|
|
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
|
|
| def _infer_RemovePadding(self, node): |
| shape = self._get_shape(node, 0) |
| if shape and len(shape) == 3: |
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]])) |
|
|
| vi_token_offset = self.known_vi_[node.output[1]] |
| vi_token_offset.CopyFrom( |
| helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]]) |
| ) |
|
|
| vi_cumulated_seq_len = self.known_vi_[node.output[2]] |
| vi_cumulated_seq_len.CopyFrom( |
| helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"]) |
| ) |
|
|
| vi_max_seq_len = self.known_vi_[node.output[3]] |
| vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1])) |
|
|
| def _infer_RestorePadding(self, node): |
| shape_input = self._get_shape(node, 0) |
| shape_token_offset = self._get_shape(node, 1) |
| if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2: |
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
|
|
| output_shape = [shape_token_offset[0], shape_token_offset[1], shape_input[1]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
|
|
| def _infer_BiasGelu(self, node): |
| self._propagate_shape_and_type(node) |
|
|
| def _infer_MultiHeadAttention(self, node): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| query_shape = self._get_shape(node, 0) |
| total_sequence_length = None |
| output_dtype = None |
| if query_shape is not None: |
| if len(query_shape) == 3: |
| key_shape = self._try_get_shape(node, 1) |
| |
| output_shape = query_shape |
| if key_shape is not None and len(key_shape) == 3: |
| value_shape = self._try_get_shape(node, 2) |
| if value_shape is not None and len(value_shape) == 3: |
| output_shape[2] = value_shape[2] |
| total_sequence_length = key_shape[1] |
|
|
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
|
|
| elif len(query_shape) == 5: |
| if isinstance(query_shape[2], int) and isinstance(query_shape[4], int): |
| output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]] |
| else: |
| output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"] |
|
|
| total_sequence_length = query_shape[1] |
|
|
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
|
|
| if len(node.output) > 1: |
| batch_size = query_shape[0] |
| num_heads = get_attribute(node, "num_heads") |
|
|
| head_size = None |
| if len(query_shape) == 3: |
| head_size = ( |
| int(query_shape[2] / num_heads) |
| if isinstance(query_shape[2], int) |
| else f"{query_shape[2]}/{num_heads}" |
| ) |
| else: |
| head_size = query_shape[4] |
|
|
| past_shape = self._try_get_shape(node, 6) |
|
|
| if past_shape is not None: |
| if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int): |
| total_sequence_length = past_shape[2] + total_sequence_length |
| else: |
| total_sequence_length = f"{past_shape[2]}+{total_sequence_length}" |
|
|
| present_shape = [batch_size, num_heads, total_sequence_length, head_size] |
|
|
| assert output_dtype is not None |
| if len(node.output) > 2 and node.output[1] and node.output[2]: |
| vi = self.known_vi_[node.output[1]] |
| vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) |
| vi = self.known_vi_[node.output[2]] |
| vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) |
|
|
| def _infer_DecoderMaskedMultiHeadAttention(self, node): |
| |
| |
| |
| |
|
|
| query_shape = self._get_shape(node, 0) |
| if query_shape is not None: |
| output_shape = query_shape |
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| assert output_dtype is not None |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
|
|
| if len(node.output) > 2 and node.output[1] and node.output[2]: |
| past_shape = self._try_get_shape(node, 5) |
| if past_shape is not None: |
| vi = self.known_vi_[node.output[1]] |
| vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) |
| vi = self.known_vi_[node.output[2]] |
| vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) |
|
|
| def _infer_FastGelu(self, node): |
| self._propagate_shape_and_type(node) |
|
|
| def _infer_Gelu(self, node): |
| self._propagate_shape_and_type(node) |
|
|
| def _infer_QuickGelu(self, node): |
| self._propagate_shape_and_type(node) |
|
|
| def _infer_GemmFastGelu(self, node): |
| self._compute_matmul_shape(node) |
|
|
| def _infer_GemmFloat8(self, node): |
| self._compute_matmul_shape(node) |
|
|
| def _infer_LayerNormalization(self, node): |
| self._propagate_shape_and_type(node) |
| if len(node.output) > 1: |
| axis = get_attribute(node, "axis") |
| if axis is None: |
| axis = -1 |
| x_shape = self._get_shape(node, 0) |
| if x_shape is not None: |
| rank = len(x_shape) |
| axis = handle_negative_axis(axis, rank) |
| mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)] |
| mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| if mean_dtype == onnx.TensorProto.FLOAT16 or mean_dtype == onnx.TensorProto.BFLOAT16: |
| mean_dtype = onnx.TensorProto.FLOAT |
| vi = self.known_vi_[node.output[1]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape)) |
| if len(node.output) > 2: |
| vi = self.known_vi_[node.output[2]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape)) |
|
|
| def _infer_LongformerAttention(self, node): |
| self._propagate_shape_and_type(node) |
|
|
| def _infer_EmbedLayerNormalization(self, node): |
| input_ids_shape = self._get_shape(node, 0) |
| word_embedding_shape = self._get_shape(node, 2) |
| assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2 |
| output_shape = [*input_ids_shape, word_embedding_shape[1]] |
|
|
| word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape)) |
|
|
| if len(node.output) > 1 and node.output[1]: |
| mask_index_shape = [input_ids_shape[0]] |
| vi = self.known_vi_[node.output[1]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape)) |
|
|
| if len(node.output) > 2: |
| |
| |
| vi = self.known_vi_[node.output[2]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape)) |
|
|
| def _infer_SkipLayerNormalization(self, node): |
| self._propagate_shape_and_type(node) |
|
|
| |
| |
| if len(node.output) > 3: |
| self._propagate_shape_and_type(node, 0, 3) |
|
|
| def _infer_GroupNorm(self, node): |
| self._propagate_shape_and_type(node) |
|
|
| def _infer_PagedAttention(self, node): |
| self._propagate_shape_and_type(node) |
|
|
| def _infer_GroupQueryAttention(self, node): |
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
|
|
| past_shape = self._try_get_shape(node, 3) |
| if past_shape is not None: |
| vi = self.known_vi_[node.output[1]] |
| vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) |
| vi = self.known_vi_[node.output[2]] |
| vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) |
|
|
| if node.input[1] != "" and node.input[2] != "": |
| self._propagate_shape_and_type(node, 0, 0) |
| else: |
| |
| assert node.input[1] == "" and node.input[2] == "" |
| num_heads = get_attribute(node, "num_heads") |
| kv_num_heads = get_attribute(node, "kv_num_heads") |
| query_shape = self._get_shape(node, 0) |
| if query_shape is not None: |
| hidden_size = query_shape[2] |
| if isinstance(hidden_size, int): |
| head_size = int(hidden_size / (num_heads + 2 * kv_num_heads)) |
| query_shape[2] = num_heads * head_size |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape)) |
|
|
| def _infer_SkipGroupNorm(self, node): |
| self._propagate_shape_and_type(node, 0, 0) |
| if len(node.output) > 1: |
| self._propagate_shape_and_type(node, 0, 1) |
|
|
| def _infer_BiasSplitGelu(self, node): |
| input_shape = self._get_shape(node, 0) |
| bias_shape = self._get_shape(node, 1) |
| if input_shape and bias_shape and isinstance(bias_shape[0], int): |
| output_shape = input_shape |
| output_shape[2] = int(bias_shape[0] / 2) |
| vi = self.known_vi_[node.output[0]] |
| output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) |
|
|
| def _infer_BiasAdd(self, node): |
| self._propagate_shape_and_type(node) |
|
|
| def _infer_RotaryEmbedding(self, node): |
| if len(node.output) == 1: |
| self._propagate_shape_and_type(node) |
| elif len(node.output) == 2: |
| |
| self._propagate_shape_and_type(node, input_index=1, output_index=0) |
| self._propagate_shape_and_type(node, input_index=0, output_index=1) |
| elif len(node.output) == 3: |
| |
| self._propagate_shape_and_type(node, input_index=1, output_index=0) |
| self._propagate_shape_and_type(node, input_index=1, output_index=1) |
| self._propagate_shape_and_type(node, input_index=0, output_index=2) |
|
|
| def _infer_PythonOp(self, node): |
| output_tensor_types = get_attribute(node, "output_tensor_types") |
| assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute." |
| output_tensor_ranks = get_attribute(node, "output_tensor_ranks") |
| assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute." |
|
|
| from onnxruntime.capi._pybind_state import get_shape_inference_function |
|
|
| func_name = get_attribute(node, "func_name").decode() |
| shape_inferer = get_shape_inference_function(func_name) |
|
|
| |
| |
| vi = self.known_vi_[node.output[0]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) |
|
|
| if shape_inferer is not None: |
| input_shapes = [] |
| input_dtypes = [] |
| for input_index in range(len(node.input)): |
| shape = self._get_shape(node, input_index) |
| input_shapes.append(shape) |
| input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type |
| input_dtypes.append(input_dtype) |
| output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes) |
| assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), ( |
| f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, " |
| f"but expected {len(node.output) - 1} outputs." |
| ) |
| for i in range(len(node.output) - 1): |
| output_index = i + 1 |
| vi = self.known_vi_[node.output[output_index]] |
| vi.CopyFrom( |
| helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i]) |
| ) |
| else: |
| |
| |
| |
| for i in range(len(node.output) - 1): |
| |
| vi = self.known_vi_[node.output[i + 1]] |
| sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) |
| shape = get_shape_from_sympy_shape(sympy_shape) |
| value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape) |
| vi.CopyFrom(value_info) |
|
|
| def _propagate_shape_and_type(self, node, input_index=0, output_index=0): |
| shape = self._get_shape(node, input_index) |
| output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type |
| vi = self.known_vi_[node.output[output_index]] |
| vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape)) |
|
|
| def _is_none_dim(self, dim_value): |
| if type(dim_value) != str: |
| return False |
| if "unk__" not in dim_value: |
| return False |
| if dim_value in self.symbolic_dims_: |
| return False |
| return True |
|
|
| def _is_shape_contains_none_dim(self, out_shape): |
| for out in out_shape: |
| if self._is_none_dim(out): |
| return out |
| return None |
|
|
| def _infer_impl(self, start_sympy_data=None): |
| self.sympy_data_ = start_sympy_data or {} |
| self.out_mp_.graph.ClearField("value_info") |
| self._apply_suggested_merge(graph_input_only=True) |
| self.input_symbols_ = set() |
| for i in self.out_mp_.graph.input: |
| input_shape = get_shape_from_value_info(i) |
| if input_shape is None: |
| continue |
|
|
| if is_sequence(i.type): |
| input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim |
| else: |
| input_dims = i.type.tensor_type.shape.dim |
|
|
| for i_dim, dim in enumerate(input_shape): |
| if dim is None: |
| |
| input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim)) |
|
|
| self.input_symbols_.update([d for d in input_shape if type(d) == str]) |
|
|
| for s in self.input_symbols_: |
| if s in self.suggested_merge_: |
| s_merge = self.suggested_merge_[s] |
| assert s_merge in self.symbolic_dims_ |
| self.symbolic_dims_[s] = self.symbolic_dims_[s_merge] |
| else: |
| |
| self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True) |
| |
| |
| |
| self.tmp_mp_ = onnx.ModelProto() |
| self.tmp_mp_.CopyFrom(self.out_mp_) |
| self.tmp_mp_.graph.ClearField("initializer") |
|
|
| |
| |
| prereq_for_node = {} |
|
|
| def get_prereq(node): |
| names = {i for i in node.input if i} |
| subgraphs = [] |
| if node.op_type == "If": |
| subgraphs = [ |
| get_attribute(node, "then_branch"), |
| get_attribute(node, "else_branch"), |
| ] |
| elif node.op_type in ["Loop", "Scan"]: |
| subgraphs = [get_attribute(node, "body")] |
| for g in subgraphs: |
| g_outputs_and_initializers = {i.name for i in g.initializer} |
| g_prereq = set() |
| for n in g.node: |
| g_outputs_and_initializers.update(n.output) |
| for n in g.node: |
| g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers]) |
| names.update(g_prereq) |
| |
| for i in g.input: |
| if i.name in names: |
| names.remove(i.name) |
| return names |
|
|
| for n in self.tmp_mp_.graph.node: |
| prereq_for_node[n.output[0]] = get_prereq(n) |
|
|
| |
| sorted_nodes = [] |
| sorted_known_vi = {i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)} |
| if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): |
| |
| sorted_nodes = self.out_mp_.graph.node |
| else: |
| while not all([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): |
| old_sorted_nodes_len = len(sorted_nodes) |
| for node in self.out_mp_.graph.node: |
| if (node.output[0] not in sorted_known_vi) and all( |
| [i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i] |
| ): |
| sorted_known_vi.update(node.output) |
| sorted_nodes.append(node) |
| if old_sorted_nodes_len == len(sorted_nodes) and not all( |
| [o.name in sorted_known_vi for o in self.out_mp_.graph.output] |
| ): |
| raise Exception("Invalid model with cyclic graph") |
|
|
| for node in sorted_nodes: |
| assert all([i in self.known_vi_ for i in node.input if i]) |
| self._onnx_infer_single_node(node) |
| known_aten_op = False |
| if node.op_type in self.dispatcher_: |
| self.dispatcher_[node.op_type](node) |
| elif node.op_type in ["ConvTranspose"]: |
| |
| |
| |
| vi = self.known_vi_[node.output[0]] |
| if len(vi.type.tensor_type.shape.dim) == 0: |
| vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED |
| elif node.op_type == "ATen" and node.domain == "org.pytorch.aten": |
| for attr in node.attribute: |
| |
| if attr.name == "operator": |
| aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s |
| if aten_op_name in self.aten_op_dispatcher_: |
| known_aten_op = True |
| self.aten_op_dispatcher_[aten_op_name](node) |
| break |
|
|
| if self.verbose_ > 2: |
| logger.debug(node.op_type + ": " + node.name) |
| for i, name in enumerate(node.input): |
| logger.debug( |
| " Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else "") |
| ) |
|
|
| |
| |
| if node.op_type in [ |
| "Add", |
| "Sub", |
| "Mul", |
| "Div", |
| "MatMul", |
| "MatMulInteger", |
| "MatMulInteger16", |
| "Where", |
| "Sum", |
| ]: |
| vi = self.known_vi_[node.output[0]] |
| out_rank = len(get_shape_from_type_proto(vi.type)) |
| in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] |
| for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): |
| in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] |
| if len(in_dims) > 1: |
| self._check_merged_dims(in_dims, allow_broadcast=True) |
|
|
| for i_o in range(len(node.output)): |
| |
| |
| |
| |
| |
| if ( |
| node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization" |
| ) and i_o in [1, 2]: |
| continue |
| if node.op_type == "RotaryEmbedding" and len(node.output) > 1: |
| |
| |
| continue |
|
|
| vi = self.known_vi_[node.output[i_o]] |
| out_type = vi.type |
| out_type_kind = out_type.WhichOneof("value") |
|
|
| |
| if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]: |
| if self.verbose_ > 2: |
| if out_type_kind == "sequence_type": |
| seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value") |
| if seq_cls_type == "tensor_type": |
| logger.debug( |
| " {}: sequence of {} {}".format( |
| node.output[i_o], |
| str(get_shape_from_value_info(vi)), |
| onnx.TensorProto.DataType.Name( |
| vi.type.sequence_type.elem_type.tensor_type.elem_type |
| ), |
| ) |
| ) |
| else: |
| logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}") |
| else: |
| logger.debug(f" {node.output[i_o]}: {out_type_kind}") |
| continue |
|
|
| out_shape = get_shape_from_value_info(vi) |
| out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED |
| if self.verbose_ > 2: |
| logger.debug( |
| " {}: {} {}".format( |
| node.output[i_o], |
| str(out_shape), |
| onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type), |
| ) |
| ) |
| if node.output[i_o] in self.sympy_data_: |
| logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) |
|
|
| |
| if ( |
| out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape)) |
| ) or out_type_undefined: |
| if self.auto_merge_: |
| if node.op_type in [ |
| "Add", |
| "Sub", |
| "Mul", |
| "Div", |
| "MatMul", |
| "MatMulInteger", |
| "MatMulInteger16", |
| "Concat", |
| "Where", |
| "Sum", |
| "Equal", |
| "Less", |
| "Greater", |
| "LessOrEqual", |
| "GreaterOrEqual", |
| "Min", |
| "Max", |
| ]: |
| shapes = [self._get_shape(node, i) for i in range(len(node.input))] |
| if node.op_type in [ |
| "MatMul", |
| "MatMulInteger", |
| "MatMulInteger16", |
| ]: |
| if None in out_shape or self._is_shape_contains_none_dim(out_shape): |
| if None in out_shape: |
| idx = out_shape.index(None) |
| else: |
| idx = out_shape.index(self._is_shape_contains_none_dim(out_shape)) |
| dim_idx = [len(s) - len(out_shape) + idx for s in shapes] |
| |
| assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 |
| assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2 |
| elif node.op_type == "Expand": |
| |
| shapes = [ |
| self._get_shape(node, 0), |
| self._get_value(node, 1), |
| ] |
| else: |
| shapes = [] |
|
|
| if shapes: |
| for idx in range(len(out_shape)): |
| if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]): |
| continue |
| |
| |
| dim_idx = [len(s) - len(out_shape) + idx for s in shapes] |
| if len(dim_idx) > 0: |
| self._add_suggested_merge( |
| [ |
| s[i] if is_literal(s[i]) else str(s[i]) |
| for s, i in zip(shapes, dim_idx) |
| if i >= 0 |
| ] |
| ) |
| self.run_ = True |
| else: |
| self.run_ = False |
| else: |
| self.run_ = False |
|
|
| |
| if self.run_ is False and node.op_type not in self.dispatcher_ and not known_aten_op: |
| is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0) |
| if is_unknown_op: |
| |
| |
| out_rank = self._get_shape_rank(node, 0) if self.guess_output_rank_ else -1 |
| else: |
| |
| out_rank = len(out_shape) |
|
|
| if out_rank >= 0: |
| new_shape = self._new_symbolic_shape(out_rank, node, i_o) |
| if out_type_undefined: |
| |
| out_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| else: |
| |
| out_dtype = vi.type.tensor_type.elem_type |
| vi.CopyFrom( |
| helper.make_tensor_value_info( |
| vi.name, |
| out_dtype, |
| get_shape_from_sympy_shape(new_shape), |
| ) |
| ) |
|
|
| if self.verbose_ > 0: |
| if is_unknown_op: |
| logger.debug( |
| "Possible unknown op: {} node: {}, guessing {} shape".format( |
| node.op_type, node.name, vi.name |
| ) |
| ) |
| if self.verbose_ > 2: |
| logger.debug( |
| " {}: {} {}".format( |
| node.output[i_o], |
| str(new_shape), |
| vi.type.tensor_type.elem_type, |
| ) |
| ) |
|
|
| self.run_ = True |
| continue |
|
|
| if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: |
| logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name) |
| logger.debug("node inputs:") |
| for i in node.input: |
| if i in self.known_vi_: |
| logger.debug(self.known_vi_[i]) |
| else: |
| logger.debug(f"not in known_vi_ for {i}") |
| logger.debug("node outputs:") |
| for o in node.output: |
| if o in self.known_vi_: |
| logger.debug(self.known_vi_[o]) |
| else: |
| logger.debug(f"not in known_vi_ for {o}") |
| if self.auto_merge_ and not out_type_undefined: |
| logger.debug("Merging: " + str(self.suggested_merge_)) |
| return False |
|
|
| self.run_ = False |
| return True |
|
|
| def _update_output_from_vi(self): |
| for output in self.out_mp_.graph.output: |
| if output.name in self.known_vi_: |
| output.CopyFrom(self.known_vi_[output.name]) |
|
|
| @staticmethod |
| def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0): |
| onnx_opset = get_opset(in_mp) |
| if (not onnx_opset) or onnx_opset < 7: |
| logger.warning("Only support models of onnx opset 7 and above.") |
| return None |
| symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose) |
| all_shapes_inferred = False |
| symbolic_shape_inference._preprocess(in_mp) |
| while symbolic_shape_inference.run_: |
| all_shapes_inferred = symbolic_shape_inference._infer_impl() |
| symbolic_shape_inference._update_output_from_vi() |
| if not all_shapes_inferred: |
| onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True) |
| raise Exception("Incomplete symbolic shape inference") |
| return symbolic_shape_inference.out_mp_ |
|
|
|
|
| def parse_arguments(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input", required=True, help="The input model file") |
| parser.add_argument("--output", help="The output model file") |
| parser.add_argument( |
| "--auto_merge", |
| help="Automatically merge symbolic dims when confliction happens", |
| action="store_true", |
| default=False, |
| ) |
| parser.add_argument( |
| "--int_max", |
| help="maximum value for integer to be treated as boundless for ops like slice", |
| type=int, |
| default=2**31 - 1, |
| ) |
| parser.add_argument( |
| "--guess_output_rank", |
| help="guess output rank to be the same as input 0 for unknown ops", |
| action="store_true", |
| default=False, |
| ) |
| parser.add_argument( |
| "--verbose", |
| help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed", |
| type=int, |
| default=0, |
| ) |
| parser.add_argument( |
| "--save_as_external_data", |
| help="Saving an ONNX model to external data", |
| action="store_true", |
| default=False, |
| ) |
| parser.add_argument( |
| "--all_tensors_to_one_file", |
| help="Saving all the external data to one file", |
| action="store_true", |
| default=False, |
| ) |
| parser.add_argument( |
| "--external_data_location", |
| help="The file location to save the external file", |
| default="./", |
| ) |
| parser.add_argument( |
| "--external_data_size_threshold", |
| help="The size threshold for external data", |
| type=int, |
| default=1024, |
| ) |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_arguments() |
| logger.info("input model: " + args.input) |
| if args.output: |
| logger.info("output model " + args.output) |
| logger.info("Doing symbolic shape inference...") |
| out_mp = SymbolicShapeInference.infer_shapes( |
| onnx.load(args.input), |
| args.int_max, |
| args.auto_merge, |
| args.guess_output_rank, |
| args.verbose, |
| ) |
| if args.output and out_mp: |
| if args.save_as_external_data: |
| onnx.save_model( |
| out_mp, |
| args.output, |
| save_as_external_data=True, |
| all_tensors_to_one_file=args.all_tensors_to_one_file, |
| location=args.external_data_location, |
| size_threshold=args.external_data_size_threshold, |
| convert_attribute=False, |
| ) |
| else: |
| onnx.save(out_mp, args.output) |
| logger.info("Done!") |
|
|