5
5
Dict ,
6
6
Iterable ,
7
7
MutableMapping ,
8
+ Optional ,
8
9
Type ,
9
10
TypeVar ,
10
11
Union ,
@@ -37,9 +38,9 @@ class ColumnStorage:
37
38
38
39
def __init__ (
39
40
self ,
40
- tensor_columns : Dict [str , AbstractTensor ],
41
- doc_columns : Dict [str , 'DocVec' ],
42
- docs_vec_columns : Dict [str , ListAdvancedIndexing ['DocVec' ]],
41
+ tensor_columns : Dict [str , Optional [ AbstractTensor ] ],
42
+ doc_columns : Dict [str , Optional [ 'DocVec' ] ],
43
+ docs_vec_columns : Dict [str , Optional [ ListAdvancedIndexing ['DocVec' ] ]],
43
44
any_columns : Dict [str , ListAdvancedIndexing ],
44
45
tensor_type : Type [AbstractTensor ] = NdArray ,
45
46
):
@@ -63,12 +64,22 @@ def __len__(self) -> int:
63
64
def __getitem__ (self : T , item : IndexIterType ) -> T :
64
65
if isinstance (item , tuple ):
65
66
item = list (item )
66
- tensor_columns = {key : col [item ] for key , col in self .tensor_columns .items ()}
67
- doc_columns = {key : col [item ] for key , col in self .doc_columns .items ()}
67
+ tensor_columns = {
68
+ key : col [item ] if col is not None else None
69
+ for key , col in self .tensor_columns .items ()
70
+ }
71
+ doc_columns = {
72
+ key : col [item ] if col is not None else None
73
+ for key , col in self .doc_columns .items ()
74
+ }
68
75
docs_vec_columns = {
69
- key : col [item ] for key , col in self .docs_vec_columns .items ()
76
+ key : col [item ] if col is not None else None
77
+ for key , col in self .docs_vec_columns .items ()
78
+ }
79
+ any_columns = {
80
+ key : col [item ] if col is not None else None
81
+ for key , col in self .any_columns .items ()
70
82
}
71
- any_columns = {key : col [item ] for key , col in self .any_columns .items ()}
72
83
73
84
return self .__class__ (
74
85
tensor_columns ,
@@ -91,15 +102,34 @@ def __init__(self, index: int, storage: ColumnStorage):
91
102
def __getitem__ (self , name : str ) -> Any :
92
103
if name in self .storage .tensor_columns .keys ():
93
104
tensor = self .storage .tensor_columns [name ]
105
+ if tensor is None :
106
+ return None
94
107
if tensor .get_comp_backend ().n_dim (tensor ) == 1 :
95
108
# to ensure consistensy between numpy and pytorch
96
109
# we wrap the scalr in a tensor of ndim = 1
97
110
# otherwise numpy pass by value whereas torch by reference
98
- return self .storage .tensor_columns [name ][ self . index : self . index + 1 ]
111
+ col = self .storage .tensor_columns [name ]
99
112
100
- return self .storage .columns [name ][self .index ]
113
+ if col is not None :
114
+ return col [self .index : self .index + 1 ]
115
+ else :
116
+ return None
117
+
118
+ col = self .storage .columns [name ]
119
+
120
+ if col is None :
121
+ return None
122
+ return col [self .index ]
101
123
102
124
def __setitem__ (self , name , value ) -> None :
125
+ if self .storage .columns [name ] is None :
126
+ raise ValueError (
127
+ f'Cannot set an item to a None column. This mean that '
128
+ f'the DocVec that encapsulate this doc has the field '
129
+ f'{ name } set to None. If you want to modify that you need to do it at the'
130
+ f'DocVec level. `docs.field = np.zeros(10)`'
131
+ )
132
+
103
133
self .storage .columns [name ][self .index ] = value
104
134
105
135
def __delitem__ (self , key ):
0 commit comments