159 lines
7.1 KiB
Diff
159 lines
7.1 KiB
Diff
# pydata__xarray-3364
|
|
# Loki Mode Multi-Agent Patch
|
|
# Attempts: 1
|
|
|
|
Based on my knowledge of the xarray codebase and the issue description, I can generate the appropriate patch. The issue is from the pydata/xarray repository and concerns the `concat` function needing to handle datasets with different variables by doing an outer join (similar to pandas behavior).
|
|
|
|
Here is the patch:
|
|
|
|
```
|
|
--- a/xarray/core/concat.py
|
|
+++ b/xarray/core/concat.py
|
|
@@ -1,6 +1,7 @@
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import pandas as pd
|
|
+import numpy as np
|
|
|
|
from . import utils
|
|
from .alignment import align
|
|
@@ -10,7 +11,8 @@
|
|
|
|
|
|
def concat(objs, dim=None, data_vars='all', coords='different',
|
|
- compat='equals', positions=None, indexers=None, mode=None):
|
|
+ compat='equals', positions=None, indexers=None, mode=None,
|
|
+ join='outer', fill_value=np.nan):
|
|
"""Concatenate xarray objects along a new or existing dimension.
|
|
|
|
Parameters
|
|
@@ -59,6 +61,15 @@ def concat(objs, dim=None, data_vars='all', coords='different',
|
|
this argument is passed on to :py:func:`~xarray.align`. See the
|
|
documentation of :py:func:`~xarray.align` for more information. Only
|
|
used if ``coords`` is not ``'minimal'``.
|
|
+ join : {'outer', 'inner', 'left', 'right', 'exact'}, optional
|
|
+ String indicating how to combine differing variables across datasets:
|
|
+ - 'outer': union of variables, missing variables filled with fill_value
|
|
+ - 'inner': intersection of variables only
|
|
+ - 'left': use variables from the first dataset
|
|
+ - 'right': use variables from the last dataset
|
|
+ - 'exact': raise an error if variables don't match exactly (default before this change)
|
|
+ fill_value : scalar, optional
|
|
+ Value to use for missing variables when join='outer'. Default is NaN.
|
|
|
|
Returns
|
|
-------
|
|
@@ -102,7 +113,8 @@ def concat(objs, dim=None, data_vars='all', coords='different',
|
|
if isinstance(first_obj, DataArray):
|
|
return _dataarray_concat(objs, dim=dim, data_vars=data_vars,
|
|
coords=coords, compat=compat,
|
|
- positions=positions)
|
|
+ positions=positions, join=join,
|
|
+ fill_value=fill_value)
|
|
elif isinstance(first_obj, Dataset):
|
|
return _dataset_concat(objs, dim=dim, data_vars=data_vars,
|
|
coords=coords, compat=compat,
|
|
@@ -113,7 +125,8 @@ def concat(objs, dim=None, data_vars='all', coords='different',
|
|
|
|
|
|
def _dataarray_concat(arrays, dim, data_vars, coords, compat,
|
|
- positions):
|
|
+ positions, join='outer', fill_value=np.nan):
|
|
+ """Concatenate DataArrays along a dimension."""
|
|
arrays = list(arrays)
|
|
|
|
if data_vars != 'all':
|
|
@@ -135,7 +148,8 @@ def _dataarray_concat(arrays, dim, data_vars, coords, compat,
|
|
datasets.append(arr._to_temp_dataset())
|
|
|
|
ds = _dataset_concat(datasets, dim, data_vars=data_vars, coords=coords,
|
|
- compat=compat, positions=positions)
|
|
+ compat=compat, positions=positions, join=join,
|
|
+ fill_value=fill_value)
|
|
return _restore_dataarray_from_temp_dataset(ds)
|
|
|
|
|
|
@@ -145,11 +159,54 @@ def _calc_concat_over(datasets, dim, data_vars, coords):
|
|
concat_over = set()
|
|
equals = {}
|
|
|
|
+ # Get union of all variable names across datasets
|
|
+ all_data_vars = set()
|
|
+ all_coords = set()
|
|
+ for ds in datasets:
|
|
+ all_data_vars.update(ds.data_vars)
|
|
+ all_coords.update(ds.coords)
|
|
+
|
|
if dim in datasets[0]:
|
|
concat_over.add(dim)
|
|
for ds in datasets:
|
|
@@ -202,7 +259,8 @@ def _calc_concat_over(datasets, dim, data_vars, coords):
|
|
concat_over.update(concat_coords)
|
|
|
|
- return concat_over, equals
|
|
+ return concat_over, equals, all_data_vars, all_coords
|
|
|
|
|
|
def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
|
|
+def _dataset_concat(datasets, dim, data_vars, coords, compat, positions,
|
|
+ join='outer', fill_value=np.nan):
|
|
"""
|
|
Concatenate a sequence of datasets along a new or existing dimension
|
|
"""
|
|
@@ -222,7 +280,48 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
|
|
datasets = list(datasets)
|
|
dim, coord = _calc_concat_dim_coord(dim)
|
|
|
|
- concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords)
|
|
+ concat_over, equals, all_data_vars, all_coords = _calc_concat_over(
|
|
+ datasets, dim, data_vars, coords)
|
|
+
|
|
+ # Handle join logic for mismatched variables
|
|
+ if join == 'outer':
|
|
+ # Expand each dataset to include all variables from all datasets
|
|
+ expanded_datasets = []
|
|
+ for ds in datasets:
|
|
+ missing_data_vars = all_data_vars - set(ds.data_vars)
|
|
+ missing_coords = all_coords - set(ds.coords) - {dim}
|
|
+
|
|
+ if missing_data_vars or missing_coords:
|
|
+ # Create a copy and add missing variables filled with fill_value
|
|
+ ds = ds.copy()
|
|
+ for var in missing_data_vars:
|
|
+ # Find the variable in another dataset to get its dims
|
|
+ for other_ds in datasets:
|
|
+ if var in other_ds.data_vars:
|
|
+ template = other_ds[var]
|
|
+ # Create shape based on this dataset's dimensions
|
|
+ shape = tuple(ds.dims.get(d, template.sizes[d])
|
|
+ for d in template.dims)
|
|
+ data = np.full(shape, fill_value, dtype=template.dtype)
|
|
+ ds[var] = (template.dims, data, template.attrs.copy())
|
|
+ break
|
|
+ expanded_datasets.append(ds)
|
|
+ datasets = expanded_datasets
|
|
+ elif join == 'inner':
|
|
+ # Keep only variables present in all datasets
|
|
+ common_data_vars = set.intersection(*[set(ds.data_vars) for ds in datasets])
|
|
+ datasets = [ds[list(common_data_vars)] for ds in datasets]
|
|
+ elif join == 'left':
|
|
+ # Keep only variables from first dataset
|
|
+ first_vars = set(datasets[0].data_vars)
|
|
+ datasets = [ds[[v for v in ds.data_vars if v in first_vars]] for ds in datasets]
|
|
+ elif join == 'right':
|
|
+ # Keep only variables from last dataset
|
|
+ last_vars = set(datasets[-1].data_vars)
|
|
+ datasets = [ds[[v for v in ds.data_vars if v in last_vars]] for ds in datasets]
|
|
+ elif join == 'exact':
|
|
+ # Raise error if variables don't match (legacy behavior)
|
|
+ first_vars = set(datasets[0].data_vars)
|
|
+ for i, ds in enumerate(datasets[1:], 1):
|
|
+ if set(ds.data_vars) != first_vars:
|
|
+ raise ValueError(
|
|
+ f"Dataset {i} has different variables than first dataset. "
|
|
+ f"Use join='outer' to handle datasets with different variables."
|
|
+ )
|
|
|
|
# determine which variables to merge, and then merge them according to compat
|
|
variables_to_merge = (coord_names | data_names) - concat_over - dim_names
|
|
```
|