Skip to content

Representation Quality

reliably.repr.mig.mig(z, factors, *, n_bins=20, ci='bca', n_bootstrap=2000, level=0.95, seed=0)

Mutual Information Gap (MIG) disentanglement metric.

Parameters:

Name Type Description Default
z array - like

Latent codes, shape (N, D).

required
factors array - like

Ground-truth generative factors, shape (N, K).

required
n_bins int

Bins for MI estimation.

20
ci str | None

CI method.

'bca'
n_bootstrap int

Bootstrap resamples.

2000
level float

Nominal coverage.

0.95
seed int

RNG seed.

0

Returns:

Type Description
MetricResult

Named "MIG".

Examples:

>>> import numpy as np
>>> rng = np.random.default_rng(0)
>>> z = rng.normal(0, 1, (200, 4))
>>> f = rng.normal(0, 1, (200, 3))
>>> r = mig(z, f, ci=None)
>>> 0.0 <= r.value <= 1.0
True
Source code in src/reliably/repr/mig.py
def mig(
    z: Any,
    factors: Any,
    *,
    n_bins: int = 20,
    ci: str | None = "bca",
    n_bootstrap: int = 2000,
    level: float = 0.95,
    seed: int = 0,
) -> MetricResult:
    """Mutual Information Gap (MIG) disentanglement metric.

    Parameters
    ----------
    z : array-like
        Latent codes, shape ``(N, D)``.
    factors : array-like
        Ground-truth generative factors, shape ``(N, K)``.
    n_bins : int
        Bins for MI estimation.
    ci : str | None
        CI method.
    n_bootstrap : int
        Bootstrap resamples.
    level : float
        Nominal coverage.
    seed : int
        RNG seed.

    Returns
    -------
    MetricResult
        Named ``"MIG"``.

    Examples
    --------
    >>> import numpy as np
    >>> rng = np.random.default_rng(0)
    >>> z = rng.normal(0, 1, (200, 4))
    >>> f = rng.normal(0, 1, (200, 3))
    >>> r = mig(z, f, ci=None)
    >>> 0.0 <= r.value <= 1.0
    True
    """
    z_np = to_numpy(z, dtype=np.float64)
    f_np = to_numpy(factors, dtype=np.float64)
    if z_np.ndim == 1:
        z_np = z_np[:, None]
    if f_np.ndim == 1:
        f_np = f_np[:, None]
    n = z_np.shape[0]

    point = _mig_from_arrays(z_np, f_np, n_bins)

    if ci is None:
        return MetricResult(name="MIG", value=point, ci=None, n=n)

    def _est(idx: NDArray[np.intp]) -> float:
        return _mig_from_arrays(z_np[idx], f_np[idx], n_bins)

    ci_result = bootstrap_ci(_est, n, point=point, n_boot=n_bootstrap, level=level,
                             method=ci, seed=seed)
    return MetricResult(name="MIG", value=point, ci=ci_result, n=n)

reliably.repr.sap.sap(z, factors, *, ci='bca', n_bootstrap=2000, level=0.95, seed=0)

Separated Attribute Predictability (SAP).

Parameters:

Name Type Description Default
z array - like

Latent codes, shape (N, D).

required
factors array - like

Ground-truth factors, shape (N, K).

required
ci str | None

CI method.

'bca'
n_bootstrap int

Bootstrap resamples.

2000
level float

Nominal coverage.

0.95
seed int

RNG seed.

0

Returns:

Type Description
MetricResult

Named "SAP".

Examples:

