1
1
from typing import *
2
+ from warnings import warn
3
+
4
+ import numpy as np
2
5
3
6
from .features ._base import cleanup_slice
4
7
@@ -92,10 +95,19 @@ def _set_feature(self, feature: str, new_data: Any, indices: Any):
92
95
def _reset_feature (self , feature : str ):
93
96
pass
94
97
95
- def link (self , event_type : str , target : Any , feature : str , new_data : Any , callback_function : callable = None ):
98
+ def link (
99
+ self ,
100
+ event_type : str ,
101
+ target : Any ,
102
+ feature : str ,
103
+ new_data : Any ,
104
+ callback : callable = None ,
105
+ bidirectional : bool = False
106
+ ):
96
107
if event_type in self .pygfx_events :
97
108
self .world_object .add_event_handler (self .event_handler , event_type )
98
109
110
+ # make sure event is valid
99
111
elif event_type in self .feature_events :
100
112
if isinstance (self , GraphicCollection ):
101
113
feature_instance = getattr (self [:], event_type )
@@ -105,15 +117,35 @@ def link(self, event_type: str, target: Any, feature: str, new_data: Any, callba
105
117
feature_instance .add_event_handler (self .event_handler )
106
118
107
119
else :
108
- raise ValueError (" event not possible " )
120
+ raise ValueError (f"Invalid event, valid events are: { self . pygfx_events + self . feature_events } " )
109
121
110
- if event_type in self .registered_callbacks .keys ():
111
- self .registered_callbacks [event_type ].append (
112
- CallbackData (target = target , feature = feature , new_data = new_data , callback_function = callback_function ))
113
- else :
122
+ # make sure target feature is valid
123
+ if feature is not None :
124
+ if feature not in target .feature_events :
125
+ raise ValueError (f"Invalid feature for target, valid features are: { target .feature_events } " )
126
+
127
+ if event_type not in self .registered_callbacks .keys ():
114
128
self .registered_callbacks [event_type ] = list ()
115
- self .registered_callbacks [event_type ].append (
116
- CallbackData (target = target , feature = feature , new_data = new_data , callback_function = callback_function ))
129
+
130
+ callback_data = CallbackData (target = target , feature = feature , new_data = new_data , callback_function = callback )
131
+
132
+ for existing_callback_data in self .registered_callbacks [event_type ]:
133
+ if existing_callback_data == callback_data :
134
+ warn ("linkage already exists for given event, target, and data, skipping" )
135
+ return
136
+
137
+ self .registered_callbacks [event_type ].append (callback_data )
138
+
139
+ if bidirectional :
140
+ target .link (
141
+ event_type = event_type ,
142
+ target = self ,
143
+ feature = feature ,
144
+ new_data = new_data ,
145
+ callback = callback ,
146
+ bidirectional = False # else infinite recursion, otherwise target will call
147
+ # this instance .link(), and then it will happen again etc.
148
+ )
117
149
118
150
def event_handler (self , event ):
119
151
if event .type in self .registered_callbacks .keys ():
@@ -145,6 +177,28 @@ class CallbackData:
145
177
new_data : Any
146
178
callback_function : callable = None
147
179
180
+ def __eq__ (self , other ):
181
+ if not isinstance (other , CallbackData ):
182
+ raise TypeError ("Can only compare against other <CallbackData> types" )
183
+
184
+ if other .target is not self .target :
185
+ return False
186
+
187
+ if not other .feature == self .feature :
188
+ return False
189
+
190
+ if not other .new_data == self .new_data :
191
+ return False
192
+
193
+ if (self .callback_function is None ) and (other .callback_function is None ):
194
+ return True
195
+
196
+ if other .callback_function is self .callback_function :
197
+ return True
198
+
199
+ else :
200
+ return False
201
+
148
202
149
203
@dataclass
150
204
class PreviouslyModifiedData :
@@ -156,10 +210,6 @@ class PreviouslyModifiedData:
156
210
class GraphicCollection (Graphic ):
157
211
"""Graphic Collection base class"""
158
212
159
- pygfx_events = [
160
- "click"
161
- ]
162
-
163
213
def __init__ (self , name : str = None ):
164
214
super (GraphicCollection , self ).__init__ (name )
165
215
self ._items : List [Graphic ] = list ()
@@ -207,14 +257,21 @@ def __getitem__(self, key):
207
257
selection = self ._items [key ]
208
258
209
259
# fancy-ish indexing
210
- elif isinstance (key , (tuple , list )):
260
+ elif isinstance (key , (tuple , list , np .ndarray )):
261
+ if isinstance (key , np .ndarray ):
262
+ if not key .ndim == 1 :
263
+ raise TypeError (f"{ self .__class__ .__name__ } indexing supports "
264
+ f"1D numpy arrays, int, slice, tuple or list of integers, "
265
+ f"your numpy arrays has <{ key .ndim } > dimensions." )
211
266
selection = list ()
267
+
212
268
for ix in key :
213
269
selection .append (self ._items [ix ])
214
270
215
271
selection_indices = key
216
272
else :
217
- raise TypeError (f"Graphic Collection indexing supports int, slice, tuple or list of integers, "
273
+ raise TypeError (f"{ self .__class__ .__name__ } indexing supports "
274
+ f"1D numpy arrays, int, slice, tuple or list of integers, "
218
275
f"you have passed a <{ type (key )} >" )
219
276
220
277
return CollectionIndexer (
0 commit comments