diff --git a/src/muse/__main__.py b/src/muse/__main__.py index 80f6abee1..8c033059d 100644 --- a/src/muse/__main__.py +++ b/src/muse/__main__.py @@ -98,6 +98,15 @@ def patched_broadcast_compat_data(self, other): "`broadcast_regions` (see `muse.utilities`)." ) + if (isinstance(other, Variable)) and ("year" in self.dims) != ( + "year" in getattr(other, "dims", []) + ): + raise ValueError( + "Broadcasting along the 'year' dimension is required, but automatic " + "broadcasting is disabled. Please handle it explicitly using " + "`broadcast_years` (see `muse.utilities`)." + ) + # The rest of the function is copied directly from # xarray.core.variable._broadcast_compat_data if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]): diff --git a/src/muse/agents/agent.py b/src/muse/agents/agent.py index 09309c079..9b3ac775f 100644 --- a/src/muse/agents/agent.py +++ b/src/muse/agents/agent.py @@ -8,6 +8,7 @@ import xarray as xr from muse.timeslices import drop_timeslice +from muse.utilities import broadcast_years class AbstractAgent(ABC): @@ -494,7 +495,9 @@ def retirement_profile( investments = investments.reindex_like(profile, method="ffill") # Apply the retirement profile to the investments - new_assets = (investments * profile).rename(replacement="asset") + new_assets = (broadcast_years(investments, profile) * profile).rename( + replacement="asset" + ) new_assets["installed"] = "asset", [investment_year] * len(new_assets.asset) # The new assets have picked up quite a few coordinates along the way. diff --git a/src/muse/agents/factories.py b/src/muse/agents/factories.py index a906322e5..c87949ea8 100644 --- a/src/muse/agents/factories.py +++ b/src/muse/agents/factories.py @@ -10,6 +10,7 @@ from muse.agents.agent import Agent, InvestingAgent from muse.errors import AgentShareNotDefined, TechnologyNotDefined +from muse.utilities import broadcast_years def create_standard_agent( @@ -252,8 +253,8 @@ def _shared_capacity( techs = (existing > 0) & (shares > 0) techs = techs.any([u for u in techs.dims if u != "asset"]) if not any(techs): - return (capacity * shares).copy() - return (capacity * shares).sel(asset=techs.values).copy() + return (capacity * broadcast_years(shares, capacity)).copy() + return (capacity * broadcast_years(shares, capacity)).sel(asset=techs.values).copy() def _standardize_inputs( diff --git a/src/muse/costs.py b/src/muse/costs.py index d26ca8c82..75e500d3f 100644 --- a/src/muse/costs.py +++ b/src/muse/costs.py @@ -55,6 +55,7 @@ from muse.commodities import is_enduse, is_fuel, is_material, is_pollutant from muse.quantities import production_amplitude from muse.timeslices import broadcast_timeslice, distribute_timeslice, get_level +from muse.utilities import broadcast_years def cost(func): @@ -625,10 +626,11 @@ def annual_to_lifetime(costs: xr.DataArray, technologies: xr.Dataset): rates = discount_factor( years=years, interest_rate=technologies.interest_rate, - mask=years <= life, + mask=years <= broadcast_years(life, years), ) if "timeslice" in costs.dims: rates = broadcast_timeslice(rates, level=get_level(costs)) + costs = broadcast_years(costs, years) return (costs * rates).sum("year") @@ -648,7 +650,7 @@ def discount_factor( assert "year" not in interest_rate.dims # Calculate discount factor over the years - df = 1 / (1 + interest_rate) ** years + df = 1 / (1 + broadcast_years(interest_rate, years)) ** years # Apply mask if mask is not None: diff --git a/src/muse/dispatch.py b/src/muse/dispatch.py index bb4cfc289..e2bcbe821 100644 --- a/src/muse/dispatch.py +++ b/src/muse/dispatch.py @@ -49,6 +49,7 @@ def production( import xarray as xr from muse.registration import registrator +from muse.utilities import check_dimensions class PRODUCTION_SIGNATURE(Protocol): @@ -141,8 +142,11 @@ def share_based_production( from muse.quantities import emission, maximum_production, minimum_production from muse.utilities import broadcast_over_assets - assert "asset" not in demand.dims - assert "asset" in capacity.dims + check_dimensions(demand, ["timeslice", "commodity"], optional=["region"]) + check_dimensions(capacity, ["asset"], optional=["dst_region"]) + check_dimensions( + technologies, ["asset", "commodity"], optional=["timeslice", "dst_region"] + ) # Maximum and minimum production for each asset maxprod = maximum_production( @@ -234,9 +238,15 @@ def merit_order_production( ) from muse.utilities import broadcast_over_assets - assert "asset" not in demand.dims - assert "asset" in capacity.dims - assert "asset" not in prices.dims + if "dst_region" in technologies.dims: + raise ValueError( + "Merit-order dispatch is not currently compatible with trade models." + ) + + check_dimensions(demand, ["timeslice", "commodity"], optional=["region"]) + check_dimensions(capacity, ["asset"]) + check_dimensions(technologies, ["asset", "commodity"], optional=["timeslice"]) + check_dimensions(prices, ["timeslice", "commodity"], optional=["region"]) # Normalise demand/prices dataarrays to ensure they have a region dimension # Multi-region models will already have a region dimension @@ -283,51 +293,36 @@ def merit_order_production( # Initialise result with zeros result = xr.zeros_like(maxprod) - for y in maxprod.year.values: - prices_y = prices.sel(year=y) - maxprod_y = maxprod.sel(year=y) - minprod_y = minprod.sel(year=y) - maxcons_y = maxcons.sel(year=y) - - for region in demand.region.values: - maxprod_region = maxprod_y.sel( - asset=maxprod_y.asset[maxprod_y.region == region] - ) - techs_region = technologies.sel( - asset=technologies.asset[technologies.region == region] - ) - minprod_region = minprod_y.sel( - asset=minprod_y.asset[minprod_y.region == region] - ) - maxcons_region = maxcons_y.sel( - asset=maxcons_y.asset[maxcons_y.region == region] - ) + for region in demand.region.values: + # Select data for this region + maxprod_region = maxprod.sel(asset=maxprod.asset[maxprod.region == region]) + techs_region = technologies.sel( + asset=technologies.asset[technologies.region == region] + ) + minprod_region = minprod.sel(asset=minprod.asset[minprod.region == region]) + maxcons_region = maxcons.sel(asset=maxcons.asset[maxcons.region == region]) + + # Calculate timeslice-level costs for each asset assuming full + # dispatch. We use LCOE excluding capital costs. + technology_costs = marginal_cost( + techs_region, + prices.sel(region=region), + production=maxprod_region, + consumption=maxcons_region, + ) - # Calculate timeslice-level costs for each asset in this year assuming full - # dispatch. We use LCOE excluding capital costs. - technology_costs = marginal_cost( - techs_region, - prices_y.sel(region=region), - production=maxprod_region, - consumption=maxcons_region, + # Calculate production by dispatching assets in order of + # increasing cost until demand is met + for ts in maxprod.timeslice.values: + dispatch = dispatch_by_merit_order( + demand=demand.sel(timeslice=ts, region=region), + minprod=minprod_region.sel(timeslice=ts), + maxprod=maxprod_region.sel(timeslice=ts), + technology_costs=technology_costs.sel(timeslice=ts), ) - - # Calculate production for this year by dispatching assets in order of - # increasing cost until demand is met - for ts in maxprod_y.timeslice.values: - dispatch = dispatch_by_merit_order( - demand=demand.sel(year=y, timeslice=ts, region=region), - minprod=minprod_region.sel(timeslice=ts), - maxprod=maxprod_region.sel(timeslice=ts), - technology_costs=technology_costs.sel(timeslice=ts), - ) - result.loc[ - dict( - year=y, - timeslice=ts, - asset=result.asset[result.region == region], - ) - ] = dispatch + result.loc[ + dict(timeslice=ts, asset=result.asset[result.region == region]) + ] = dispatch # Add production of environmental pollutants env = is_pollutant(technologies.comm_usage) diff --git a/src/muse/examples.py b/src/muse/examples.py index 189a82c85..89b0a787a 100644 --- a/src/muse/examples.py +++ b/src/muse/examples.py @@ -39,6 +39,7 @@ from muse.mca import MCA from muse.sectors import AbstractSector from muse.timeslices import drop_timeslice +from muse.utilities import broadcast_years __all__ = ["model", "technodata"] @@ -271,6 +272,7 @@ def matching_market(sector: str, model: str = "default") -> xr.Dataset: market = xr.Dataset() techs = broadcast_over_assets(loaded_sector.technologies, assets.capacity) + techs = broadcast_years(techs, assets.capacity) production = maximum_production(techs, assets.capacity) market["supply"] = production.sum("asset") if "dst_region" in market.dims: diff --git a/src/muse/investments.py b/src/muse/investments.py index 8a5149360..c5e5a361b 100644 --- a/src/muse/investments.py +++ b/src/muse/investments.py @@ -65,6 +65,7 @@ def investment( from muse.outputs.cache import cache_quantity from muse.registration import registrator from muse.timeslices import timeslice_max +from muse.utilities import broadcast_years INVESTMENT_SIGNATURE = Callable[ [xr.DataArray, xr.DataArray, xr.Dataset, list[Constraint], KwArg(Any)], @@ -203,7 +204,7 @@ def cliff_retirement_profile( dims="year", coords={"year": range(investment_year, max_year + 1)}, ) - profile = allyears < (investment_year + technical_life) # type: ignore + profile = allyears < broadcast_years(investment_year + technical_life, allyears) # type: ignore # Minimize the number of years needed to represent the profile fully # This is done by removing the central year of any three repeating years, ensuring diff --git a/src/muse/readers/toml.py b/src/muse/readers/toml.py index 5c49320d5..1058874f2 100644 --- a/src/muse/readers/toml.py +++ b/src/muse/readers/toml.py @@ -16,6 +16,7 @@ import xarray as xr from muse.defaults import DATA_DIRECTORY +from muse.utilities import broadcast_years DEFAULT_SETTINGS_PATH = DATA_DIRECTORY / "default_settings.toml" """Default settings path.""" @@ -721,7 +722,9 @@ def read_correlation_consumption(sector_conf: Any) -> xr.Dataset: # Split by timeslice if sector_conf.timeslice_shares_path is not None: shares = read_timeslice_shares(sector_conf.timeslice_shares_path) - consumption = broadcast_timeslice(consumption) * shares + consumption = broadcast_timeslice(consumption) * broadcast_years( + shares, consumption.year + ) else: consumption = distribute_timeslice(consumption) diff --git a/src/muse/regressions.py b/src/muse/regressions.py index 39560abe3..ee531d569 100644 --- a/src/muse/regressions.py +++ b/src/muse/regressions.py @@ -9,6 +9,8 @@ from xarray import DataArray, Dataset +from muse.utilities import broadcast_years + __all__ = [ "Exponential", "ExponentialAdj", @@ -356,8 +358,9 @@ def Exponential( ) -> DataArray: from numpy import exp - factor = 1e6 * self.coeffs.a * population - return factor * exp(self.coeffs.b * population / gdp) + coeffs = broadcast_years(self.coeffs, gdp.year) + factor = 1e6 * coeffs.a * population + return factor * exp(coeffs.b * population / gdp) @register_regression @@ -377,10 +380,11 @@ def ExponentialAdj( if year is None: year = self.base_year - factor = 1e6 * self.coeffs.a * population - unadjusted = factor * exp(self.coeffs.b * population / gdp) + coeffs = broadcast_years(self.coeffs, gdp.year) + factor = 1e6 * coeffs.a * population + unadjusted = factor * exp(coeffs.b * population / gdp) p = power(year + forecast - self.base_year, n) - return unadjusted * (1 + self.coeffs.w * p) / (1 + p) + return unadjusted * (1 + coeffs.w * p) / (1 + p) @register_regression @@ -394,7 +398,8 @@ def Logistic( """ from numpy import exp, power - a, b, c, w = self.coeffs.a, self.coeffs.b, self.coeffs.c, self.coeffs.w + coeffs = broadcast_years(self.coeffs, gdp.year) + a, b, c, w = coeffs.a, coeffs.b, coeffs.c, coeffs.w p = power(forecast, n) factor = 1e6 * a * population * (1 + w * p) / (1 + p) return factor / (1 + b * exp(gdp * c / population)) @@ -406,8 +411,9 @@ def Loglog(self, gdp: DataArray, population: DataArray, *args, **kwargs) -> Data """1e6 * e^a * population * (gpd/population)^b.""" from numpy import exp, power - factor = 1e6 * exp(self.coeffs.a) * population - return factor * power(gdp / population, self.coeffs.b) + coeffs = broadcast_years(self.coeffs, gdp.year) + factor = 1e6 * exp(coeffs.a) * population + return factor * power(gdp / population, coeffs.b) @register_regression @@ -424,6 +430,7 @@ def LogisticSigmoid( ) -> DataArray: """0.001 * (constant * pop + gdp * c / sqrt(1 + (gdp * scale / pop)^2).""" from numpy import power + from xarray import ones_like constant = self.coeffs.a c = self.coeffs.c @@ -442,8 +449,11 @@ def LogisticSigmoid( # fmt: enable scale = self.coeffs.b0.where(years < 2015, self.coeffs.b1) else: - scale = 1 + scale = ones_like(self.coeffs.b0) + scale = broadcast_years(scale, gdp.year) + constant = broadcast_years(constant, gdp.year) + c = broadcast_years(c, gdp.year) p = power(1 + power(gdp * scale / population, 2), 0.5) return 0.001 * (constant * population + gdp * c / p) @@ -498,7 +508,10 @@ def __call__( if year is not None and "year" in data.dims: data = data.interp(year=year, method=self.interpolation) - return coeffs.a * data.population + scale * ( + a = broadcast_years(coeffs.a, data.year) + scale = broadcast_years(scale, data.year) + gdpcap_offset = broadcast_years(gdpcap_offset, data.year) + return a * data.population + scale * ( data.gdp - gdpcap_offset * data.population ) diff --git a/src/muse/sectors/sector.py b/src/muse/sectors/sector.py index b429523d2..bf5b4b226 100644 --- a/src/muse/sectors/sector.py +++ b/src/muse/sectors/sector.py @@ -286,45 +286,58 @@ def market_variables(self, market: xr.Dataset, technologies: xr.Dataset) -> Any: technologies, capacity, installed_as_year=True ) - # Calculate supply - supply = self.supply_prod( - demand=market.consumption, - capacity=capacity, - technologies=technodata, - timeslice_level=self.timeslice_level, - prices=market.prices, - ) + # Calculate supply/consumption/costs (one year at a time) + supply_list = [] + consume_list = [] + costs_list = [] + for year in capacity.year.values: + supply_year = self.supply_prod( + demand=market.consumption.sel(year=year), + capacity=capacity.sel(year=year), + technologies=technodata, + timeslice_level=self.timeslice_level, + prices=market.prices.sel(year=year), + ) - # Select relevant prices for each asset - prices = broadcast_over_assets(market.prices, capacity, installed_as_year=False) + # Select relevant prices for each asset + prices_for_assets = broadcast_over_assets( + market.prices.sel(year=year), capacity, installed_as_year=False + ) - # Calculate consumption - consume = consumption( - technologies=technodata, - production=supply, - prices=prices, - timeslice_level=self.timeslice_level, - ) + # Calculate consumption + consume_year = consumption( + technologies=technodata, + production=supply_year, + prices=prices_for_assets, + timeslice_level=self.timeslice_level, + ) - # Calculate LCOE - # We select data for the second year, which corresponds to the investment year - # We base LCOE only on the portion of capacity that is actually used (#728) - utilized_capacity = capacity_to_service_demand( - demand=supply.isel(year=1), - technologies=technodata, - timeslice_level=self.timeslice_level, - ) - lcoe = levelized_cost_of_energy( - prices=prices.isel(year=1), - technologies=technodata, - capacity=utilized_capacity, - production=supply.isel(year=1), - consumption=consume.isel(year=1), - method="annual", - ) + # Calculate LCOE + # We base LCOE only on the portion of capacity that is actually used (#728) + utilized_capacity = capacity_to_service_demand( + demand=supply_year, + technologies=technodata, + timeslice_level=self.timeslice_level, + ) + lcoe = levelized_cost_of_energy( + prices=prices_for_assets, + technologies=technodata, + capacity=utilized_capacity, + production=supply_year, + consumption=consume_year, + method="annual", + ) + + # Calculate new commodity prices + costs_year = supply_cost(supply_year, lcoe, asset_dim="asset") + + supply_list.append(supply_year.expand_dims(year=[year])) + consume_list.append(consume_year.expand_dims(year=[year])) + costs_list.append(costs_year.expand_dims(year=[year])) - # Calculate new commodity prices - costs = supply_cost(supply, lcoe, asset_dim="asset") + supply = xr.concat(supply_list, dim="year") + consume = xr.concat(consume_list, dim="year") + costs = xr.concat(costs_list, dim="year") return supply, consume, costs diff --git a/src/muse/timeslices.py b/src/muse/timeslices.py index 4017ce2b8..9215340ae 100644 --- a/src/muse/timeslices.py +++ b/src/muse/timeslices.py @@ -21,7 +21,7 @@ import pandas as pd from xarray import DataArray -from muse.utilities import broadcast_regions +from muse.utilities import broadcast_regions, broadcast_years TIMESLICE: DataArray = None # type: ignore @@ -160,6 +160,8 @@ def distribute_timeslice( timeslice_fractions = ts / broadcast_timeslice(timeslice_sum, ts=ts) if "region" in data.dims: timeslice_fractions = broadcast_regions(timeslice_fractions, data) + if "year" in data.dims: + timeslice_fractions = broadcast_years(timeslice_fractions, data) return broadcasted * timeslice_fractions diff --git a/src/muse/utilities.py b/src/muse/utilities.py index 0c1728c05..8465059ac 100644 --- a/src/muse/utilities.py +++ b/src/muse/utilities.py @@ -707,3 +707,23 @@ def broadcast_regions(data: xr.DataArray, template: xr.DataArray) -> xr.DataArra raise ValueError("Data is already regioned, but does not match the reference.") return data.expand_dims(region=template.region) + + +def broadcast_years(data: xr.DataArray, template: xr.DataArray) -> xr.DataArray: + """Convert a non-year array to a year array by broadcasting. + + If data is already yeared in the appropriate scheme, it will be returned + unchanged. + + Args: + data: Array to broadcast. + template: Dataarray with year coordinates to broadcast to. + + """ + # If data already has years, check that it matches the template years. + if "year" in data.dims: + if data.year.equals(template.year): + return data + raise ValueError("Data is already yeared, but does not match the reference.") + + return data.expand_dims(year=template.year) diff --git a/tests/conftest.py b/tests/conftest.py index 63061a2c8..7207cfacc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from muse.__main__ import patched_broadcast_compat_data from muse.agents import Agent -from muse.utilities import broadcast_regions +from muse.utilities import broadcast_regions, broadcast_years @contextmanager @@ -225,7 +225,7 @@ def agent_args(coords) -> Mapping: @fixture def technologies(coords) -> Dataset: """Randomly generated technology characteristics.""" - from numpy import nonzero, sum + from numpy import nonzero from numpy.random import choice, rand, randint from muse.commodities import CommodityUsage @@ -242,9 +242,7 @@ def var(*dims, factor=100.0): return dims, (rand(*shape) * factor).astype(type(factor)) result["agent_share"] = var("technology", "region", "year") - result["agent_share"] /= broadcast_regions( - sum(result.agent_share), result.agent_share - ) + result["agent_share"] /= result.agent_share.sum("technology") result["agent_share_zero"] = result["agent_share"] * 0 # first create a mask so each tech will have consistent inputs/outputs across years @@ -276,21 +274,18 @@ def var(*dims, factor=100.0): fout.loc[{"technology": tech, "commodity": i}] = 1 # expand along year and region, and fill with random numbers - ones = broadcast_regions(result.year == result.year, result.region) * ( - result.region == result.region - ) - result["fixed_inputs"] = ( - broadcast_regions(result.fixed_inputs, result.region) * ones + fixed_in = broadcast_years( + broadcast_regions(result.fixed_inputs, result.region), result.year ) - result.fixed_inputs[:] *= rand(*result.fixed_inputs.shape) - result["flexible_inputs"] = ( - broadcast_regions(result.flexible_inputs, result.region) * ones + result["fixed_inputs"] = fixed_in * rand(*fixed_in.shape) + flex_in = broadcast_years( + broadcast_regions(result.flexible_inputs, result.region), result.year ) - result.flexible_inputs[:] *= rand(*result.flexible_inputs.shape) - result["fixed_outputs"] = ( - broadcast_regions(result.fixed_outputs, result.region) * ones + result["flexible_inputs"] = flex_in * rand(*flex_in.shape) + fixed_out = broadcast_years( + broadcast_regions(result.fixed_outputs, result.region), result.year ) - result.fixed_outputs[:] *= rand(*result.fixed_outputs.shape) + result["fixed_outputs"] = fixed_out * rand(*fixed_out.shape) result["total_capacity_limit"] = var("technology", "region", "year") result.total_capacity_limit.loc[{"year": 2030}] += result.total_capacity_limit.sel( diff --git a/tests/test_constraints.py b/tests/test_constraints.py index eb9ba28e7..73a9e6a55 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -61,10 +61,8 @@ def model_data(): # Create initial market demand as 80% of maximum production market_demand = 0.8 * maximum_production( broadcast_over_assets(technologies, assets), - assets.capacity, - ).sel(year=INVESTMENT_YEAR).groupby("technology").sum("asset").rename( - technology="asset" - ) + assets.capacity.sel(year=INVESTMENT_YEAR), + ).groupby("technology").sum("asset").rename(technology="asset") # Remove un-demanded commodities market_demand = market_demand.sel( diff --git a/tests/test_demand_share.py b/tests/test_demand_share.py index 428446072..032fc4080 100644 --- a/tests/test_demand_share.py +++ b/tests/test_demand_share.py @@ -7,7 +7,7 @@ from muse.commodities import is_enduse from muse.quantities import maximum_production from muse.timeslices import drop_timeslice -from muse.utilities import broadcast_over_assets, interpolate_capacity +from muse.utilities import broadcast_over_assets, broadcast_years, interpolate_capacity CURRENT_YEAR = 2010 INVESTMENT_YEAR = 2030 @@ -131,8 +131,9 @@ def _matching_market(technologies, capacity): from muse.quantities import consumption as calc_consumption # Calculate production and consumption - production = maximum_production(technologies, capacity) - cons = calc_consumption(technologies, production) + techs = broadcast_years(technologies, capacity.year) + production = maximum_production(techs, capacity) + cons = calc_consumption(techs, production) # Handle regional grouping if needed if "region" in production.coords: diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index 17f950367..7cabcda70 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -7,6 +7,11 @@ from muse.utilities import broadcast_over_assets +@fixture +def capacity(capacity: xr.DataArray) -> xr.DataArray: + return capacity.isel(year=0) + + @fixture def technologies( technologies: xr.Dataset, capacity: xr.DataArray, timeslice @@ -29,7 +34,7 @@ def production( def prices( technologies: xr.Dataset, capacity: xr.DataArray, timeslice: xr.DataArray ) -> xr.DataArray: - # Make random prices for all commodities/timeslices/years/regions + # Make random prices for all commodities/timeslices/regions regions = xr.DataArray( capacity["region"].to_index().unique(), dims="region", @@ -39,16 +44,14 @@ def prices( np.random.rand( technologies.sizes["commodity"], timeslice.sizes["timeslice"], - capacity.sizes["year"], regions.size, ), coords={ "commodity": technologies.coords["commodity"], "timeslice": timeslice.coords["timeslice"], - "year": capacity.coords["year"], "region": regions, }, - dims=("commodity", "timeslice", "year", "region"), + dims=("commodity", "timeslice", "region"), ) return prices_by_region diff --git a/tests/test_quantities.py b/tests/test_quantities.py index 4c57e4aee..e0bf55076 100644 --- a/tests/test_quantities.py +++ b/tests/test_quantities.py @@ -15,6 +15,11 @@ from muse.utilities import broadcast_over_assets +@fixture +def capacity(capacity: xr.DataArray) -> xr.DataArray: + return capacity.isel(year=0) + + @fixture def technologies( technologies: xr.Dataset, capacity: xr.DataArray, timeslice diff --git a/tests/test_regressions.py b/tests/test_regressions.py index d62dfa005..5511e6f31 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -3,6 +3,7 @@ from pytest import approx, fixture from muse.regressions import Exponential, Linear +from muse.utilities import broadcast_years def create_dataset(coords, random_vars): @@ -47,8 +48,10 @@ def test_exponential(regression_params, drivers): # Calculate expected and actual results actual = functor(drivers) - factor = 1e6 * drivers.population * rp.a - expected = factor * exp(drivers.population / drivers.gdp * rp.b) + factor = 1e6 * drivers.population * broadcast_years(rp.a, drivers.year) + expected = factor * exp( + drivers.population / drivers.gdp * broadcast_years(rp.b, drivers.year) + ) expected, actual = broadcast(expected, actual) assert actual.values == approx(expected.values) @@ -68,8 +71,10 @@ def test_linear(regression_params, drivers): # Test basic functionality actual = functor(drivers, forecast=2) offset = drivers.gdp.sel(year=2010) / drivers.population.sel(year=2010) - expected = rp.a * drivers.population + rp.b0 * ( - drivers.gdp - offset * drivers.population + expected = broadcast_years( + rp.a, drivers.year + ) * drivers.population + broadcast_years(rp.b0, drivers.year) * ( + drivers.gdp - broadcast_years(offset, drivers.year) * drivers.population ) actual, expected = broadcast(actual, expected) assert actual.values == approx(expected.values) @@ -81,7 +86,9 @@ def test_linear(regression_params, drivers): ) population = drivers.population.interp(year=year, method="linear") gdp = drivers.gdp.interp(year=year, method="linear") - expected = rp.a * population + scale * (gdp - offset * population) + expected = broadcast_years(rp.a, population.year) * population + scale * ( + gdp - broadcast_years(offset, population.year) * population + ) actual = functor(drivers, forecast=2, year=year) actual, expected = broadcast(actual, expected) assert actual.values == approx(expected.values)