>>> import numpy as np
>>> rng = np.random.default_rng(0)
>>> z = rng.normal(0, 1, (200, 4))
>>> f = rng.normal(0, 1, (200, 3))
>>> r = sap(z, f, ci=None)
>>> r.value >= 0.0
True
Source code in src/reliably/repr/sap.py
def sap(
    z: Any,
    factors: Any,
    *,
    ci: str | None = "bca",
    n_bootstrap: int = 2000,
    level: float = 0.95,
    seed: int = 0,
) -> MetricResult:
    """Separated Attribute Predictability (SAP).

    Parameters
    ----------
    z : array-like
        Latent codes, shape ``(N, D)``.
    factors : array-like
        Ground-truth factors, shape ``(N, K)``.
    ci : str | None
        CI method.
    n_bootstrap : int
        Bootstrap resamples.
    level : float
        Nominal coverage.
    seed : int
        RNG seed.

    Returns
    -------
    MetricResult
        Named ``"SAP"``.

    Examples
    --------
    >>> import numpy as np
    >>> rng = np.random.default_rng(0)
    >>> z = rng.normal(0, 1, (200, 4))
    >>> f = rng.normal(0, 1, (200, 3))
    >>> r = sap(z, f, ci=None)
    >>> r.value >= 0.0
    True
    """
    z_np = to_numpy(z, dtype=np.float64)
    f_np = to_numpy(factors, dtype=np.float64)
    if z_np.ndim == 1:
        z_np = z_np[:, None]
    if f_np.ndim == 1:
        f_np = f_np[:, None]
    n = z_np.shape[0]

    point = _sap_from_arrays(z_np, f_np)

    if ci is None:
        return MetricResult(name="SAP", value=point, ci=None, n=n)

    def _est(idx: NDArray[np.intp]) -> float:
        return _sap_from_arrays(z_np[idx], f_np[idx])

    ci_result = bootstrap_ci(_est, n, point=point, n_boot=n_bootstrap,
                             level=level, method=ci, seed=seed)
    return MetricResult(name="SAP", value=point, ci=ci_result, n=n)

reliably.repr.dci.dci(z, factors, *, ci='bca', n_bootstrap=200, level=0.95, seed=0)

DCI: Disentanglement, Completeness, Informativeness.

Parameters:

Name Type Description Default
z array - like

Latent codes, shape (N, D).

required
factors array - like

Ground-truth factors, shape (N, K).

required
ci str | None

CI method (bootstrap; note: DCI is slow, so default n_bootstrap=200).

'bca'
n_bootstrap int

Bootstrap resamples.

200
level float

Nominal coverage.

0.95
seed int

RNG seed.

0

Returns:

Type Description
MetricResult

Named "DCI" with extra containing {"disentanglement", "completeness", "informativeness"}.

Examples:

>>> import numpy as np
>>> rng = np.random.default_rng(0)
>>> z = rng.normal(0, 1, (100, 4))
>>> f = rng.normal(0, 1, (100, 3))
>>> r = dci(z, f, ci=None)
>>> 0.0 <= r.value <= 1.0
True
Source code in src/reliably/repr/dci.py
def dci(
    z: Any,
    factors: Any,
    *,
    ci: str | None = "bca",
    n_bootstrap: int = 200,
    level: float = 0.95,
    seed: int = 0,
) -> MetricResult:
    """DCI: Disentanglement, Completeness, Informativeness.

    Parameters
    ----------
    z : array-like
        Latent codes, shape ``(N, D)``.
    factors : array-like
        Ground-truth factors, shape ``(N, K)``.
    ci : str | None
        CI method (bootstrap; note: DCI is slow, so default n_bootstrap=200).
    n_bootstrap : int
        Bootstrap resamples.
    level : float
        Nominal coverage.
    seed : int
        RNG seed.

    Returns
    -------
    MetricResult
        Named ``"DCI"`` with ``extra`` containing
        ``{"disentanglement", "completeness", "informativeness"}``.

    Examples
    --------
    >>> import numpy as np
    >>> rng = np.random.default_rng(0)
    >>> z = rng.normal(0, 1, (100, 4))
    >>> f = rng.normal(0, 1, (100, 3))
    >>> r = dci(z, f, ci=None)
    >>> 0.0 <= r.value <= 1.0
    True
    """
    z_np = to_numpy(z, dtype=np.float64)
    f_np = to_numpy(factors, dtype=np.float64)
    if z_np.ndim == 1:
        z_np = z_np[:, None]
    if f_np.ndim == 1:
        f_np = f_np[:, None]
    n = z_np.shape[0]

    d_val, c_val, info_val = _dci_from_arrays(z_np, f_np)
    point = (d_val + c_val) / 2.0
    extra = {"disentanglement": d_val, "completeness": c_val, "informativeness": info_val}

    if ci is None:
        return MetricResult(name="DCI", value=point, ci=None, n=n, extra=extra)

    def _est(idx: NDArray[np.intp]) -> float:
        d_, c_, _ = _dci_from_arrays(z_np[idx], f_np[idx])
        return (d_ + c_) / 2.0

    ci_result = bootstrap_ci(_est, n, point=point, n_boot=n_bootstrap,
                             level=level, method=ci, seed=seed)
    return MetricResult(name="DCI", value=point, ci=ci_result, n=n, extra=extra)

