Skip to content

[WIP] Categorical Color Mapping #6934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
CategoryNorm
  • Loading branch information
story645 committed Aug 17, 2016
commit 624991c1475111bf760ab2cb9b6e3da99e30b8dd
113 changes: 99 additions & 14 deletions lib/matplotlib/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import numpy as np

import matplotlib.colors as mcolors
import matplotlib.cbook as cbook
import matplotlib.units as units
import matplotlib.ticker as ticker

import matplotlib.units as munits
import matplotlib.ticker as mticker

# pure hack for numpy 1.6 support
from distutils.version import LooseVersion
Expand All @@ -33,11 +33,22 @@ def to_array(data, maxlen=100):
return vals


class StrCategoryConverter(units.ConversionInterface):
class StrCategoryConverter(munits.ConversionInterface):
"""Converts categorical (or string) data to numerical values

Conversion typically happens in the following order:
1. default_units:
creates unit_data category-integer mapping and binds to axis
2. axis_info:
sets ticks/locator and label/formatter
3. convert:
maps input category data to integers using unit_data

"""
@staticmethod
def convert(value, unit, axis):
"""Uses axis.unit_data map to encode
data as floats
"""
Encode value as floats using axis.unit_data
"""
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs))

Expand All @@ -52,33 +63,107 @@ def convert(value, unit, axis):

@staticmethod
def axisinfo(unit, axis):
"""
Return the :class:`~matplotlib.units.AxisInfo` for *unit*.

*unit* is None
*axis.unit_data* is used to set ticks and labels
"""
majloc = StrCategoryLocator(axis.unit_data.locs)
majfmt = StrCategoryFormatter(axis.unit_data.seq)
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
return munits.AxisInfo(majloc=majloc, majfmt=majfmt)

@staticmethod
def default_units(data, axis):
# the conversion call stack is:
# default_units->axis_info->convert
"""
Create mapping between string categories in *data*
and integers, then store in *axis.unit_data*
"""
if axis.unit_data is None:
axis.unit_data = UnitData(data)
else:
axis.unit_data.update(data)
return None


class StrCategoryLocator(ticker.FixedLocator):
class StrCategoryLocator(mticker.FixedLocator):
"""
Ensures that every category has a tick by subclassing
:class:`~matplotlib.ticker.FixedLocator`
"""
def __init__(self, locs):
self.locs = locs
self.nbins = None


class StrCategoryFormatter(ticker.FixedFormatter):
class StrCategoryFormatter(mticker.FixedFormatter):
"""
Labels every category by subclassing
:class:`~matplotlib.ticker.FixedFormatter`
"""
def __init__(self, seq):
self.seq = seq
self.offset_string = ''


class CategoryNorm(mcolors.Normalize):
"""
Preserves ordering of discrete values
"""
def __init__(self, categories):
"""
*categories*
distinct values for mapping

Out-of-range values are mapped to a value not in categories;
these are then converted to valid indices by :meth:`Colormap.__call__`.
"""
self.categories = categories
self.N = len(self.categories)
self.vmin = 0
self.vmax = self.N
self._interp = False

def __call__(self, value, clip=None):
if not cbook.iterable(value):
value = [value]

value = np.asarray(value)
ret = np.ones(value.shape) * np.nan

for i, c in enumerate(self.categories):
ret[value == c] = i / (self.N * 1.0)

return np.ma.array(ret, mask=np.isnan(ret))

def inverse(self, value):
# not quite sure what invertible means in this context
return ValueError("CategoryNorm is not invertible")


def colors_from_categories(codings):
"""
Helper routine to generate a cmap and a norm from a list
of (color, value) pairs

Parameters
----------
codings : sequence of (key, value) pairs

Returns
-------
(cmap, norm) : tuple containing a :class:`Colormap` and a \
:class:`Normalize` instance
"""
if isinstance(codings, dict):
codings = codings.items()

values, colors = zip(*codings)
cmap = mcolors.ListedColormap(list(colors))
norm = CategoryNorm(list(values))
return cmap, norm


def convert_to_string(value):
"""Helper function for numpy 1.6, can be replaced with
np.array(...,dtype=unicode) for all later versions of numpy"""
Expand Down Expand Up @@ -132,6 +217,6 @@ def _set_seq_locs(self, data, value):
value += 1

# Connects the convertor to matplotlib
units.registry[str] = StrCategoryConverter()
units.registry[bytes] = StrCategoryConverter()
units.registry[six.text_type] = StrCategoryConverter()
munits.registry[str] = StrCategoryConverter()
munits.registry[bytes] = StrCategoryConverter()
munits.registry[six.text_type] = StrCategoryConverter()
14 changes: 14 additions & 0 deletions lib/matplotlib/tests/test_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ def test_StrCategoryFormatterUnicode(self):
assert labels('a', 1) == "привет"


class TestCategoryNorm(object):
testdata = [[[205, 302, 205, 101], [0, 2. / 3., 0, 1. / 3.]],
[[205, np.nan, 101, 305], [0, 9999, 1. / 3., 2. / 3.]],
[[205, 101, 504, 101], [0, 9999, 1. / 3., 1. / 3.]]]

ids = ["regular", "nan", "exclude"]

@pytest.mark.parametrize("data, nmap", testdata, ids=ids)
def test_norm(self, data, nmap):
norm = cat.CategoryNorm([205, 101, 302])
test = np.ma.masked_equal(nmap, 9999)
np.testing.assert_allclose(norm(data), test)


def lt(tl):
return [l.get_text() for l in tl]

Expand Down