@@ -38,12 +38,11 @@ class StrCategoryConverter(munits.ConversionInterface):
38
38
39
39
Conversion typically happens in the following order:
40
40
1. default_units:
41
- creates unit_data category-integer mapping and binds to axis
41
+ create unit_data category-integer mapping and binds to axis
42
42
2. axis_info:
43
- sets ticks/locator and label /formatter
43
+ set ticks/locator and labels /formatter
44
44
3. convert:
45
- maps input category data to integers using unit_data
46
-
45
+ map input category data to integers using unit_data
47
46
"""
48
47
@staticmethod
49
48
def convert (value , unit , axis ):
@@ -53,13 +52,13 @@ def convert(value, unit, axis):
53
52
vmap = dict (zip (axis .unit_data .seq , axis .unit_data .locs ))
54
53
55
54
if isinstance (value , six .string_types ):
56
- return vmap [ value ]
55
+ return vmap . get ( value , None )
57
56
58
57
vals = to_array (value )
59
58
for lab , loc in vmap .items ():
60
59
vals [vals == lab ] = loc
61
60
62
- return vals .astype ('float ' )
61
+ return vals .astype ('float64 ' )
63
62
64
63
@staticmethod
65
64
def axisinfo (unit , axis ):
@@ -74,16 +73,19 @@ def axisinfo(unit, axis):
74
73
return munits .AxisInfo (majloc = majloc , majfmt = majfmt )
75
74
76
75
@staticmethod
77
- def default_units (data , axis ):
76
+ def default_units (data , axis , sort = True , normed = False ):
78
77
"""
79
78
Create mapping between string categories in *data*
80
- and integers, then store in *axis.unit_data*
79
+ and integers, and store in *axis.unit_data*
81
80
"""
82
- if axis .unit_data is None :
83
- axis .unit_data = UnitData (data )
84
- else :
85
- axis .unit_data .update (data )
86
- return None
81
+ if axis and axis .unit_data :
82
+ axis .unit_data .update (data , sort )
83
+ return axis .unit_data
84
+
85
+ unit_data = UnitData (data , sort )
86
+ if axis :
87
+ axis .unit_data = unit_data
88
+ return unit_data
87
89
88
90
89
91
class StrCategoryLocator (mticker .FixedLocator ):
@@ -110,35 +112,37 @@ class CategoryNorm(mcolors.Normalize):
110
112
"""
111
113
Preserves ordering of discrete values
112
114
"""
113
- def __init__ (self , categories ):
115
+ def __init__ (self , data ):
114
116
"""
115
117
*categories*
116
118
distinct values for mapping
117
119
118
- Out-of-range values are mapped to a value not in categories;
119
- these are then converted to valid indices by :meth:`Colormap.__call__`.
120
+ Out-of-range values are mapped to np.nan
120
121
"""
121
- self .categories = categories
122
- self .N = len (self .categories )
123
- self .vmin = 0
124
- self .vmax = self .N
125
- self ._interp = False
126
-
127
- def __call__ (self , value , clip = None ):
128
- if not cbook .iterable (value ):
129
- value = [value ]
130
-
131
- value = np .asarray (value )
132
- ret = np .ones (value .shape ) * np .nan
133
122
134
- for i , c in enumerate (self .categories ):
135
- ret [value == c ] = i / (self .N * 1.0 )
123
+ self .units = StrCategoryConverter ()
124
+ self .unit_data = None
125
+ self .units .default_units (data ,
126
+ self , sort = False )
127
+ self .loc2seq = dict (zip (self .unit_data .locs , self .unit_data .seq ))
128
+ self .vmin = min (self .unit_data .locs )
129
+ self .vmax = max (self .unit_data .locs )
136
130
137
- return np .ma .array (ret , mask = np .isnan (ret ))
131
+ def __call__ (self , value , clip = None ):
132
+ # gonna have to go into imshow and undo casting
133
+ value = np .asarray (value , dtype = np .int )
134
+ ret = self .units .convert (value , None , self )
135
+ # knock out values not in the norm
136
+ mask = np .in1d (ret , self .unit_data .locs ).reshape (ret .shape )
137
+ # normalize ret & locs
138
+ ret /= self .vmax
139
+ return np .ma .array (ret , mask = ~ mask )
138
140
139
141
def inverse (self , value ):
140
- # not quite sure what invertible means in this context
141
- return ValueError ("CategoryNorm is not invertible" )
142
+ if not cbook .iterable (value ):
143
+ value = np .asarray (value )
144
+ vscaled = np .asarray (value ) * self .vmax
145
+ return [self .loc2seq [int (vs )] for vs in vscaled ]
142
146
143
147
144
148
def colors_from_categories (codings ):
@@ -156,8 +160,7 @@ def colors_from_categories(codings):
156
160
:class:`Normalize` instance
157
161
"""
158
162
if isinstance (codings , dict ):
159
- codings = codings .items ()
160
-
163
+ codings = cbook .sanitize_sequence (codings .items ())
161
164
values , colors = zip (* codings )
162
165
cmap = mcolors .ListedColormap (list (colors ))
163
166
norm = CategoryNorm (list (values ))
@@ -184,30 +187,43 @@ def convert_to_string(value):
184
187
185
188
186
189
class UnitData (object ):
187
- # debatable makes sense to special code missing values
190
+ # debatable if it makes sense to special code missing values
188
191
spdict = {'nan' : - 1.0 , 'inf' : - 2.0 , '-inf' : - 3.0 }
189
192
190
- def __init__ (self , data ):
193
+ def __init__ (self , data , sort = True ):
191
194
"""Create mapping between unique categorical values
192
195
and numerical identifier
193
196
Paramters
194
197
---------
195
198
data: iterable
196
199
sequence of values
200
+ sort: bool
201
+ sort input data, default is True
202
+ False preserves input order
197
203
"""
198
204
self .seq , self .locs = [], []
199
- self ._set_seq_locs (data , 0 )
205
+ self ._set_seq_locs (data , 0 , sort )
206
+ self .sort = sort
200
207
201
- def update (self , new_data ):
208
+ def update (self , new_data , sort = True ):
209
+ if sort :
210
+ self .sort = sort
202
211
# so as not to conflict with spdict
203
212
value = max (max (self .locs ) + 1 , 0 )
204
- self ._set_seq_locs (new_data , value )
213
+ self ._set_seq_locs (new_data , value , self . sort )
205
214
206
- def _set_seq_locs (self , data , value ):
215
+ def _set_seq_locs (self , data , value , sort ):
207
216
# magic to make it work under np1.6
208
217
strdata = to_array (data )
218
+
209
219
# np.unique makes dateframes work
210
- new_s = [d for d in np .unique (strdata ) if d not in self .seq ]
220
+ if sort :
221
+ unq = np .unique (strdata )
222
+ else :
223
+ _ , idx = np .unique (strdata , return_index = ~ sort )
224
+ unq = strdata [np .sort (idx )]
225
+
226
+ new_s = [d for d in unq if d not in self .seq ]
211
227
for ns in new_s :
212
228
self .seq .append (convert_to_string (ns ))
213
229
if ns in UnitData .spdict .keys ():
0 commit comments