reliably.repr.factorvae.factorvae_metric(z, factors, *, n_votes=800, batch_size=64, ci='bca', n_bootstrap=200, level=0.95, seed=0)

FactorVAE disentanglement metric.

Parameters:

Name Type Description Default
z array - like

Latent codes, shape (N, D).

required
factors array - like

Ground-truth factors, shape (N, K).

required
n_votes int

Number of voting rounds.

800
batch_size int

Batch size per round.

64
ci str | None

CI method.

'bca'
n_bootstrap int

Bootstrap resamples.

200
level float

Nominal coverage.

0.95
seed int

RNG seed.

0

Returns:

Type Description
MetricResult

Named "FactorVAE", value in [0, 1].

Examples:

>>> import numpy as np
>>> rng = np.random.default_rng(0)
>>> z = rng.normal(0, 1, (300, 4))
>>> f = rng.normal(0, 1, (300, 3))
>>> r = factorvae_metric(z, f, ci=None)
>>> 0.0 <= r.value <= 1.0
True
Source code in src/reliably/repr/factorvae.py
def factorvae_metric(
    z: Any,
    factors: Any,
    *,
    n_votes: int = 800,
    batch_size: int = 64,
    ci: str | None = "bca",
    n_bootstrap: int = 200,
    level: float = 0.95,
    seed: int = 0,
) -> MetricResult:
    """FactorVAE disentanglement metric.

    Parameters
    ----------
    z : array-like
        Latent codes, shape ``(N, D)``.
    factors : array-like
        Ground-truth factors, shape ``(N, K)``.
    n_votes : int
        Number of voting rounds.
    batch_size : int
        Batch size per round.
    ci : str | None
        CI method.
    n_bootstrap : int
        Bootstrap resamples.
    level : float
        Nominal coverage.
    seed : int
        RNG seed.

    Returns
    -------
    MetricResult
        Named ``"FactorVAE"``, value in ``[0, 1]``.

    Examples
    --------
    >>> import numpy as np
    >>> rng = np.random.default_rng(0)
    >>> z = rng.normal(0, 1, (300, 4))
    >>> f = rng.normal(0, 1, (300, 3))
    >>> r = factorvae_metric(z, f, ci=None)
    >>> 0.0 <= r.value <= 1.0
    True
    """
    z_np = to_numpy(z, dtype=np.float64)
    f_np = to_numpy(factors, dtype=np.float64)
    if z_np.ndim == 1:
        z_np = z_np[:, None]
    if f_np.ndim == 1:
        f_np = f_np[:, None]
    n = z_np.shape[0]

    rng = make_rng(seed)
    point = _factorvae_from_arrays(z_np, f_np, n_votes=n_votes, batch_size=batch_size, rng=rng)

    if ci is None:
        return MetricResult(name="FactorVAE", value=point, ci=None, n=n)

    def _est(idx: NDArray[np.intp]) -> float:
        sub_rng = make_rng(int(idx[:3].sum()))
        return _factorvae_from_arrays(
            z_np[idx], f_np[idx], n_votes=n_votes, batch_size=min(batch_size, len(idx)),
            rng=sub_rng
        )

    ci_result = bootstrap_ci(_est, n, point=point, n_boot=n_bootstrap,
                             level=level, method=ci, seed=seed)
    return MetricResult(name="FactorVAE", value=point, ci=ci_result, n=n)

