Skip to content

Commit 9707a0f

Browse files
committed
Categorical mapping via units on norm
1 parent 624991c commit 9707a0f

File tree

4 files changed

+87
-84
lines changed

4 files changed

+87
-84
lines changed

build_alllocal.cmd

Lines changed: 0 additions & 36 deletions
This file was deleted.

lib/matplotlib/category.py

Lines changed: 58 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,11 @@ class StrCategoryConverter(munits.ConversionInterface):
3838
3939
Conversion typically happens in the following order:
4040
1. default_units:
41-
creates unit_data category-integer mapping and binds to axis
41+
create unit_data category-integer mapping and binds to axis
4242
2. axis_info:
43-
sets ticks/locator and label/formatter
43+
set ticks/locator and labels/formatter
4444
3. convert:
45-
maps input category data to integers using unit_data
46-
45+
map input category data to integers using unit_data
4746
"""
4847
@staticmethod
4948
def convert(value, unit, axis):
@@ -53,13 +52,13 @@ def convert(value, unit, axis):
5352
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs))
5453

5554
if isinstance(value, six.string_types):
56-
return vmap[value]
55+
return vmap.get(value, None)
5756

5857
vals = to_array(value)
5958
for lab, loc in vmap.items():
6059
vals[vals == lab] = loc
6160

62-
return vals.astype('float')
61+
return vals.astype('float64')
6362

6463
@staticmethod
6564
def axisinfo(unit, axis):
@@ -74,16 +73,19 @@ def axisinfo(unit, axis):
7473
return munits.AxisInfo(majloc=majloc, majfmt=majfmt)
7574

7675
@staticmethod
77-
def default_units(data, axis):
76+
def default_units(data, axis, sort=True, normed=False):
7877
"""
7978
Create mapping between string categories in *data*
80-
and integers, then store in *axis.unit_data*
79+
and integers, and store in *axis.unit_data*
8180
"""
82-
if axis.unit_data is None:
83-
axis.unit_data = UnitData(data)
84-
else:
85-
axis.unit_data.update(data)
86-
return None
81+
if axis and axis.unit_data:
82+
axis.unit_data.update(data, sort)
83+
return axis.unit_data
84+
85+
unit_data = UnitData(data, sort)
86+
if axis:
87+
axis.unit_data = unit_data
88+
return unit_data
8789

8890

8991
class StrCategoryLocator(mticker.FixedLocator):
@@ -110,35 +112,37 @@ class CategoryNorm(mcolors.Normalize):
110112
"""
111113
Preserves ordering of discrete values
112114
"""
113-
def __init__(self, categories):
115+
def __init__(self, data):
114116
"""
115117
*categories*
116118
distinct values for mapping
117119
118-
Out-of-range values are mapped to a value not in categories;
119-
these are then converted to valid indices by :meth:`Colormap.__call__`.
120+
Out-of-range values are mapped to np.nan
120121
"""
121-
self.categories = categories
122-
self.N = len(self.categories)
123-
self.vmin = 0
124-
self.vmax = self.N
125-
self._interp = False
126-
127-
def __call__(self, value, clip=None):
128-
if not cbook.iterable(value):
129-
value = [value]
130-
131-
value = np.asarray(value)
132-
ret = np.ones(value.shape) * np.nan
133122

134-
for i, c in enumerate(self.categories):
135-
ret[value == c] = i / (self.N * 1.0)
123+
self.units = StrCategoryConverter()
124+
self.unit_data = None
125+
self.units.default_units(data,
126+
self, sort=False)
127+
self.loc2seq = dict(zip(self.unit_data.locs, self.unit_data.seq))
128+
self.vmin = min(self.unit_data.locs)
129+
self.vmax = max(self.unit_data.locs)
136130

137-
return np.ma.array(ret, mask=np.isnan(ret))
131+
def __call__(self, value, clip=None):
132+
# gonna have to go into imshow and undo casting
133+
value = np.asarray(value, dtype=np.int)
134+
ret = self.units.convert(value, None, self)
135+
# knock out values not in the norm
136+
mask = np.in1d(ret, self.unit_data.locs).reshape(ret.shape)
137+
# normalize ret & locs
138+
ret /= self.vmax
139+
return np.ma.array(ret, mask=~mask)
138140

139141
def inverse(self, value):
140-
# not quite sure what invertible means in this context
141-
return ValueError("CategoryNorm is not invertible")
142+
if not cbook.iterable(value):
143+
value = np.asarray(value)
144+
vscaled = np.asarray(value) * self.vmax
145+
return [self.loc2seq[int(vs)] for vs in vscaled]
142146

143147

144148
def colors_from_categories(codings):
@@ -156,8 +160,7 @@ def colors_from_categories(codings):
156160
:class:`Normalize` instance
157161
"""
158162
if isinstance(codings, dict):
159-
codings = codings.items()
160-
163+
codings = cbook.sanitize_sequence(codings.items())
161164
values, colors = zip(*codings)
162165
cmap = mcolors.ListedColormap(list(colors))
163166
norm = CategoryNorm(list(values))
@@ -184,30 +187,43 @@ def convert_to_string(value):
184187

185188

186189
class UnitData(object):
187-
# debatable makes sense to special code missing values
190+
# debatable if it makes sense to special code missing values
188191
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0}
189192

