NicheC commited on
Commit
c6c6312
·
verified ·
1 Parent(s): 3148df0

Upload unitary_protocol.py

Browse files
Files changed (1) hide show
  1. unitary_protocol.py +201 -0
unitary_protocol.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Cirq Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from types import NotImplementedType
16
+ from typing import Any, Optional, TypeVar, Union
17
+
18
+ import numpy as np
19
+ from typing_extensions import Protocol
20
+
21
+ from cirq._doc import doc_private
22
+ from cirq.protocols import qid_shape_protocol
23
+ from cirq.protocols.apply_unitary_protocol import apply_unitaries, ApplyUnitaryArgs
24
+ from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
25
+
26
+ # This is a special indicator value used by the unitary method to determine
27
+ # whether or not the caller provided a 'default' argument. It must be of type
28
+ # np.ndarray to ensure the method has the correct type signature in that case.
29
+ # It is checked for using `is`, so it won't have a false positive if the user
30
+ # provides a different np.array([]) value.
31
+ RaiseTypeErrorIfNotProvided: np.ndarray = np.array([])
32
+
33
+ TDefault = TypeVar('TDefault')
34
+
35
+
36
+ class SupportsUnitary(Protocol):
37
+ """An object that may be describable by a unitary matrix."""
38
+
39
+ @doc_private
40
+ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
41
+ """A unitary matrix describing this value, e.g. the matrix of a gate.
42
+
43
+ This method is used by the global `cirq.unitary` method. If this method
44
+ is not present, or returns NotImplemented, it is assumed that the
45
+ receiving object doesn't have a unitary matrix (resulting in a TypeError
46
+ or default result when calling `cirq.unitary` on it). (The ability to
47
+ return NotImplemented is useful when a class cannot know if it has a
48
+ matrix until runtime, e.g. cirq.X**c normally has a matrix but
49
+ cirq.X**sympy.Symbol('a') doesn't.)
50
+
51
+ The order of cells in the matrix is always implicit with respect to the
52
+ object being called. For example, for gates the matrix must be ordered
53
+ with respect to the list of qubits that the gate is applied to. For
54
+ operations, the matrix is ordered to match the list returned by its
55
+ `qubits` attribute. The qubit-to-amplitude order mapping matches the
56
+ ordering of numpy.kron(A, B), where A is a qubit earlier in the list
57
+ than the qubit B.
58
+
59
+ Returns:
60
+ A unitary matrix describing this value, or NotImplemented if there
61
+ is no such matrix.
62
+ """
63
+
64
+ @doc_private
65
+ def _has_unitary_(self) -> bool:
66
+ """Whether this value has a unitary matrix representation.
67
+
68
+ This method is used by the global `cirq.has_unitary` method. If this
69
+ method is not present, or returns NotImplemented, it will fallback
70
+ to using _unitary_ with a default value, or False if neither exist.
71
+
72
+ Returns:
73
+ True if the value has a unitary matrix representation, False
74
+ otherwise.
75
+ """
76
+
77
+
78
+ def unitary(
79
+ val: Any, default: Union[np.ndarray, TDefault] = RaiseTypeErrorIfNotProvided
80
+ ) -> Union[np.ndarray, TDefault]:
81
+ """Returns a unitary matrix describing the given value.
82
+
83
+ The matrix is determined by any one of the following techniques:
84
+
85
+ - The value has a `_unitary_` method that returns something besides None or
86
+ NotImplemented. The matrix is whatever the method returned.
87
+ - The value has a `_decompose_` method that returns a list of operations,
88
+ and each operation in the list has a unitary effect. The matrix is
89
+ created by aggregating the sub-operations' unitary effects.
90
+ - The value has an `_apply_unitary_` method, and it returns something
91
+ besides None or NotImplemented. The matrix is created by applying
92
+ `_apply_unitary_` to an identity matrix.
93
+
94
+ If none of these techniques succeeds, it is assumed that `val` doesn't have
95
+ a unitary effect. The order in which techniques are attempted is
96
+ unspecified.
97
+
98
+ Args:
99
+ val: The value to describe with a unitary matrix.
100
+ default: Determines the fallback behavior when `val` doesn't have
101
+ a unitary effect. If `default` is not set, a TypeError is raised. If
102
+ `default` is set to a value, that value is returned.
103
+
104
+ Returns:
105
+ If `val` has a unitary effect, the corresponding unitary matrix.
106
+ Otherwise, if `default` is specified, it is returned.
107
+
108
+ Raises:
109
+ TypeError: `val` doesn't have a unitary effect and no default value was
110
+ specified.
111
+ """
112
+ strats = [
113
+ _strat_unitary_from_unitary,
114
+ _strat_unitary_from_apply_unitary,
115
+ _strat_unitary_from_decompose,
116
+ ]
117
+ for strat in strats:
118
+ result = strat(val)
119
+ if result is None:
120
+ break
121
+ if result is not NotImplemented:
122
+ return result
123
+
124
+ if default is not RaiseTypeErrorIfNotProvided:
125
+ return default
126
+ raise TypeError(
127
+ "cirq.unitary failed. "
128
+ "Value doesn't have a (non-parameterized) unitary effect.\n"
129
+ "\n"
130
+ f"type: {type(val)}\n"
131
+ f"value: {val!r}\n"
132
+ "\n"
133
+ "The value failed to satisfy any of the following criteria:\n"
134
+ "- A `_unitary_(self)` method that returned a value "
135
+ "besides None or NotImplemented.\n"
136
+ "- A `_decompose_(self)` method that returned a "
137
+ "list of unitary operations.\n"
138
+ "- An `_apply_unitary_(self, args) method that returned a value "
139
+ "besides None or NotImplemented."
140
+ )
141
+
142
+
143
+ def _strat_unitary_from_unitary(val: Any) -> Optional[np.ndarray]:
144
+ """Attempts to compute a value's unitary via its _unitary_ method."""
145
+ getter = getattr(val, '_unitary_', None)
146
+ if getter is None:
147
+ return NotImplemented
148
+ return getter()
149
+
150
+
151
+ def _strat_unitary_from_apply_unitary(val: Any) -> Optional[np.ndarray]:
152
+ """Attempts to compute a value's unitary via its _apply_unitary_ method."""
153
+ # Check for the magic method.
154
+ method = getattr(val, '_apply_unitary_', None)
155
+ if method is None:
156
+ return NotImplemented
157
+
158
+ # Get the qid_shape.
159
+ val_qid_shape = qid_shape_protocol.qid_shape(val, None)
160
+ if val_qid_shape is None:
161
+ return NotImplemented
162
+
163
+ # Apply unitary effect to an identity matrix.
164
+ result = method(ApplyUnitaryArgs.for_unitary(qid_shape=val_qid_shape))
165
+
166
+ if result is NotImplemented or result is None:
167
+ return result
168
+ state_len = np.prod(val_qid_shape, dtype=np.int64)
169
+ return result.reshape((state_len, state_len))
170
+
171
+
172
+ def _strat_unitary_from_decompose(val: Any) -> Optional[np.ndarray]:
173
+ """Attempts to compute a value's unitary via its _decompose_ method."""
174
+ # Check if there's a decomposition.
175
+ operations, qubits, val_qid_shape = _try_decompose_into_operations_and_qubits(val)
176
+ if operations is None:
177
+ return NotImplemented
178
+
179
+ all_qubits = frozenset(q for op in operations for q in op.qubits)
180
+ work_qubits = frozenset(qubits)
181
+ ancillas = tuple(sorted(all_qubits.difference(work_qubits)))
182
+
183
+ ordered_qubits = ancillas + tuple(qubits)
184
+ val_qid_shape = qid_shape_protocol.qid_shape(ancillas) + val_qid_shape
185
+
186
+ # Apply sub-operations' unitary effects to an identity matrix.
187
+ result = apply_unitaries(
188
+ operations, ordered_qubits, ApplyUnitaryArgs.for_unitary(qid_shape=val_qid_shape), None
189
+ )
190
+
191
+ # Package result.
192
+ if result is None:
193
+ return None
194
+
195
+ state_len = np.prod(val_qid_shape, dtype=np.int64)
196
+ result = result.reshape((state_len, state_len))
197
+ # Assuming borrowable qubits are restored to their original state and
198
+ # clean qubits restord to the zero state then the desired unitary is
199
+ # the upper left square.
200
+ work_state_len = np.prod(val_qid_shape[len(ancillas) :], dtype=np.int64)
201
+ return result[:work_state_len, :work_state_len]