Based on my knowledge of the xarray/xray codebase and the issue description, I need to create a patch that adds an option to handle missing variables when concatenating datasets. The issue asks for behavior similar to pandas' `pd.concat` which does an outer join. Let me create the patch based on my understanding of the xarray codebase structure. The key changes needed are: 1. Add a `data_vars` parameter to the `concat` function (similar to existing `coords` parameter) 2. The parameter should support values like `'all'`, `'minimal'`, or `'different'` to control how variables are handled 3. When variables are missing from some datasets, fill them with NaN Here's the patch for the xarray `concat` function: --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,11 +1,13 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import pandas as pd from . import utils -from .variable import Variable, Coordinate -from .pycompat import OrderedDict +from .variable import Variable, Coordinate, broadcast_variables +from .pycompat import OrderedDict, iteritems -def concat(objs, dim='concat_dim', indexers=None, mode='different', +def concat(objs, dim='concat_dim', data_vars='all', coords='different', indexers=None, mode=None, concat_over=None, compat='equals'): """Concatenate xarray objects along a new or existing dimension. @@ -23,12 +25,27 @@ def concat(objs, dim='concat_dim', indexers=None, mode='different', dimension already exists, the data is indexed along that dimension before concatenation. instead. - mode : {'minimal', 'different', 'all'}, optional - Decides which variables are concatenated. Choices are 'minimal' - in which only variables in which dimension already appears are - included, 'different' in which all variables which are not equal (not - those equal for all datasets) across all datasets are concatenated - (as well as all for which dimension already appears), and 'all' for - which all variables are concatenated. By default, mode is 'different'. + data_vars : {'minimal', 'different', 'all'} or list of str, optional + These data variables will be concatenated together: + * 'minimal': Only data variables in which the dimension already + appears are included. + * 'different': Data variables which are not equal (ignoring + attributes) across all datasets are also concatenated (as well as + all for which dimension already appears). Beware: this option may + load the data payload of data variables into memory if they are not + already loaded. + * 'all': All data variables will be concatenated. + * list of str: The listed data variables will be concatenated, in + addition to the 'minimal' data variables. + If objects are DataArrays, data_vars must be 'all'. + coords : {'minimal', 'different', 'all'} o list of str, optional + These coordinate variables will be concatenated together: + * 'minimal': Only coordinates in which the dimension already appears + are included. + * 'different': Coordinates which are not equal (ignoring attributes) + across all datasets are also concatenated (as well as all for which + dimension already appears). Beware: this option may load the data + payload of coordinate variables into memory if they are not already + loaded. + * 'all': All coordinate variables will be concatenated. + * list of str: The listed coordinate variables will be concatenated, + in addition to the 'minimal' coordinates. concat_over : None or str or iterable of str, optional - Names of additional variables to concatenate, in which "weights" would - appear in the result as the concatenation of the input variables - "weights". By default, only variables in which `dim` appears are - included in the result. + Deprecated; use data_vars instead. compat : {'equals', 'identical'}, optional String indicating how to compare non-concatenated variables and dataset global attributes for potential conflicts. 'equals' means @@ -62,9 +79,6 @@ def concat(objs, dim='concat_dim', indexers=None, mode='different', # we've already verified that the brunt of the parameters are OK so # now it's OK to convert objects to datasets datasets = [as_dataset(ds) for ds in objs] - dim, coord = _calc_concat_dim_coord(dim) - - concat_over = set() if isinstance(dim, basestring): dim, coord = _calc_concat_dim_coord(dim) @@ -72,7 +86,19 @@ def concat(objs, dim='concat_dim', indexers=None, mode='different', dim = getattr(dim, 'name', dim) coord = dim - if mode not in ['minimal', 'different', 'all']: + # deprecation handling + if mode is not None: + import warnings + warnings.warn('the `mode` argument to `concat` has been deprecated; ' + 'please use `data_vars` and `coords` instead', + DeprecationWarning, stacklevel=2) + data_vars = mode + coords = mode + + concat_over = set() + + # determine variables to concatenate + if data_vars not in ['minimal', 'different', 'all']: raise ValueError("unexpected value for mode: %s" % mode) if concat_over is None: @@ -85,45 +111,66 @@ def concat(objs, dim='concat_dim', indexers=None, mode='different', # automatically concatenate over variables with the new dimension for ds in datasets: - concat_over.update(k for k, v in ds.variables.items() + concat_over.update(k for k, v in iteritems(ds.variables) if dim in v.dims) - # determine which variables to test for equality - equals = OrderedDict() - if mode == 'minimal': - pass - elif mode == 'different': - # variables that differ across datasets should be concatenated - for ds in datasets: - for k, v in ds.variables.items(): - if k not in concat_over: - if k in equals: - if not (equals[k] is True or v.equals(equals[k])): - concat_over.add(k) - equals[k] = False - else: - equals[k] = v - elif mode == 'all': - for ds in datasets: - concat_over.update(ds.data_vars) - else: - raise ValueError("unexpected value for mode: %s" % mode) - - return _concat(datasets, dim, coord, concat_over, compat) + # gather all variable names from all datasets + all_vars = set() + for ds in datasets: + all_vars.update(ds.variables) + # determine which data variables to concatenate + if isinstance(data_vars, basestring): + if data_vars == 'minimal': + pass + elif data_vars == 'different': + for ds in datasets: + for k, v in iteritems(ds.data_vars): + if k not in concat_over: + # check if variable exists and is equal in all datasets + all_equal = True + for other_ds in datasets: + if k in other_ds.variables: + if not v.equals(other_ds.variables[k]): + all_equal = False + break + else: + all_equal = False + break + if not all_equal: + concat_over.add(k) + elif data_vars == 'all': + for ds in datasets: + concat_over.update(ds.data_vars) + else: + raise ValueError("unexpected value for data_vars: %s" % data_vars) + else: + concat_over.update(data_vars) -def _concat(datasets, dim, coord, concat_over, compat): - """ - Concatenate a sequence of datasets along a new or existing dimension. - """ - from .dataset import Dataset + return _dataset_concat(datasets, dim, data_vars, coords, compat) - # Make sure we're working with datasets - datasets = [as_dataset(ds) for ds in datasets] +def _dataset_concat(datasets, dim, data_vars, coords, compat): + """ + Concatenate a sequence of datasets. + """ + from .dataset import Dataset + + # determine coordinate and dimension for concatenation + if isinstance(dim, basestring): + dim, coord = _calc_concat_dim_coord(dim) + else: + coord = dim + dim = getattr(dim, 'name', dim) + # Determine which variables to include in result - # Variables that are in all datasets with same values should be included - result_vars = OrderedDict() + # Use union of all variables across all datasets (outer join) + all_data_vars = set() + all_coord_vars = set() + for ds in datasets: + all_data_vars.update(ds.data_vars) + all_coord_vars.update(ds.coords) + + # Determine which variables to concatenate vs. merge + concat_over = set() # Variables in concat_dim should be concatenated for ds in datasets: - for name, var in ds.variables.items(): - if dim in var.dims: - if name not in concat_over: - concat_over.add(name) + concat_over.update(k for k, v in iteritems(ds.variables) + if dim in v.dims) + + # Add variables based on data_vars setting + if isinstance(data_vars, basestring): + if data_vars == 'all': + concat_over.update(all_data_vars) + elif data_vars == 'different': + for k in all_data_vars: + if k not in concat_over: + # Check if variable differs across datasets + ref_var = None + for ds in datasets: + if k in ds.variables: + if ref_var is None: + ref_var = ds.variables[k] + elif not ref_var.equals(ds.variables[k]): + concat_over.add(k) + break + else: + # Variable missing from some datasets means we should concat + if any(k not in ds.variables for ds in datasets): + concat_over.add(k) + elif isinstance(data_vars, list): + concat_over.update(data_vars) + + # Concatenate variables + result_vars = OrderedDict() - # Concatenate variables that need concatenation + # Handle concatenated variables - fill missing with NaN for k in concat_over: - vars_to_concat = [ds.variables[k] for ds in datasets] + vars_to_concat = [] + for ds in datasets: + if k in ds.variables: + vars_to_concat.append(ds.variables[k]) + else: + # Create a NaN-filled variable with appropriate shape + # We need to get the shape from datasets that have this variable + ref_var = None + for ref_ds in datasets: + if k in ref_ds.variables: + ref_var = ref_ds.variables[k] + break + if ref_var is not None: + # Create fill variable with NaN values + import numpy as np + shape = list(ref_var.shape) + dims = ref_var.dims + # Adjust shape for the concat dimension + if dim in dims: + dim_idx = dims.index(dim) + shape[dim_idx] = ds.dims.get(dim, 1) + fill_data = np.full(shape, np.nan, dtype=ref_var.dtype) + fill_var = Variable(dims, fill_data, ref_var.attrs) + vars_to_concat.append(fill_var) result_vars[k] = Variable.concat(vars_to_concat, dim, indexers=None) - # Copy over variables that don't need concatenation + # Handle non-concatenated variables for ds in datasets: - for k, v in ds.variables.items(): + for k, v in iteritems(ds.variables): if k not in result_vars: if k not in concat_over: result_vars[k] = v - # Add coordinate if provided + # Add the concat coordinate if coord is not None: result_vars[dim] = coord - result_coord_names = set() - for ds in datasets: - result_coord_names.update(ds.coords) - if coord is not None: - result_coord_names.add(dim) - - return Dataset(result_vars, coords=result_coord_names) + return Dataset(result_vars, coords=all_coord_vars | {dim} if coord else all_coord_vars) This is getting complex. Let me look at an older version of xarray (when it was called xray) since the issue mentions "xray" which was the old name. Let me create a more focused patch: --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -3,7 +3,7 @@ import pandas as pd from . import utils from .variable import Variable, Coordinate from .pycompat import OrderedDict - +from .alignment import align def concat(objs, dim='concat_dim', data_vars='all', coords='different', compat='equals', positions=None, fill_value=dtypes.NA): @@ -31,6 +31,13 @@ def concat(objs, dim='concat_dim', data_vars='all', coords='different', * 'all': All data variables will be concatenated. * list of str: The listed data variables will be concatenated, in addition to the 'minimal' data variables. + + The ``data_vars`` argument controls how variables that are not present + in all datasets are handled. When ``data_vars='all'``, variables that + are missing from some datasets will be filled with ``fill_value`` + (default NaN). This is similar to the outer join behavior of + ``pd.concat``. + If objects are DataArrays, data_vars must be 'all'. coords : {'minimal', 'different', 'all'} or list of str, optional These coordinate variables will be concatenated together: @@ -51,6 +58,9 @@ def concat(objs, dim='concat_dim', data_vars='all', coords='different', positions : None or list of integer arrays, optional List of integer arrays which specifies the integer positions to which to assign each dataset along the concatenated dimension. + fill_value : scalar, optional + Value to use for newly missing values. Default is NaN. + Only relevant when ``data_vars='all'``. Returns ------- @@ -140,7 +150,7 @@ def _calc_concat_over(datasets, dim, data_vars, coords): return concat_over, equals -def _dataset_concat(datasets, dim, data_vars, coords, compat, positions): +def _dataset_concat(datasets, dim, data_vars, coords, compat, positions, fill_value=dtypes.NA): """ Concatenate a sequence of datasets along a new or existing dimension """ @@ -148,6 +158,7 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions): from .dataset import Dataset dim, coord = _calc_concat_dim_coord(dim) + datasets = list(datasets) # Determine which variables to concatenate concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords) @@ -157,6 +168,19 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions): for ds in datasets: all_vars.update(ds.variables) + # Ensure all datasets have the same variables by filling with fill_value + def ensure_variable(ds, name, ref_var, fill_value): + """Create a variable filled with fill_value matching ref_var structure.""" + import numpy as np + dims = tuple(d for d in ref_var.dims if d != dim) + shape = tuple(ds.dims.get(d, ref_var.sizes[d]) for d in dims) + dtype = ref_var.dtype + if np.issubdtype(dtype, np.integer): + dtype = float # Need float to hold NaN + data = np.full(shape, fill_value, dtype=dtype) + return Variable(dims, data, ref_var.attrs) + + # Find reference variables for each variable name result_vars = OrderedDict() result_coord_names = set() @@ -165,7 +189,24 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions): # Concatenate variables in concat_over for k in concat_over: - vars = [ds.variables[k] for ds in datasets] + # Find a reference variable (from first dataset that has it) + ref_var = None + for ds in datasets: + if k in ds.variables: + ref_var = ds.variables[k] + break + + vars = [] + for ds in datasets: + if k in ds.variables: + vars.append(ds.variables[k]) + elif ref_var is not None: + # Create a fill variable + fill_var = ensure_variable(ds, k, ref_var, fill_value) + vars.append(fill_var) + else: + raise ValueError("Variable %r not found" % k) combined = concat_vars(vars, dim, positions) result_vars[k] = combined if k in datasets[0].coords: