test_common.py 1.68 KB
"""Test loaders for common functionality.
"""
import pytest
import numpy as np


def check_pandas_dependency_message(fetch_func):
    try:
        import pandas  # noqa
        pytest.skip("This test requires pandas to be not installed")
    except ImportError:
        # Check that pandas is imported lazily and that an informative error
        # message is raised when pandas is missing:
        expected_msg = ('{} with as_frame=True requires pandas'
                        .format(fetch_func.__name__))
        with pytest.raises(ImportError, match=expected_msg):
            fetch_func(as_frame=True)


def check_return_X_y(bunch, fetch_func_partial):
    X_y_tuple = fetch_func_partial(return_X_y=True)
    assert isinstance(X_y_tuple, tuple)
    assert X_y_tuple[0].shape == bunch.data.shape
    assert X_y_tuple[1].shape == bunch.target.shape


def check_as_frame(bunch, fetch_func_partial,
                   expected_data_dtype=None, expected_target_dtype=None):
    pd = pytest.importorskip('pandas')
    frame_bunch = fetch_func_partial(as_frame=True)
    assert hasattr(frame_bunch, 'frame')
    assert isinstance(frame_bunch.frame, pd.DataFrame)
    assert isinstance(frame_bunch.data, pd.DataFrame)
    assert frame_bunch.data.shape == bunch.data.shape
    if frame_bunch.target.ndim > 1:
        assert isinstance(frame_bunch.target, pd.DataFrame)
    else:
        assert isinstance(frame_bunch.target, pd.Series)
    assert frame_bunch.target.shape[0] == bunch.target.shape[0]
    if expected_data_dtype is not None:
        assert np.all(frame_bunch.data.dtypes == expected_data_dtype)
    if expected_target_dtype is not None:
        assert np.all(frame_bunch.target.dtypes == expected_target_dtype)