190-
def __init__(self, data):
193+
def __init__(self, data, sort=True):
191194
"""Create mapping between unique categorical values
192195
and numerical identifier
193196
Paramters
194197
---------
195198
data: iterable
196199
sequence of values
200+
sort: bool
201+
sort input data, default is True
202+
False preserves input order
197203
"""
198204
self.seq, self.locs = [], []
199-
self._set_seq_locs(data, 0)
205+
self._set_seq_locs(data, 0, sort)
206+
self.sort = sort
200207

201-
def update(self, new_data):
208+
def update(self, new_data, sort=True):
209+
if sort:
210+
self.sort = sort
202211
# so as not to conflict with spdict
203212
value = max(max(self.locs) + 1, 0)
204-
self._set_seq_locs(new_data, value)
213+
self._set_seq_locs(new_data, value, self.sort)
205214

206-
def _set_seq_locs(self, data, value):
215+
def _set_seq_locs(self, data, value, sort):
207216
# magic to make it work under np1.6
208217
strdata = to_array(data)
218+
209219
# np.unique makes dateframes work
210-
new_s = [d for d in np.unique(strdata) if d not in self.seq]
220+
if sort:
221+
unq = np.unique(strdata)
222+
else:
223+
_, idx = np.unique(strdata, return_index=~sort)
224+
unq = strdata[np.sort(idx)]
225+
226+
new_s = [d for d in unq if d not in self.seq]
211227
for ns in new_s:
212228
self.seq.append(convert_to_string(ns))
213229
if ns in UnitData.spdict.keys():

lib/matplotlib/colorbar.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import matplotlib as mpl
3232
import matplotlib.artist as martist
33+
import matplotlib.category as category
3334
import matplotlib.cbook as cbook
3435
import matplotlib.collections as collections
3536
import matplotlib.colors as colors
@@ -312,6 +313,8 @@ def __init__(self, ax, cmap=None,
312313
if format is None:
313314
if isinstance(self.norm, colors.LogNorm):
314315
self.formatter = ticker.LogFormatterMathtext()
316+
elif isinstance(self.norm, category.CategoryNorm):
317+
self.formatter = ticker.FixedFormatter(self.norm.unit_data.seq)
315318
else:
316319
self.formatter = ticker.ScalarFormatter()
317320
elif cbook.is_string_like(format):
@@ -580,6 +583,8 @@ def _ticker(self):
580583
locator = ticker.FixedLocator(b, nbins=10)
581584
elif isinstance(self.norm, colors.LogNorm):
582585
locator = ticker.LogLocator()
586+
elif isinstance(self.norm, category.CategoryNorm):
587+
locator = ticker.FixedLocator(self.norm.unit_data.locs)
583588
else:
584589
if mpl.rcParams['_internal.classic_mode']:
585590
locator = ticker.MaxNLocator()

lib/matplotlib/tests/test_category.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_axisinfo(self):
106106

107107
def test_default_units(self):
108108
axis = FakeAxis(None)
109-
assert self.cc.default_units(["a"], axis) is None
109+
assert isinstance(self.cc.default_units(["a"], axis), cat.UnitData)
110110

111111

112112
class TestStrCategoryLocator(object):
@@ -129,17 +129,35 @@ def test_StrCategoryFormatterUnicode(self):
129129

130130

131131
class TestCategoryNorm(object):
132-
testdata = [[[205, 302, 205, 101], [0, 2. / 3., 0, 1. / 3.]],
133-
[[205, np.nan, 101, 305], [0, 9999, 1. / 3., 2. / 3.]],
134-
[[205, 101, 504, 101], [0, 9999, 1. / 3., 1. / 3.]]]
132+
testdata = [[[205, 302, 205, 101], [0, 1, 0, .5]],
133+
[[205, np.nan, 101, 305], [0, np.nan, .5, 1]],
134+
[[205, 101, 504, 101], [0, .5, np.nan, .5]]]
135135

136136
ids = ["regular", "nan", "exclude"]
137137

138138
@pytest.mark.parametrize("data, nmap", testdata, ids=ids)
139139
def test_norm(self, data, nmap):
140140
norm = cat.CategoryNorm([205, 101, 302])
141-
test = np.ma.masked_equal(nmap, 9999)
142-
np.testing.assert_allclose(norm(data), test)
141+
masked_nmap = np.ma.masked_equal(nmap, np.nan)
142+
assert np.ma.allequal(norm(data), masked_nmap)
143+
144+
def test_invert(self):
145+
data = [205, 302, 101]
146+
strdata = ['205', '302', '101']
147+
value = [0, .5, 1]
148+
norm = cat.CategoryNorm(data)
149+
assert norm.inverse(value) == strdata
150+
151+
152+
class TestColorsFromCategories(object):
153+
testdata = [[{'101': "blue", '205': "red", '302': "green"}, dict],
154+
[[('205', "red"), ('101', "blue"), ('302', "green")], list]]
155+
ids = ["dict", "tuple"]
156+
157+
@pytest.mark.parametrize("codings, mtype", testdata, ids=ids)
158+
def test_colors_from_categories(self, codings, mtype):
159+
cmap, norm = cat.colors_from_categories(codings)
160+
assert mtype(zip(norm.unit_data.seq, cmap.colors)) == codings
143161

144162

145163
def lt(tl):

0 commit comments

Comments
 (0)