reliably.repr.irs.irs(z, factors, *, n_interventions=100, ci='bca', n_bootstrap=200, level=0.95, seed=0)

Interventional Robustness Score (IRS).

Measures maximum change in the matched latent under interventions on nuisance factors while the target factor is held fixed (Suter et al., 2019). Higher score = more robust / better disentanglement.

Parameters:

Name Type Description Default
z array - like

Latent codes, shape (N, D).

required
factors array - like

Ground-truth factors, shape (N, K).

required
n_interventions int

Number of random interventions to sample.

100
ci str | None

CI method.

'bca'
n_bootstrap int

Bootstrap resamples.

200
level float

Nominal coverage.

0.95
seed int

RNG seed.

0

Returns:

Type Description
MetricResult

Named "IRS", value in [0, 1].

Examples:

>>> import numpy as np
>>> rng = np.random.default_rng(0)
>>> z = rng.normal(0, 1, (300, 4))
>>> f = rng.normal(0, 1, (300, 3))
>>> r = irs(z, f, ci=None)
>>> 0.0 <= r.value <= 1.0
True
Source code in src/reliably/repr/irs.py
def irs(
    z: Any,
    factors: Any,
    *,
    n_interventions: int = 100,
    ci: str | None = "bca",
    n_bootstrap: int = 200,
    level: float = 0.95,
    seed: int = 0,
) -> MetricResult:
    """Interventional Robustness Score (IRS).

    Measures maximum change in the matched latent under interventions on
    nuisance factors while the target factor is held fixed (Suter et al., 2019).
    Higher score = more robust / better disentanglement.

    Parameters
    ----------
    z : array-like
        Latent codes, shape ``(N, D)``.
    factors : array-like
        Ground-truth factors, shape ``(N, K)``.
    n_interventions : int
        Number of random interventions to sample.
    ci : str | None
        CI method.
    n_bootstrap : int
        Bootstrap resamples.
    level : float
        Nominal coverage.
    seed : int
        RNG seed.

    Returns
    -------
    MetricResult
        Named ``"IRS"``, value in ``[0, 1]``.

    Examples
    --------
    >>> import numpy as np
    >>> rng = np.random.default_rng(0)
    >>> z = rng.normal(0, 1, (300, 4))
    >>> f = rng.normal(0, 1, (300, 3))
    >>> r = irs(z, f, ci=None)
    >>> 0.0 <= r.value <= 1.0
    True
    """
    z_np = to_numpy(z, dtype=np.float64)
    f_np = to_numpy(factors, dtype=np.float64)
    if z_np.ndim == 1:
        z_np = z_np[:, None]
    if f_np.ndim == 1:
        f_np = f_np[:, None]
    n = z_np.shape[0]

    rng = make_rng(seed)
    point = _irs_from_arrays(z_np, f_np, n_interventions=n_interventions, rng=rng)

    if ci is None:
        return MetricResult(name="IRS", value=point, ci=None, n=n)

    def _est(idx: NDArray[np.intp]) -> float:
        sub_rng = make_rng(int(idx[:3].sum()))
        return _irs_from_arrays(z_np[idx], f_np[idx],
                                n_interventions=n_interventions, rng=sub_rng)

    ci_result = bootstrap_ci(_est, n, point=point, n_boot=n_bootstrap,
                             level=level, method=ci, seed=seed)
    return MetricResult(name="IRS", value=point, ci=ci_result, n=n)