arviz_base.extract

Contents

arviz_base.extract#

arviz_base.extract(data, group='posterior', sample_dims=None, *, combined=True, var_names=None, filter_vars=None, num_samples=None, weights=None, resampling_method=None, keep_dataset=False, random_seed=None)[source]#

Extract a group or group subset from a DataTree.

Parameters:
idataDataTree-like

DataTree from which to extract the data.

groupstr, optional

Which group to extract data from.

sample_dimssequence of hashable, optional

List of dimensions that should be considered sampling dimensions. Random subsets and potential stacking if combine=True happen over these dimensions only. Defaults to rcParams["data.sample_dims"].

combinedbool, optional

Combine sample_dims dimensions into sample. Won’t work if a dimension named sample already exists. It is irrelevant and ignored when sample_dims is a single dimension.

var_namesstr or list of str, optional

Variables to be extracted. Prefix the variables by when you want to exclude them.

filter_vars{None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include

num_samplesint, optional

Extract only a subset of the samples. Only valid if combined=True or sample_dims represents a single dimension.

weightsarray_like, optional

Extract a weighted subset of the samples. Only valid if num_samples is not None.

resampling_methodstr, optional

Method to use for resampling. Default is “multinomial”. Options are “multinomial” and “stratified”. For stratified resampling, weights must be provided. Default is “stratified” if weights are provided, “multinomial” otherwise.

keep_datasetbool, optional

If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.

random_seedint, numpy.Generator, optional

Random number generator or seed. Only used if weights is not None or if num_samples is not None.

Returns:
xarray.DataArray or xarray.Dataset

Examples

The default behaviour is to return the posterior group after stacking the chain and draw dimensions.

import arviz_base as az
idata = az.load_arviz_data("centered_eight")
az.extract(idata)
<xarray.Dataset> Size: 209kB
Dimensions:  (sample: 2000, school: 8)
Coordinates:
  * sample   (sample) object 16kB MultiIndex
  * chain    (sample) int64 16kB 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3
  * draw     (sample) int64 16kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
    mu       (sample) float64 16kB 1.716 1.903 1.903 1.903 ... 5.409 7.721 10.24
    theta    (school, sample) float64 128kB 2.317 0.8892 0.8892 ... 9.754 14.02
    tau      (sample) float64 16kB 0.8775 0.8027 0.8027 ... 2.236 2.99 3.052
Attributes:
    created_at:                 2025-01-19T14:32:33.071271+00:00
    arviz_version:              0.20.0
    inference_library:          pymc
    inference_library_version:  5.20.0
    sampling_time:              3.159093141555786
    tuning_steps:               1000

You can also indicate a subset to be returned, but in variables and in samples:

az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)> Size: 6kB
array([[ 6.68471590e+00,  2.57249107e+00,  2.99172113e+00,
         5.81653080e+00,  1.55119685e+00, -2.13347506e+00,
         8.12191423e+00,  3.78557008e+00,  1.90457029e+01,
         1.17409423e+01,  4.43616091e+00,  2.92038998e+00,
         4.71273898e+00, -7.02366780e-01,  4.27337284e+00,
         1.32758740e+01,  9.37066645e+00,  1.10972038e+01,
         1.91515897e+00,  5.37903339e+00,  1.09675811e+01,
         6.24439453e+00,  5.07493796e+00,  7.87700852e+00,
         6.70987225e+00, -7.95927409e+00, -2.28779404e+00,
         7.16565453e+00,  6.69068199e+00,  3.33481048e+00,
         4.87010661e+00,  2.35939213e+01,  5.87166557e+00,
         1.54030624e+01,  5.88693731e+00,  1.29617378e+01,
         4.29774953e+00,  4.34980010e+00,  1.13150317e+01,
         7.47838234e+00,  1.78722902e+01,  7.20466974e-01,
         2.29337646e+01,  2.05025159e+00,  1.52621772e+00,
         6.28214233e+00,  1.30671870e+00,  1.05316393e+01,
         4.14922272e+00,  1.47220436e+01,  6.69690701e+00,
        -6.20691096e+00,  6.06725954e+00,  2.71076287e+00,
         3.35049267e+00,  1.27021144e+01,  1.19449175e+01,
         1.00552225e+01,  6.28250875e+00,  9.55189677e+00,
...
         1.59686541e+01, -1.06621507e-01,  3.49501735e+00,
        -1.05716230e+01,  4.45165075e+00,  1.11020636e+00,
         9.69629831e+00,  6.21592516e+00,  3.67578666e+00,
         4.18452816e+00,  3.03240922e+00,  5.27501577e+00,
         2.74181051e+00,  7.05670536e+00,  5.81672064e+00,
         1.18808099e+00,  2.60479843e+00,  4.00740765e+00,
        -1.53972551e-01,  9.28109060e+00,  8.99880439e+00,
         3.07043795e+00,  1.46427630e+01,  1.83572572e+01,
        -6.72358848e+00,  6.88936888e+00,  3.82680610e+00,
         9.42741342e+00,  1.10157749e+00,  4.24093460e+00,
        -5.16899739e+00,  7.48235996e+00,  4.33901016e+00,
         1.02765035e+01,  2.64341240e+00,  6.81192833e+00,
         6.95804318e-01,  1.97179248e+00, -8.68273653e-01,
        -2.17777716e+00,  3.78542983e+00,  5.86804626e+00,
         9.40307510e+00,  2.56739760e+00,  1.62283160e+00,
         1.04302129e+01,  3.62829097e+00,  3.00465019e+00,
         6.68384338e+00,  7.86257298e+00, -1.21359955e+00,
         9.87984716e+00,  1.72647854e+00,  7.70648424e+00,
         1.52228819e+01,  7.42756817e+00,  2.88561246e+00,
         6.11788501e+00]])
Coordinates:
  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
  * sample   (sample) object 800B MultiIndex
  * chain    (sample) int64 800B 0 0 1 1 1 0 0 3 3 3 0 ... 1 0 3 3 1 0 2 1 1 1 3
  * draw     (sample) int64 800B 28 438 482 406 240 43 ... 390 412 28 364 263

To keep the chain and draw dimensions, use combined=False.

az.extract(idata, group="prior", combined=False)
<xarray.Dataset> Size: 45kB
Dimensions:  (chain: 1, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 8B 0
  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
    theta    (chain, draw, school) float64 32kB -8.435 24.12 ... 54.57 52.29
    tau      (chain, draw) float64 4kB 11.93 17.76 4.732 ... 2.231 3.319 93.69
    mu       (chain, draw) float64 4kB 4.714 3.853 1.709 ... -2.245 -2.435
Attributes:
    created_at:                 2025-01-19T14:32:29.212688+00:00
    arviz_version:              0.20.0
    inference_library:          pymc
    inference_library_version:  5.20.0