From 86020d0a988dbefb045c1b0ff01b53ac374ce6e8 Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Wed, 8 Apr 2026 16:08:53 -0400 Subject: [PATCH 1/8] feat: add discrete variable support Implement discrete variable handling across the full stack: - Discipline base class: add_discrete_input/add_discrete_output methods with VariableMetaData using kDiscreteInput/kDiscreteOutput types - Server: process_inputs demuxes VariableMessage oneof into continuous and discrete dicts; yields VariableMessage wrappers on output - Client: _assemble_input_messages wraps both Array and DiscreteVariable in VariableMessage; _recover_outputs handles discrete responses - OpenMDAO bindings: auto-discover discrete vars from server metadata, declare them via add_discrete_input/output, and forward through compute/apply_nonlinear/linearize calls Discrete values use google.protobuf.Value for language-interoperable serialization of scalars, lists, and nested structures. Updates proto submodule to feature/discrete_vars (DiscreteVariable message, VariableMessage wrapper, stream VariableMessage RPCs). Refs: MDO-Standards/Philote-Python#54 --- CHANGELOG.md | 9 + philote_mdo/general/discipline.py | 42 +++++ philote_mdo/general/discipline_client.py | 158 +++++++++++++----- philote_mdo/general/discipline_server.py | 140 ++++++++++++++-- philote_mdo/general/explicit_client.py | 36 +++- philote_mdo/general/explicit_server.py | 75 ++++++--- philote_mdo/general/implicit_client.py | 28 +++- philote_mdo/general/implicit_server.py | 116 +++++++++---- philote_mdo/generated/data_pb2.py | 16 +- philote_mdo/generated/data_pb2.pyi | 22 +++ philote_mdo/generated/disciplines_pb2.py | 8 +- philote_mdo/generated/disciplines_pb2_grpc.py | 30 ++-- philote_mdo/openmdao/explicit.py | 34 +++- philote_mdo/openmdao/implicit.py | 81 ++++----- philote_mdo/openmdao/utils.py | 27 ++- proto | 2 +- 16 files changed, 608 insertions(+), 216 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 763a855..a6b3237 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Features +- Added discrete variable support throughout the stack. Disciplines can + now declare discrete inputs/outputs via `add_discrete_input` / + `add_discrete_output`. Discrete data is serialized as + `google.protobuf.Value` (supporting scalars, lists, and nested + structures) and multiplexed alongside continuous `Array` chunks in + the new `VariableMessage` wrapper. The OpenMDAO bindings + (`RemoteExplicitComponent`, `RemoteImplicitComponent`) automatically + discover and forward discrete variables. + ### Bug Fixes - Fixed `SellarMDA` promoted-input ambiguity that newer OpenMDAO releases diff --git a/philote_mdo/general/discipline.py b/philote_mdo/general/discipline.py index d00ef8b..f5930cf 100644 --- a/philote_mdo/general/discipline.py +++ b/philote_mdo/general/discipline.py @@ -44,6 +44,9 @@ def __init__(self): # variable metadata self._var_meta = [] + # discrete variable metadata (name → default value) + self._discrete_var_meta = [] + # partials metadata self._partials_meta = [] @@ -90,6 +93,44 @@ def add_input(self, name, shape=(1,), units=""): meta.units = units self._var_meta += [meta] + def add_discrete_input(self, name, default=None): + """ + Define a discrete input. + + Discrete inputs can hold any value that is representable as a + ``google.protobuf.Value`` (scalars, lists, or nested dicts). + + Parameters + ---------- + name : string + the name of the discrete input variable + default : object, optional + the default value for the discrete input + """ + meta = data.VariableMetaData() + meta.type = data.VariableType.kDiscreteInput + meta.name = name + self._discrete_var_meta += [meta] + + def add_discrete_output(self, name, default=None): + """ + Define a discrete output. + + Discrete outputs can hold any value that is representable as a + ``google.protobuf.Value`` (scalars, lists, or nested dicts). + + Parameters + ---------- + name : string + the name of the discrete output variable + default : object, optional + the default value for the discrete output + """ + meta = data.VariableMetaData() + meta.type = data.VariableType.kDiscreteOutput + meta.name = name + self._discrete_var_meta += [meta] + def add_output(self, name, shape=(1,), units=""): """ Defines a continuous output. @@ -179,4 +220,5 @@ def _clear_data(self): This function is invoked from the Setup function of the server. """ self._var_meta = [] + self._discrete_var_meta = [] self._partials_meta = [] diff --git a/philote_mdo/general/discipline_client.py b/philote_mdo/general/discipline_client.py index 4630cca..385d077 100644 --- a/philote_mdo/general/discipline_client.py +++ b/philote_mdo/general/discipline_client.py @@ -32,6 +32,7 @@ import philote_mdo.generated.data_pb2 as data import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.utils as utils +from philote_mdo.general.discipline_server import _python_to_value, _value_to_python class DisciplineClient: @@ -59,6 +60,7 @@ def __init__(self, channel): # variable and partials metadata self._var_meta = [] + self._discrete_var_meta = [] self._partials_meta = [] # list of available options @@ -123,9 +125,18 @@ def run_setup(self): def get_variable_definitions(self): """ Requests the input and output metadata from the server. + + Both continuous and discrete variable metadata are stored in their + respective lists. """ for message in self._disc_stub.GetVariableDefinitions(empty.Empty()): - self._var_meta += [message] + if message.type in ( + data.VariableType.kDiscreteInput, + data.VariableType.kDiscreteOutput, + ): + self._discrete_var_meta += [message] + else: + self._var_meta += [message] def get_partials_definitions(self): """ @@ -135,52 +146,92 @@ def get_partials_definitions(self): if message.name not in self._partials_meta: self._partials_meta += [message] - def _assemble_input_messages(self, inputs, outputs=None): + def _assemble_input_messages( + self, inputs, outputs=None, discrete_inputs=None, discrete_outputs=None + ): """ Assembles the messages for transmitting the input variables to the server. + + Both continuous and discrete inputs are wrapped in ``VariableMessage`` + envelopes. """ messages = [] + # Continuous inputs for input_name, value in inputs.items(): for b, e in utils.get_chunk_indices( value.size, self._stream_options.num_double ): messages += [ - data.Array( - name=input_name, - start=b, - end=e - 1, - type=data.VariableType.kInput, - data=value.ravel()[b:e], + data.VariableMessage( + continuous=data.Array( + name=input_name, + start=b, + end=e - 1, + type=data.VariableType.kInput, + data=value.ravel()[b:e], + ) ) ] + # Continuous outputs (for implicit disciplines) if outputs: for output_name, value in outputs.items(): for b, e in utils.get_chunk_indices( value.size, self._stream_options.num_double ): messages += [ - data.Array( - name=output_name, - start=b, - end=e - 1, - type=data.VariableType.kOutput, - data=value.ravel()[b:e], + data.VariableMessage( + continuous=data.Array( + name=output_name, + start=b, + end=e - 1, + type=data.VariableType.kOutput, + data=value.ravel()[b:e], + ) ) ] + # Discrete inputs + if discrete_inputs: + for name, value in discrete_inputs.items(): + messages += [ + data.VariableMessage( + discrete=data.DiscreteVariable( + name=name, + type=data.VariableType.kDiscreteInput, + value=_python_to_value(value), + ) + ) + ] + + # Discrete outputs (for implicit disciplines) + if discrete_outputs: + for name, value in discrete_outputs.items(): + messages += [ + data.VariableMessage( + discrete=data.DiscreteVariable( + name=name, + type=data.VariableType.kDiscreteOutput, + value=_python_to_value(value), + ) + ) + ] + return messages def _recover_outputs(self, responses): """ Recovers the outputs from the stream of responses. + + Returns both continuous outputs and discrete outputs. """ outputs = {} flat_outputs = {} + discrete_outputs = {} - # preallocate + # preallocate continuous outputs for out in self._var_meta: if out.type == data.kOutput: name = out.name @@ -188,16 +239,27 @@ def _recover_outputs(self, responses): flat_outputs[name] = utils.get_flattened_view(outputs[name]) for message in responses: - if message.type == data.kOutput: - b = message.start - e = message.end + 1 - if len(message.data) > 0: - flat_outputs[message.name][b:e] = message.data - else: - raise ValueError( - "Expected continuous variables, but array is empty." - ) + variant = message.WhichOneof("payload") + + if variant == "continuous": + arr = message.continuous + if arr.type == data.kOutput: + b = arr.start + e = arr.end + 1 + if len(arr.data) > 0: + flat_outputs[arr.name][b:e] = arr.data + else: + raise ValueError( + "Expected continuous variables, but array is empty." + ) + elif variant == "discrete": + dv = message.discrete + if dv.type == data.VariableType.kDiscreteOutput: + discrete_outputs[dv.name] = _value_to_python(dv.value) + + if discrete_outputs: + return outputs, discrete_outputs return outputs def _recover_residuals(self, responses): @@ -215,15 +277,19 @@ def _recover_residuals(self, responses): flat_residuals[name] = utils.get_flattened_view(residuals[name]) for message in responses: - if message.type == data.kResidual: - b = message.start - e = message.end + 1 - if len(message.data) > 0: - flat_residuals[message.name][b:e] = message.data - else: - raise ValueError( - "Expected continuous variables, but array is empty." - ) + variant = message.WhichOneof("payload") + + if variant == "continuous": + arr = message.continuous + if arr.type == data.kResidual: + b = arr.start + e = arr.end + 1 + if len(arr.data) > 0: + flat_residuals[arr.name][b:e] = arr.data + else: + raise ValueError( + "Expected continuous variables, but array is empty." + ) return residuals @@ -257,16 +323,20 @@ def _recover_partials(self, responses): ) for message in responses: - b = message.start - e = message.end + 1 - - if message.type == data.kPartial: - if len(message.data) > 0: - flat_p[(message.name, message.subname)][b:e] = message.data - else: - raise ValueError( - "Expected continuous outputs for the " - "partials, but array was empty." - ) + variant = message.WhichOneof("payload") + + if variant == "continuous": + arr = message.continuous + b = arr.start + e = arr.end + 1 + + if arr.type == data.kPartial: + if len(arr.data) > 0: + flat_p[(arr.name, arr.subname)][b:e] = arr.data + else: + raise ValueError( + "Expected continuous outputs for the " + "partials, but array was empty." + ) return partials diff --git a/philote_mdo/general/discipline_server.py b/philote_mdo/general/discipline_server.py index 5af98df..5e57904 100644 --- a/philote_mdo/general/discipline_server.py +++ b/philote_mdo/general/discipline_server.py @@ -32,6 +32,7 @@ import philote_mdo.generated.data_pb2 as data import philote_mdo.generated.disciplines_pb2_grpc as disc from google.protobuf.empty_pb2 import Empty +from google.protobuf import struct_pb2 from philote_mdo.utils import PairDict, get_flattened_view @@ -128,10 +129,15 @@ def Setup(self, request, context): def GetVariableDefinitions(self, request, context): """ Transmits variable metadata about the analysis discipline to the client. + + Both continuous and discrete variable metadata are streamed. """ for var in self._discipline._var_meta: yield var + for var in self._discipline._discrete_var_meta: + yield var + def GetPartialDefinitions(self, request, context): """ Transmits partials metadata about the analysis discipline to the client. @@ -193,27 +199,127 @@ def preallocate_partials(self): return jac - def process_inputs(self, request_iterator, flat_inputs, flat_outputs=None): + def process_inputs( + self, + request_iterator, + flat_inputs, + flat_outputs=None, + discrete_inputs=None, + discrete_outputs=None, + ): """ Processes the message inputs from a gRPC stream. + The stream consists of ``VariableMessage`` wrappers, each of which + contains either a continuous ``Array`` or a ``DiscreteVariable``. + Note, for implicit disciplines, the function values are considered inputs to evaluate the residuals and the partials of the residuals. """ - # process inputs + if discrete_inputs is None: + discrete_inputs = {} + if discrete_outputs is None: + discrete_outputs = {} + for message in request_iterator: - # start and end indices for the array chunk - b = message.start - e = message.end - - # assign either continuous or discrete data - if len(message.data) > 0: - if message.type == data.VariableType.kInput: - flat_inputs[message.name][b : e + 1] = message.data - elif message.type == data.VariableType.kOutput: - flat_outputs[message.name][b : e + 1] = message.data - else: - raise ValueError( - "Expected continuous variables but arrays were" - " empty for variable %s." % (message.name) - ) + variant = message.WhichOneof("payload") + + if variant == "continuous": + arr = message.continuous + b = arr.start + e = arr.end + + if len(arr.data) > 0: + if arr.type == data.VariableType.kInput: + flat_inputs[arr.name][b : e + 1] = arr.data + elif arr.type == data.VariableType.kOutput: + flat_outputs[arr.name][b : e + 1] = arr.data + else: + raise ValueError( + "Expected continuous variables but arrays were" + " empty for variable %s." % (arr.name) + ) + + elif variant == "discrete": + dv = message.discrete + # Convert protobuf Value to native Python type + native_value = _value_to_python(dv.value) + + if dv.type == data.VariableType.kDiscreteInput: + discrete_inputs[dv.name] = native_value + elif dv.type == data.VariableType.kDiscreteOutput: + discrete_outputs[dv.name] = native_value + + return discrete_inputs, discrete_outputs + + +def _value_to_python(value): + """ + Converts a ``google.protobuf.Value`` to a native Python object. + + Parameters + ---------- + value : google.protobuf.Value + protobuf Value message + + Returns + ------- + object + Native Python equivalent (None, bool, int/float, str, list, or dict) + """ + kind = value.WhichOneof("kind") + + if kind == "null_value": + return None + elif kind == "bool_value": + return value.bool_value + elif kind == "number_value": + # protobuf stores all numbers as doubles; return int if lossless + num = value.number_value + if num == int(num): + return int(num) + return num + elif kind == "string_value": + return value.string_value + elif kind == "list_value": + return [_value_to_python(v) for v in value.list_value.values] + elif kind == "struct_value": + return {k: _value_to_python(v) for k, v in value.struct_value.fields.items()} + else: + return None + + +def _python_to_value(obj): + """ + Converts a native Python object to a ``google.protobuf.Value``. + + Parameters + ---------- + obj : object + A Python scalar, list, or dict + + Returns + ------- + google.protobuf.Value + protobuf Value message + """ + val = struct_pb2.Value() + + if obj is None: + val.null_value = 0 + elif isinstance(obj, bool): + val.bool_value = obj + elif isinstance(obj, (int, float)): + val.number_value = float(obj) + elif isinstance(obj, str): + val.string_value = obj + elif isinstance(obj, (list, tuple)): + for item in obj: + val.list_value.values.append(_python_to_value(item)) + elif isinstance(obj, dict): + for k, v in obj.items(): + val.struct_value.fields[str(k)].CopyFrom(_python_to_value(v)) + else: + val.string_value = str(obj) + + return val diff --git a/philote_mdo/general/explicit_client.py b/philote_mdo/general/explicit_client.py index d0f9ef9..5b278cb 100644 --- a/philote_mdo/general/explicit_client.py +++ b/philote_mdo/general/explicit_client.py @@ -41,23 +41,45 @@ def __init__(self, channel): super().__init__(channel) self._expl_stub = disc.ExplicitServiceStub(channel) - def run_compute(self, inputs): + def run_compute(self, inputs, discrete_inputs=None): """ Requests and receives the function evaluation from the analysis server for a set of inputs (sent to the server). + + Parameters + ---------- + inputs : dict + Continuous input values. + discrete_inputs : dict, optional + Discrete input values. + + Returns + ------- + dict or tuple(dict, dict) + Continuous outputs, or (continuous outputs, discrete outputs) when + the server returns discrete output data. """ - messages = self._assemble_input_messages(inputs) + messages = self._assemble_input_messages( + inputs, discrete_inputs=discrete_inputs + ) responses = self._expl_stub.ComputeFunction(iter(messages)) - outputs = self._recover_outputs(responses) + return self._recover_outputs(responses) - return outputs - - def run_compute_partials(self, inputs): + def run_compute_partials(self, inputs, discrete_inputs=None): """ Requests and receives the gradient evaluation from the analysis server for a set of inputs (sent to the server). + + Parameters + ---------- + inputs : dict + Continuous input values. + discrete_inputs : dict, optional + Discrete input values. """ - messages = self._assemble_input_messages(inputs) + messages = self._assemble_input_messages( + inputs, discrete_inputs=discrete_inputs + ) responses = self._expl_stub.ComputeGradient(iter(messages)) partials = self._recover_partials(responses) diff --git a/philote_mdo/general/explicit_server.py b/philote_mdo/general/explicit_server.py index 6f6f329..049e0a2 100644 --- a/philote_mdo/general/explicit_server.py +++ b/philote_mdo/general/explicit_server.py @@ -29,7 +29,10 @@ # control over the information you may find at these locations. import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.generated.data_pb2 as data -from philote_mdo.general.discipline_server import DisciplineServer +from philote_mdo.general.discipline_server import ( + DisciplineServer, + _python_to_value, +) from philote_mdo.utils import get_chunk_indices @@ -55,21 +58,44 @@ def ComputeFunction(self, request_iterator, context): inputs = {} flat_inputs = {} outputs = {} + discrete_inputs = {} + discrete_outputs = {} self.preallocate_inputs(inputs, flat_inputs) - self.process_inputs(request_iterator, flat_inputs) - self._discipline.compute(inputs, outputs) + discrete_inputs, _ = self.process_inputs( + request_iterator, flat_inputs, discrete_inputs=discrete_inputs + ) + # Call compute with discrete data when discrete variables are present + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.compute( + inputs, outputs, discrete_inputs, discrete_outputs + ) + else: + self._discipline.compute(inputs, outputs) + + # Stream continuous outputs for output_name, value in outputs.items(): - # iterate through all chunks needed for the current output for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.Array( - name=output_name, - type=data.kOutput, - start=b, - end=e - 1, - data=value.ravel()[b:e], + yield data.VariableMessage( + continuous=data.Array( + name=output_name, + type=data.kOutput, + start=b, + end=e - 1, + data=value.ravel()[b:e], + ) + ) + + # Stream discrete outputs + for name, value in discrete_outputs.items(): + yield data.VariableMessage( + discrete=data.DiscreteVariable( + name=name, + type=data.VariableType.kDiscreteOutput, + value=_python_to_value(value), ) + ) def ComputeGradient(self, request_iterator, context): """ @@ -77,19 +103,28 @@ def ComputeGradient(self, request_iterator, context): """ inputs = {} flat_inputs = {} + discrete_inputs = {} + self.preallocate_inputs(inputs, flat_inputs) jac = self.preallocate_partials() - self.process_inputs(request_iterator, flat_inputs) - self._discipline.compute_partials(inputs, jac) + discrete_inputs, _ = self.process_inputs( + request_iterator, flat_inputs, discrete_inputs=discrete_inputs + ) + + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.compute_partials(inputs, jac, discrete_inputs) + else: + self._discipline.compute_partials(inputs, jac) for jac, value in jac.items(): - # iterate through all chunks needed for the current partials for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.Array( - name=jac[0], - subname=jac[1], - type=data.kPartial, - start=b, - end=e - 1, - data=value.ravel()[b:e], + yield data.VariableMessage( + continuous=data.Array( + name=jac[0], + subname=jac[1], + type=data.kPartial, + start=b, + end=e - 1, + data=value.ravel()[b:e], + ) ) diff --git a/philote_mdo/general/implicit_client.py b/philote_mdo/general/implicit_client.py index d62c2b1..81fb6f3 100644 --- a/philote_mdo/general/implicit_client.py +++ b/philote_mdo/general/implicit_client.py @@ -119,7 +119,9 @@ def __init__(self, channel): super().__init__(channel=channel) self._impl_stub = disc.ImplicitServiceStub(channel) - def run_compute_residuals(self, inputs, outputs): + def run_compute_residuals( + self, inputs, outputs, discrete_inputs=None, discrete_outputs=None + ): """ Compute residuals R(inputs, outputs) by calling the remote server. @@ -167,13 +169,18 @@ def run_compute_residuals(self, inputs, outputs): - This is typically used for residual evaluation during Newton iterations """ # Assemble input messages and call server - messages = self._assemble_input_messages(inputs, outputs) + messages = self._assemble_input_messages( + inputs, + outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, + ) responses = self._impl_stub.ComputeResiduals(iter(messages)) residuals = self._recover_residuals(responses) return residuals - def run_solve_residuals(self, inputs): + def run_solve_residuals(self, inputs, discrete_inputs=None): """ Solve implicit equations R(inputs, outputs) = 0 by calling the remote server. @@ -225,12 +232,16 @@ def run_solve_residuals(self, inputs): - Solution quality depends on the server's implementation and input conditioning """ # Assemble input messages and call server - messages = self._assemble_input_messages(inputs) + messages = self._assemble_input_messages( + inputs, discrete_inputs=discrete_inputs + ) responses = self._impl_stub.SolveResiduals(iter(messages)) outputs = self._recover_outputs(responses) return outputs - def run_residual_gradients(self, inputs, outputs): + def run_residual_gradients( + self, inputs, outputs, discrete_inputs=None, discrete_outputs=None + ): """ Compute Jacobian of residuals dR/d[inputs,outputs] by calling the remote server. @@ -289,7 +300,12 @@ def run_residual_gradients(self, inputs, outputs): - For large problems, consider matrix-free methods if available """ # Assemble input messages and call server - messages = self._assemble_input_messages(inputs, outputs) + messages = self._assemble_input_messages( + inputs, + outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, + ) responses = self._impl_stub.ComputeResidualGradients(iter(messages)) partials = self._recover_partials(responses) return partials diff --git a/philote_mdo/general/implicit_server.py b/philote_mdo/general/implicit_server.py index 21d45e3..973cfc8 100644 --- a/philote_mdo/general/implicit_server.py +++ b/philote_mdo/general/implicit_server.py @@ -30,6 +30,7 @@ import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.generated.data_pb2 as data import philote_mdo.general as pmdo +from philote_mdo.general.discipline_server import _python_to_value from philote_mdo.utils import get_chunk_indices @@ -136,15 +137,15 @@ def ComputeResiduals(self, request_iterator, context): Parameters ---------- - request_iterator : Iterator[data.Array] - Stream of input and output arrays from the client + request_iterator : Iterator[data.VariableMessage] + Stream of input and output variable messages from the client context : grpc.ServicerContext gRPC context for the request Yields ------ - data.Array - Stream of residual arrays back to the client + data.VariableMessage + Stream of residual variable messages back to the client Notes ----- @@ -159,21 +160,36 @@ def ComputeResiduals(self, request_iterator, context): outputs = {} flat_outputs = {} residuals = {} + discrete_inputs = {} + discrete_outputs = {} self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) - self.process_inputs(request_iterator, flat_inputs, flat_outputs) + discrete_inputs, discrete_outputs = self.process_inputs( + request_iterator, + flat_inputs, + flat_outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, + ) # Call the user-defined compute_residuals function - self._discipline.compute_residuals(inputs, outputs, residuals) + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.compute_residuals( + inputs, outputs, residuals, discrete_inputs, discrete_outputs + ) + else: + self._discipline.compute_residuals(inputs, outputs, residuals) for res_name, value in residuals.items(): for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.Array( - name=res_name, - start=b, - end=e, - type=data.kResidual, - data=value.ravel()[b:e], + yield data.VariableMessage( + continuous=data.Array( + name=res_name, + start=b, + end=e, + type=data.kResidual, + data=value.ravel()[b:e], + ) ) def SolveResiduals(self, request_iterator, context): @@ -186,15 +202,15 @@ def SolveResiduals(self, request_iterator, context): Parameters ---------- - request_iterator : Iterator[data.Array] - Stream of input arrays from the client + request_iterator : Iterator[data.VariableMessage] + Stream of input variable messages from the client context : grpc.ServicerContext gRPC context for the request Yields ------ - data.Array - Stream of solved output arrays back to the client + data.VariableMessage + Stream of solved output variable messages back to the client Notes ----- @@ -208,21 +224,32 @@ def SolveResiduals(self, request_iterator, context): flat_inputs = {} outputs = {} flat_outputs = {} + discrete_inputs = {} self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) - self.process_inputs(request_iterator, flat_inputs, flat_outputs) + discrete_inputs, _ = self.process_inputs( + request_iterator, + flat_inputs, + flat_outputs, + discrete_inputs=discrete_inputs, + ) # Call the user-defined solve function - self._discipline.solve_residuals(inputs, outputs) + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.solve_residuals(inputs, outputs, discrete_inputs) + else: + self._discipline.solve_residuals(inputs, outputs) for output_name, value in outputs.items(): for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.Array( - name=output_name, - start=b, - end=e, - type=data.kOutput, - data=value.ravel()[b:e], + yield data.VariableMessage( + continuous=data.Array( + name=output_name, + start=b, + end=e, + type=data.kOutput, + data=value.ravel()[b:e], + ) ) def ComputeResidualGradients(self, request_iterator, context): @@ -235,15 +262,15 @@ def ComputeResidualGradients(self, request_iterator, context): Parameters ---------- - request_iterator : Iterator[data.Array] - Stream of input and output arrays from the client + request_iterator : Iterator[data.VariableMessage] + Stream of input and output variable messages from the client context : grpc.ServicerContext gRPC context for the request Yields ------ - data.Array - Stream of partial derivative arrays back to the client + data.VariableMessage + Stream of partial derivative variable messages back to the client Notes ----- @@ -257,23 +284,38 @@ def ComputeResidualGradients(self, request_iterator, context): flat_inputs = {} outputs = {} flat_outputs = {} + discrete_inputs = {} + discrete_outputs = {} self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) jac = self.preallocate_partials() - self.process_inputs(request_iterator, flat_inputs, flat_outputs) + discrete_inputs, discrete_outputs = self.process_inputs( + request_iterator, + flat_inputs, + flat_outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, + ) # Call the user-defined residual partials function - self._discipline.residual_partials(inputs, outputs, jac) + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.residual_partials( + inputs, outputs, jac, discrete_inputs, discrete_outputs + ) + else: + self._discipline.residual_partials(inputs, outputs, jac) for jac, value in jac.items(): for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.Array( - name=jac[0], - subname=jac[1], - type=data.kPartial, - start=b, - end=e, - data=value.ravel()[b:e], + yield data.VariableMessage( + continuous=data.Array( + name=jac[0], + subname=jac[1], + type=data.kPartial, + start=b, + end=e, + data=value.ravel()[b:e], + ) ) # def MatrixFreeGradients(self, request_iterator, context): diff --git a/philote_mdo/generated/data_pb2.py b/philote_mdo/generated/data_pb2.py index 64720fe..8334dc1 100644 --- a/philote_mdo/generated/data_pb2.py +++ b/philote_mdo/generated/data_pb2.py @@ -7,17 +7,17 @@ _runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 5, 27, 2, '', 'data.proto') _sym_db = _symbol_database.Default() from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x07philote\x1a\x1cgoogle/protobuf/struct.proto"}\n\x14DisciplineProperties\x12\x12\n\ncontinuous\x18\x01 \x01(\x08\x12\x16\n\x0edifferentiable\x18\x02 \x01(\x08\x12\x1a\n\x12provides_gradients\x18\x03 \x01(\x08\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x0f\n\x07version\x18\x05 \x01(\t"#\n\rStreamOptions\x12\x12\n\nnum_double\x18\x01 \x01(\x03"?\n\x0bOptionsList\x12\x0f\n\x07options\x18\x01 \x03(\t\x12\x1f\n\x04type\x18\x02 \x03(\x0e2\x11.philote.DataType"=\n\x11DisciplineOptions\x12(\n\x07options\x18\x01 \x01(\x0b2\x17.google.protobuf.Struct"c\n\x10VariableMetaData\x12#\n\x04type\x18\x01 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05units\x18\x05 \x01(\t"@\n\x10PartialsMetaData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"u\n\x05Array\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05start\x18\x03 \x01(\x03\x12\x0b\n\x03end\x18\x04 \x01(\x03\x12#\n\x04type\x18\x05 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04data\x18\x06 \x03(\x01*9\n\x08DataType\x12\t\n\x05kBool\x10\x00\x12\x08\n\x04kInt\x10\x01\x12\x0b\n\x07kDouble\x10\x02\x12\x0b\n\x07kString\x10\x03*m\n\x0cVariableType\x12\n\n\x06kInput\x10\x00\x12\x12\n\x0ekDiscreteInput\x10\x01\x12\r\n\tkResidual\x10\x02\x12\x0b\n\x07kOutput\x10\x03\x12\x13\n\x0fkDiscreteOutput\x10\x04\x12\x0c\n\x08kPartial\x10\x05B\x11\n\x0forg.philote.mdob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x07philote\x1a\x1cgoogle/protobuf/struct.proto"}\n\x14DisciplineProperties\x12\x12\n\ncontinuous\x18\x01 \x01(\x08\x12\x16\n\x0edifferentiable\x18\x02 \x01(\x08\x12\x1a\n\x12provides_gradients\x18\x03 \x01(\x08\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x0f\n\x07version\x18\x05 \x01(\t"#\n\rStreamOptions\x12\x12\n\nnum_double\x18\x01 \x01(\x03"?\n\x0bOptionsList\x12\x0f\n\x07options\x18\x01 \x03(\t\x12\x1f\n\x04type\x18\x02 \x03(\x0e2\x11.philote.DataType"=\n\x11DisciplineOptions\x12(\n\x07options\x18\x01 \x01(\x0b2\x17.google.protobuf.Struct"c\n\x10VariableMetaData\x12#\n\x04type\x18\x01 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05units\x18\x05 \x01(\t"@\n\x10PartialsMetaData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"u\n\x05Array\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05start\x18\x03 \x01(\x03\x12\x0b\n\x03end\x18\x04 \x01(\x03\x12#\n\x04type\x18\x05 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04data\x18\x06 \x03(\x01"l\n\x10DiscreteVariable\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x04type\x18\x02 \x01(\x0e2\x15.philote.VariableType\x12%\n\x05value\x18\x03 \x01(\x0b2\x16.google.protobuf.Value"q\n\x0fVariableMessage\x12$\n\ncontinuous\x18\x01 \x01(\x0b2\x0e.philote.ArrayH\x00\x12-\n\x08discrete\x18\x02 \x01(\x0b2\x19.philote.DiscreteVariableH\x00B\t\n\x07payload*9\n\x08DataType\x12\t\n\x05kBool\x10\x00\x12\x08\n\x04kInt\x10\x01\x12\x0b\n\x07kDouble\x10\x02\x12\x0b\n\x07kString\x10\x03*m\n\x0cVariableType\x12\n\n\x06kInput\x10\x00\x12\x12\n\x0ekDiscreteInput\x10\x01\x12\r\n\tkResidual\x10\x02\x12\x0b\n\x07kOutput\x10\x03\x12\x13\n\x0fkDiscreteOutput\x10\x04\x12\x0c\n\x08kPartial\x10\x05B\x11\n\x0forg.philote.mdob\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\x0forg.philote.mdo' - _globals['_DATATYPE']._serialized_start = 631 - _globals['_DATATYPE']._serialized_end = 688 - _globals['_VARIABLETYPE']._serialized_start = 690 - _globals['_VARIABLETYPE']._serialized_end = 799 + _globals['_DATATYPE']._serialized_start = 856 + _globals['_DATATYPE']._serialized_end = 913 + _globals['_VARIABLETYPE']._serialized_start = 915 + _globals['_VARIABLETYPE']._serialized_end = 1024 _globals['_DISCIPLINEPROPERTIES']._serialized_start = 53 _globals['_DISCIPLINEPROPERTIES']._serialized_end = 178 _globals['_STREAMOPTIONS']._serialized_start = 180 @@ -31,4 +31,8 @@ _globals['_PARTIALSMETADATA']._serialized_start = 446 _globals['_PARTIALSMETADATA']._serialized_end = 510 _globals['_ARRAY']._serialized_start = 512 - _globals['_ARRAY']._serialized_end = 629 \ No newline at end of file + _globals['_ARRAY']._serialized_end = 629 + _globals['_DISCRETEVARIABLE']._serialized_start = 631 + _globals['_DISCRETEVARIABLE']._serialized_end = 739 + _globals['_VARIABLEMESSAGE']._serialized_start = 741 + _globals['_VARIABLEMESSAGE']._serialized_end = 854 \ No newline at end of file diff --git a/philote_mdo/generated/data_pb2.pyi b/philote_mdo/generated/data_pb2.pyi index 6d9cdd1..a699396 100644 --- a/philote_mdo/generated/data_pb2.pyi +++ b/philote_mdo/generated/data_pb2.pyi @@ -116,4 +116,26 @@ class Array(_message.Message): data: _containers.RepeatedScalarFieldContainer[float] def __init__(self, name: _Optional[str]=..., subname: _Optional[str]=..., start: _Optional[int]=..., end: _Optional[int]=..., type: _Optional[_Union[VariableType, str]]=..., data: _Optional[_Iterable[float]]=...) -> None: + ... + +class DiscreteVariable(_message.Message): + __slots__ = ('name', 'type', 'value') + NAME_FIELD_NUMBER: _ClassVar[int] + TYPE_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + name: str + type: VariableType + value: _struct_pb2.Value + + def __init__(self, name: _Optional[str]=..., type: _Optional[_Union[VariableType, str]]=..., value: _Optional[_Union[_struct_pb2.Value, _Mapping]]=...) -> None: + ... + +class VariableMessage(_message.Message): + __slots__ = ('continuous', 'discrete') + CONTINUOUS_FIELD_NUMBER: _ClassVar[int] + DISCRETE_FIELD_NUMBER: _ClassVar[int] + continuous: Array + discrete: DiscreteVariable + + def __init__(self, continuous: _Optional[_Union[Array, _Mapping]]=..., discrete: _Optional[_Union[DiscreteVariable, _Mapping]]=...) -> None: ... \ No newline at end of file diff --git a/philote_mdo/generated/disciplines_pb2.py b/philote_mdo/generated/disciplines_pb2.py index aa06069..a78cc88 100644 --- a/philote_mdo/generated/disciplines_pb2.py +++ b/philote_mdo/generated/disciplines_pb2.py @@ -8,7 +8,7 @@ _sym_db = _symbol_database.Default() from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 from . import data_pb2 as data__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11disciplines.proto\x12\x07philote\x1a\x1bgoogle/protobuf/empty.proto\x1a\ndata.proto2\x84\x04\n\x11DisciplineService\x12B\n\x07GetInfo\x12\x16.google.protobuf.Empty\x1a\x1d.philote.DisciplineProperties"\x00\x12D\n\x10SetStreamOptions\x12\x16.philote.StreamOptions\x1a\x16.google.protobuf.Empty"\x00\x12E\n\x13GetAvailableOptions\x12\x16.google.protobuf.Empty\x1a\x14.philote.OptionsList"\x00\x12B\n\nSetOptions\x12\x1a.philote.DisciplineOptions\x1a\x16.google.protobuf.Empty"\x00\x129\n\x05Setup\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty"\x00\x12O\n\x16GetVariableDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.VariableMetaData"\x000\x01\x12N\n\x15GetPartialDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.PartialsMetaData"\x000\x012\x83\x01\n\x0fExplicitService\x127\n\x0fComputeFunction\x12\x0e.philote.Array\x1a\x0e.philote.Array"\x00(\x010\x01\x127\n\x0fComputeGradient\x12\x0e.philote.Array\x1a\x0e.philote.Array"\x00(\x010\x012\xc5\x01\n\x0fImplicitService\x128\n\x10ComputeResiduals\x12\x0e.philote.Array\x1a\x0e.philote.Array"\x00(\x010\x01\x126\n\x0eSolveResiduals\x12\x0e.philote.Array\x1a\x0e.philote.Array"\x00(\x010\x01\x12@\n\x18ComputeResidualGradients\x12\x0e.philote.Array\x1a\x0e.philote.Array"\x00(\x010\x01B\x11\n\x0forg.philote.mdob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11disciplines.proto\x12\x07philote\x1a\x1bgoogle/protobuf/empty.proto\x1a\ndata.proto2\x84\x04\n\x11DisciplineService\x12B\n\x07GetInfo\x12\x16.google.protobuf.Empty\x1a\x1d.philote.DisciplineProperties"\x00\x12D\n\x10SetStreamOptions\x12\x16.philote.StreamOptions\x1a\x16.google.protobuf.Empty"\x00\x12E\n\x13GetAvailableOptions\x12\x16.google.protobuf.Empty\x1a\x14.philote.OptionsList"\x00\x12B\n\nSetOptions\x12\x1a.philote.DisciplineOptions\x1a\x16.google.protobuf.Empty"\x00\x129\n\x05Setup\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty"\x00\x12O\n\x16GetVariableDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.VariableMetaData"\x000\x01\x12N\n\x15GetPartialDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.PartialsMetaData"\x000\x012\xab\x01\n\x0fExplicitService\x12K\n\x0fComputeFunction\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12K\n\x0fComputeGradient\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x012\x81\x02\n\x0fImplicitService\x12L\n\x10ComputeResiduals\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12J\n\x0eSolveResiduals\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12T\n\x18ComputeResidualGradients\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01B\x11\n\x0forg.philote.mdob\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'disciplines_pb2', _globals) @@ -18,6 +18,6 @@ _globals['_DISCIPLINESERVICE']._serialized_start = 72 _globals['_DISCIPLINESERVICE']._serialized_end = 588 _globals['_EXPLICITSERVICE']._serialized_start = 591 - _globals['_EXPLICITSERVICE']._serialized_end = 722 - _globals['_IMPLICITSERVICE']._serialized_start = 725 - _globals['_IMPLICITSERVICE']._serialized_end = 922 \ No newline at end of file + _globals['_EXPLICITSERVICE']._serialized_end = 762 + _globals['_IMPLICITSERVICE']._serialized_start = 765 + _globals['_IMPLICITSERVICE']._serialized_end = 1022 \ No newline at end of file diff --git a/philote_mdo/generated/disciplines_pb2_grpc.py b/philote_mdo/generated/disciplines_pb2_grpc.py index 443d5da..1b23738 100644 --- a/philote_mdo/generated/disciplines_pb2_grpc.py +++ b/philote_mdo/generated/disciplines_pb2_grpc.py @@ -142,8 +142,8 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.ComputeFunction = channel.stream_stream('/philote.ExplicitService/ComputeFunction', request_serializer=data__pb2.Array.SerializeToString, response_deserializer=data__pb2.Array.FromString, _registered_method=True) - self.ComputeGradient = channel.stream_stream('/philote.ExplicitService/ComputeGradient', request_serializer=data__pb2.Array.SerializeToString, response_deserializer=data__pb2.Array.FromString, _registered_method=True) + self.ComputeFunction = channel.stream_stream('/philote.ExplicitService/ComputeFunction', request_serializer=data__pb2.VariableMessage.SerializeToString, response_deserializer=data__pb2.VariableMessage.FromString, _registered_method=True) + self.ComputeGradient = channel.stream_stream('/philote.ExplicitService/ComputeGradient', request_serializer=data__pb2.VariableMessage.SerializeToString, response_deserializer=data__pb2.VariableMessage.FromString, _registered_method=True) class ExplicitServiceServicer(object): """Definition of the generic Explicit Component RPC @@ -164,7 +164,7 @@ def ComputeGradient(self, request_iterator, context): raise NotImplementedError('Method not implemented!') def add_ExplicitServiceServicer_to_server(servicer, server): - rpc_method_handlers = {'ComputeFunction': grpc.stream_stream_rpc_method_handler(servicer.ComputeFunction, request_deserializer=data__pb2.Array.FromString, response_serializer=data__pb2.Array.SerializeToString), 'ComputeGradient': grpc.stream_stream_rpc_method_handler(servicer.ComputeGradient, request_deserializer=data__pb2.Array.FromString, response_serializer=data__pb2.Array.SerializeToString)} + rpc_method_handlers = {'ComputeFunction': grpc.stream_stream_rpc_method_handler(servicer.ComputeFunction, request_deserializer=data__pb2.VariableMessage.FromString, response_serializer=data__pb2.VariableMessage.SerializeToString), 'ComputeGradient': grpc.stream_stream_rpc_method_handler(servicer.ComputeGradient, request_deserializer=data__pb2.VariableMessage.FromString, response_serializer=data__pb2.VariableMessage.SerializeToString)} generic_handler = grpc.method_handlers_generic_handler('philote.ExplicitService', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) server.add_registered_method_handlers('philote.ExplicitService', rpc_method_handlers) @@ -175,14 +175,14 @@ class ExplicitService(object): @staticmethod def ComputeFunction(request_iterator, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/philote.ExplicitService/ComputeFunction', data__pb2.Array.SerializeToString, data__pb2.Array.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) + return grpc.experimental.stream_stream(request_iterator, target, '/philote.ExplicitService/ComputeFunction', data__pb2.VariableMessage.SerializeToString, data__pb2.VariableMessage.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) @staticmethod def ComputeGradient(request_iterator, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/philote.ExplicitService/ComputeGradient', data__pb2.Array.SerializeToString, data__pb2.Array.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) + return grpc.experimental.stream_stream(request_iterator, target, '/philote.ExplicitService/ComputeGradient', data__pb2.VariableMessage.SerializeToString, data__pb2.VariableMessage.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) class ImplicitServiceStub(object): - """Definition of the generic Explicit Discipline RPC + """Definition of the generic Implicit Discipline RPC """ def __init__(self, channel): @@ -191,12 +191,12 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.ComputeResiduals = channel.stream_stream('/philote.ImplicitService/ComputeResiduals', request_serializer=data__pb2.Array.SerializeToString, response_deserializer=data__pb2.Array.FromString, _registered_method=True) - self.SolveResiduals = channel.stream_stream('/philote.ImplicitService/SolveResiduals', request_serializer=data__pb2.Array.SerializeToString, response_deserializer=data__pb2.Array.FromString, _registered_method=True) - self.ComputeResidualGradients = channel.stream_stream('/philote.ImplicitService/ComputeResidualGradients', request_serializer=data__pb2.Array.SerializeToString, response_deserializer=data__pb2.Array.FromString, _registered_method=True) + self.ComputeResiduals = channel.stream_stream('/philote.ImplicitService/ComputeResiduals', request_serializer=data__pb2.VariableMessage.SerializeToString, response_deserializer=data__pb2.VariableMessage.FromString, _registered_method=True) + self.SolveResiduals = channel.stream_stream('/philote.ImplicitService/SolveResiduals', request_serializer=data__pb2.VariableMessage.SerializeToString, response_deserializer=data__pb2.VariableMessage.FromString, _registered_method=True) + self.ComputeResidualGradients = channel.stream_stream('/philote.ImplicitService/ComputeResidualGradients', request_serializer=data__pb2.VariableMessage.SerializeToString, response_deserializer=data__pb2.VariableMessage.FromString, _registered_method=True) class ImplicitServiceServicer(object): - """Definition of the generic Explicit Discipline RPC + """Definition of the generic Implicit Discipline RPC """ def ComputeResiduals(self, request_iterator, context): @@ -221,23 +221,23 @@ def ComputeResidualGradients(self, request_iterator, context): raise NotImplementedError('Method not implemented!') def add_ImplicitServiceServicer_to_server(servicer, server): - rpc_method_handlers = {'ComputeResiduals': grpc.stream_stream_rpc_method_handler(servicer.ComputeResiduals, request_deserializer=data__pb2.Array.FromString, response_serializer=data__pb2.Array.SerializeToString), 'SolveResiduals': grpc.stream_stream_rpc_method_handler(servicer.SolveResiduals, request_deserializer=data__pb2.Array.FromString, response_serializer=data__pb2.Array.SerializeToString), 'ComputeResidualGradients': grpc.stream_stream_rpc_method_handler(servicer.ComputeResidualGradients, request_deserializer=data__pb2.Array.FromString, response_serializer=data__pb2.Array.SerializeToString)} + rpc_method_handlers = {'ComputeResiduals': grpc.stream_stream_rpc_method_handler(servicer.ComputeResiduals, request_deserializer=data__pb2.VariableMessage.FromString, response_serializer=data__pb2.VariableMessage.SerializeToString), 'SolveResiduals': grpc.stream_stream_rpc_method_handler(servicer.SolveResiduals, request_deserializer=data__pb2.VariableMessage.FromString, response_serializer=data__pb2.VariableMessage.SerializeToString), 'ComputeResidualGradients': grpc.stream_stream_rpc_method_handler(servicer.ComputeResidualGradients, request_deserializer=data__pb2.VariableMessage.FromString, response_serializer=data__pb2.VariableMessage.SerializeToString)} generic_handler = grpc.method_handlers_generic_handler('philote.ImplicitService', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) server.add_registered_method_handlers('philote.ImplicitService', rpc_method_handlers) class ImplicitService(object): - """Definition of the generic Explicit Discipline RPC + """Definition of the generic Implicit Discipline RPC """ @staticmethod def ComputeResiduals(request_iterator, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/philote.ImplicitService/ComputeResiduals', data__pb2.Array.SerializeToString, data__pb2.Array.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) + return grpc.experimental.stream_stream(request_iterator, target, '/philote.ImplicitService/ComputeResiduals', data__pb2.VariableMessage.SerializeToString, data__pb2.VariableMessage.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) @staticmethod def SolveResiduals(request_iterator, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/philote.ImplicitService/SolveResiduals', data__pb2.Array.SerializeToString, data__pb2.Array.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) + return grpc.experimental.stream_stream(request_iterator, target, '/philote.ImplicitService/SolveResiduals', data__pb2.VariableMessage.SerializeToString, data__pb2.VariableMessage.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) @staticmethod def ComputeResidualGradients(request_iterator, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/philote.ImplicitService/ComputeResidualGradients', data__pb2.Array.SerializeToString, data__pb2.Array.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) \ No newline at end of file + return grpc.experimental.stream_stream(request_iterator, target, '/philote.ImplicitService/ComputeResidualGradients', data__pb2.VariableMessage.SerializeToString, data__pb2.VariableMessage.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) \ No newline at end of file diff --git a/philote_mdo/openmdao/explicit.py b/philote_mdo/openmdao/explicit.py index b22c045..458b7e6 100644 --- a/philote_mdo/openmdao/explicit.py +++ b/philote_mdo/openmdao/explicit.py @@ -219,12 +219,12 @@ def compute(self, inputs, outputs, discrete_inputs=None, discrete_outputs=None): ---------- inputs : dict Dictionary of input values with variable names as keys - outputs : dict + outputs : dict Dictionary to store computed output values with variable names as keys discrete_inputs : dict, optional - Dictionary of discrete input values (currently unused), by default None + Dictionary of discrete input values, by default None discrete_outputs : dict, optional - Dictionary of discrete output values (currently unused), by default None + Dictionary of discrete output values, by default None Notes ----- @@ -234,8 +234,19 @@ def compute(self, inputs, outputs, discrete_inputs=None, discrete_outputs=None): - This method is called automatically by OpenMDAO during model execution """ local_inputs = utils.create_local_inputs(inputs, self._client._var_meta) - out = self._client.run_compute(local_inputs) - utils.assign_global_outputs(out, outputs) + local_discrete = utils.create_local_discrete_inputs( + discrete_inputs, self._client._discrete_var_meta + ) + result = self._client.run_compute(local_inputs, discrete_inputs=local_discrete) + + # run_compute returns (outputs, discrete_outputs) when discrete data exists + if isinstance(result, tuple): + out, d_out = result + utils.assign_global_outputs(out, outputs) + if discrete_outputs is not None: + utils.assign_global_outputs(d_out, discrete_outputs) + else: + utils.assign_global_outputs(result, outputs) def compute_partials(self, inputs, partials, discrete_inputs=None, discrete_outputs=None): """ @@ -250,12 +261,12 @@ def compute_partials(self, inputs, partials, discrete_inputs=None, discrete_outp inputs : dict Dictionary of input values with variable names as keys partials : dict - Dictionary to store computed partial derivatives with (output, input) + Dictionary to store computed partial derivatives with (output, input) tuples as keys discrete_inputs : dict, optional - Dictionary of discrete input values (currently unused), by default None + Dictionary of discrete input values, by default None discrete_outputs : dict, optional - Dictionary of discrete output values (currently unused), by default None + Dictionary of discrete output values, by default None Notes ----- @@ -266,5 +277,10 @@ def compute_partials(self, inputs, partials, discrete_inputs=None, discrete_outp - Sparsity patterns from the server are preserved in the OpenMDAO component """ local_inputs = utils.create_local_inputs(inputs, self._client._var_meta) - jac = self._client.run_compute_partials(local_inputs) + local_discrete = utils.create_local_discrete_inputs( + discrete_inputs, self._client._discrete_var_meta + ) + jac = self._client.run_compute_partials( + local_inputs, discrete_inputs=local_discrete + ) utils.assign_global_outputs(jac, partials) diff --git a/philote_mdo/openmdao/implicit.py b/philote_mdo/openmdao/implicit.py index e4a862f..1aa8843 100644 --- a/philote_mdo/openmdao/implicit.py +++ b/philote_mdo/openmdao/implicit.py @@ -220,11 +220,6 @@ def apply_nonlinear(self, inputs, outputs, residuals, discrete_inputs=None, disc """ Compute residual evaluation by calling the remote Philote server. - This method transfers both input and output values to the server, requests a - residual evaluation, and transfers the computed residuals back to the OpenMDAO - component. The residuals represent R(inputs, outputs) where the goal is to - find outputs such that R = 0. - Parameters ---------- inputs : dict @@ -234,33 +229,32 @@ def apply_nonlinear(self, inputs, outputs, residuals, discrete_inputs=None, disc residuals : dict Dictionary to store computed residual values with variable names as keys discrete_inputs : dict, optional - Dictionary of discrete input values (currently unused), by default None + Dictionary of discrete input values, by default None discrete_outputs : dict, optional - Dictionary of discrete output values (currently unused), by default None - - Notes - ----- - - Both inputs and outputs are sent to the server for residual computation - - Residuals are computed as R(inputs, outputs) at the current point - - The goal is to find outputs where residuals are zero - - This method is called automatically by OpenMDAO during residual evaluation + Dictionary of discrete output values, by default None """ local_inputs = utils.create_local_inputs(inputs, self._client._var_meta) local_outputs = utils.create_local_inputs( outputs, self._client._var_meta, data.kOutput ) + local_di = utils.create_local_discrete_inputs( + discrete_inputs, self._client._discrete_var_meta + ) + local_do = utils.create_local_discrete_inputs( + discrete_outputs, + self._client._discrete_var_meta, + data.VariableType.kDiscreteOutput, + ) - res = self._client.run_compute_residuals(local_inputs, local_outputs) + res = self._client.run_compute_residuals( + local_inputs, local_outputs, local_di, local_do + ) utils.assign_global_outputs(res, residuals) def solve_nonlinear(self, inputs, outputs, discrete_inputs=None, discrete_outputs=None): """ Solve the implicit equations by calling the remote Philote server. - This method transfers input values to the server, requests the server to solve - the implicit equations R(inputs, outputs) = 0 for the outputs, and transfers - the solved output values back to the OpenMDAO component. - Parameters ---------- inputs : dict @@ -268,31 +262,21 @@ def solve_nonlinear(self, inputs, outputs, discrete_inputs=None, discrete_output outputs : dict Dictionary to store solved output values with variable names as keys discrete_inputs : dict, optional - Dictionary of discrete input values (currently unused), by default None + Dictionary of discrete input values, by default None discrete_outputs : dict, optional - Dictionary of discrete output values (currently unused), by default None - - Notes - ----- - - The server performs the nonlinear solve internally - - Server may use Newton's method, fixed-point iteration, or other solvers - - Convergence criteria are controlled by server options - - This method is called by OpenMDAO's nonlinear solvers - - Output values are updated with the converged solution + Dictionary of discrete output values, by default None """ local_inputs = utils.create_local_inputs(inputs, self._client._var_meta) - out = self._client.run_solve_residuals(local_inputs) + local_di = utils.create_local_discrete_inputs( + discrete_inputs, self._client._discrete_var_meta + ) + out = self._client.run_solve_residuals(local_inputs, local_di) utils.assign_global_outputs(out, outputs) def linearize(self, inputs, outputs, partials, discrete_inputs=None, discrete_outputs=None): """ Compute partial derivatives of residuals by calling the remote Philote server. - This method transfers both input and output values to the server, requests - computation of the residual Jacobian (dR/dinputs and dR/doutputs), and transfers - the computed partial derivatives back to the OpenMDAO component. These derivatives - are used by OpenMDAO's linear solvers and optimization algorithms. - Parameters ---------- inputs : dict @@ -300,24 +284,25 @@ def linearize(self, inputs, outputs, partials, discrete_inputs=None, discrete_ou outputs : dict Dictionary of output values with variable names as keys partials : dict - Dictionary to store computed partial derivatives with (residual, variable) - tuples as keys + Dictionary to store computed partial derivatives discrete_inputs : dict, optional - Dictionary of discrete input values (currently unused), by default None + Dictionary of discrete input values, by default None discrete_outputs : dict, optional - Dictionary of discrete output values (currently unused), by default None - - Notes - ----- - - Computes both dR/dinputs and dR/doutputs partial derivatives - - Derivatives are computed at the current (inputs, outputs) point - - Server determines whether to use analytic or finite difference derivatives - - This method is called automatically by OpenMDAO when derivatives are needed - - Results are used by linear solvers and optimization algorithms + Dictionary of discrete output values, by default None """ local_inputs = utils.create_local_inputs(inputs, self._client._var_meta) local_outputs = utils.create_local_inputs( outputs, self._client._var_meta, data.kOutput ) - jac = self._client.run_residual_gradients(local_inputs, local_outputs) + local_di = utils.create_local_discrete_inputs( + discrete_inputs, self._client._discrete_var_meta + ) + local_do = utils.create_local_discrete_inputs( + discrete_outputs, + self._client._discrete_var_meta, + data.VariableType.kDiscreteOutput, + ) + jac = self._client.run_residual_gradients( + local_inputs, local_outputs, local_di, local_do + ) utils.assign_global_outputs(jac, partials) diff --git a/philote_mdo/openmdao/utils.py b/philote_mdo/openmdao/utils.py index ed626a9..c11f881 100644 --- a/philote_mdo/openmdao/utils.py +++ b/philote_mdo/openmdao/utils.py @@ -53,13 +53,14 @@ def client_setup(comp): Sets up the OpenMDAO component with all required inputs and outputs. This function will call the required RPCs to obtain the variables - from the remote discipline server. + from the remote discipline server. Both continuous and discrete + variables are declared. """ # set up the remote discipline and get the variable definitions comp._client.run_setup() comp._client.get_variable_definitions() - # define inputs and outputs based on the discipline metadata + # define continuous inputs and outputs based on the discipline metadata for var in comp._client._var_meta: if not var.units: units = None @@ -72,6 +73,14 @@ def client_setup(comp): if var.type == data.kOutput: comp.add_output(var.name, shape=tuple(var.shape), units=units) + # define discrete inputs and outputs + for var in comp._client._discrete_var_meta: + if var.type == data.VariableType.kDiscreteInput: + comp.add_discrete_input(var.name, val=None) + + if var.type == data.VariableType.kDiscreteOutput: + comp.add_discrete_output(var.name, val=None) + def client_setup_partials(comp): """ @@ -88,6 +97,20 @@ def client_setup_partials(comp): comp.declare_partials(partial.name, partial.subname) +def create_local_discrete_inputs(discrete_inputs, discrete_var_meta, type=data.VariableType.kDiscreteInput): + """ + Creates a Philote-Python local discrete inputs dictionary from OpenMDAO + discrete inputs. + """ + if discrete_inputs is None: + return None + local = {} + for var in discrete_var_meta: + if var.type == type: + local[var.name] = discrete_inputs[var.name] + return local if local else None + + def create_local_inputs(inputs, var_meta, type=data.kInput): """ Creates a Philote-Python local inputs dictionary from OpenMDAO inputs. diff --git a/proto b/proto index 5979d87..a5a21b8 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 5979d8713a4528348ebd3672a1740a6e33e7742a +Subproject commit a5a21b8c151c34f9f7b6f8fb717da76bb7e18011 From 3bfce13ce90845c8dce1e420c4e7133be1386166 Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Wed, 8 Apr 2026 18:01:31 -0400 Subject: [PATCH 2/8] fix: update tests for VariableMessage wrapper Wrap all raw data.Array messages in data.VariableMessage in test fixtures to match the new streaming protocol. Update OpenMDAO test assertions to include discrete_inputs parameters in expected calls. --- tests/test_discipline_client.py | 122 ++++++++++++++++--------- tests/test_discipline_server.py | 56 +++++++----- tests/test_edge_cases.py | 58 ++++++------ tests/test_explicit_client.py | 34 ++++--- tests/test_explicit_server.py | 48 +++++----- tests/test_implicit_client.py | 46 ++++++---- tests/test_implicit_server.py | 90 +++++++++--------- tests/test_openmdao_explicit_client.py | 6 +- tests/test_openmdao_implicit_client.py | 10 +- tests/test_openmdao_utils.py | 1 + 10 files changed, 271 insertions(+), 200 deletions(-) diff --git a/tests/test_discipline_client.py b/tests/test_discipline_client.py index 75cc6a4..1e24d9b 100644 --- a/tests/test_discipline_client.py +++ b/tests/test_discipline_client.py @@ -257,21 +257,29 @@ def test_assemble_input_messages(self): } expected_messages = [ - data.Array( - name="x", start=0, end=1, type=data.VariableType.kInput, data=[1.0, 2.0] + data.VariableMessage( + continuous=data.Array( + name="x", start=0, end=1, type=data.VariableType.kInput, data=[1.0, 2.0] + ) ), - data.Array( - name="x", start=2, end=3, type=data.VariableType.kInput, data=[3.0, 4.0] + data.VariableMessage( + continuous=data.Array( + name="x", start=2, end=3, type=data.VariableType.kInput, data=[3.0, 4.0] + ) ), - data.Array( - name="f", - start=0, - end=1, - type=data.VariableType.kOutput, - data=[5.0, 6.0], + data.VariableMessage( + continuous=data.Array( + name="f", + start=0, + end=1, + type=data.VariableType.kOutput, + data=[5.0, 6.0], + ) ), - data.Array( - name="f", start=2, end=2, type=data.VariableType.kOutput, data=[7.0] + data.VariableMessage( + continuous=data.Array( + name="f", start=2, end=2, type=data.VariableType.kOutput, data=[7.0] + ) ), ] @@ -294,14 +302,20 @@ def test_recover_outputs(self): data.VariableMetaData(name="g", type=data.kOutput, shape=(3,)), ] - response1 = data.Array( - name="f", start=0, end=1, type=data.kOutput, data=[1.0, 2.0] + response1 = data.VariableMessage( + continuous=data.Array( + name="f", start=0, end=1, type=data.kOutput, data=[1.0, 2.0] + ) ) - response2 = data.Array( - name="f", start=2, end=3, type=data.kOutput, data=[3.0, 4.0] + response2 = data.VariableMessage( + continuous=data.Array( + name="f", start=2, end=3, type=data.kOutput, data=[3.0, 4.0] + ) ) - response3 = data.Array( - name="g", start=0, end=2, type=data.kOutput, data=[4.0, 5.0, 6.0] + response3 = data.VariableMessage( + continuous=data.Array( + name="g", start=0, end=2, type=data.kOutput, data=[4.0, 5.0, 6.0] + ) ) mock_responses = [response1, response2, response3] @@ -330,14 +344,20 @@ def test_recover_residuals(self): data.VariableMetaData(name="g", type=data.kResidual, shape=(3,)), ] - response1 = data.Array( - name="f", start=0, end=1, type=data.kResidual, data=[1.0, 2.0] + response1 = data.VariableMessage( + continuous=data.Array( + name="f", start=0, end=1, type=data.kResidual, data=[1.0, 2.0] + ) ) - response2 = data.Array( - name="f", start=2, end=3, type=data.kResidual, data=[3.0, 4.0] + response2 = data.VariableMessage( + continuous=data.Array( + name="f", start=2, end=3, type=data.kResidual, data=[3.0, 4.0] + ) ) - response3 = data.Array( - name="g", start=0, end=2, type=data.kResidual, data=[4.0, 5.0, 6.0] + response3 = data.VariableMessage( + continuous=data.Array( + name="g", start=0, end=2, type=data.kResidual, data=[4.0, 5.0, 6.0] + ) ) mock_responses = [response1, response2, response3] @@ -373,19 +393,25 @@ def test_recover_partials(self): client._partials_meta = [partial_metadata1, partial_metadata2] # Define mock responses - response1 = data.Array( - name="f", subname="x", type=data.kPartial, start=0, end=1, data=[1.0, 2.0] + response1 = data.VariableMessage( + continuous=data.Array( + name="f", subname="x", type=data.kPartial, start=0, end=1, data=[1.0, 2.0] + ) ) - response2 = data.Array( - name="f", subname="x", type=data.kPartial, start=2, end=3, data=[3.0, 4.0] + response2 = data.VariableMessage( + continuous=data.Array( + name="f", subname="x", type=data.kPartial, start=2, end=3, data=[3.0, 4.0] + ) ) - response3 = data.Array( - name="g", - subname="y", - type=data.kPartial, - start=0, - end=2, - data=[4.0, 5.0, 6.0], + response3 = data.VariableMessage( + continuous=data.Array( + name="g", + subname="y", + type=data.kPartial, + start=0, + end=2, + data=[4.0, 5.0, 6.0], + ) ) mock_responses = [response1, response2, response3] @@ -420,14 +446,16 @@ def test_recover_outputs_empty_array_raises_error(self): ] # Create response with empty data array - response_empty = data.Array( - name="f", start=0, end=1, type=data.kOutput, data=[] + response_empty = data.VariableMessage( + continuous=data.Array( + name="f", start=0, end=1, type=data.kOutput, data=[] + ) ) mock_responses = [response_empty] with self.assertRaises(ValueError) as context: client._recover_outputs(mock_responses) - + self.assertIn("Expected continuous variables, but array is empty", str(context.exception)) def test_recover_residuals_empty_array_raises_error(self): @@ -442,14 +470,16 @@ def test_recover_residuals_empty_array_raises_error(self): ] # Create response with empty data array - response_empty = data.Array( - name="f", start=0, end=1, type=data.kResidual, data=[] + response_empty = data.VariableMessage( + continuous=data.Array( + name="f", start=0, end=1, type=data.kResidual, data=[] + ) ) mock_responses = [response_empty] with self.assertRaises(ValueError) as context: client._recover_residuals(mock_responses) - + self.assertIn("Expected continuous variables, but array is empty", str(context.exception)) def test_recover_partials_empty_array_raises_error(self): @@ -463,20 +493,22 @@ def test_recover_partials_empty_array_raises_error(self): data.VariableMetaData(name="f", type=data.kOutput, shape=(2,)), data.VariableMetaData(name="x", type=data.kInput, shape=(2,)), ] - + client._partials_meta = [ data.PartialsMetaData(name="f", subname="x"), ] # Create response with empty data array - response_empty = data.Array( - name="f", subname="x", start=0, end=1, type=data.kPartial, data=[] + response_empty = data.VariableMessage( + continuous=data.Array( + name="f", subname="x", start=0, end=1, type=data.kPartial, data=[] + ) ) mock_responses = [response_empty] with self.assertRaises(ValueError) as context: client._recover_partials(mock_responses) - + self.assertIn("Expected continuous outputs for the partials, but array was empty", str(context.exception)) diff --git a/tests/test_discipline_server.py b/tests/test_discipline_server.py index db89c98..f9b2550 100644 --- a/tests/test_discipline_server.py +++ b/tests/test_discipline_server.py @@ -325,22 +325,28 @@ def test_preallocate_partials(self): def test_process_inputs(self): # create a mock request_iterator request_iterator = [ - data.Array( - start=0, - end=2, - data=[1.0, 2.0, 3.0], - type=data.VariableType.kInput, - name="x", + data.VariableMessage( + continuous=data.Array( + start=0, + end=2, + data=[1.0, 2.0, 3.0], + type=data.VariableType.kInput, + name="x", + ) ), - data.Array( - start=3, end=4, data=[4.0, 5.0], type=data.VariableType.kInput, name="x" + data.VariableMessage( + continuous=data.Array( + start=3, end=4, data=[4.0, 5.0], type=data.VariableType.kInput, name="x" + ) ), - data.Array( - start=0, - end=1, - data=[0.1, 0.2], - type=data.VariableType.kOutput, - name="f", + data.VariableMessage( + continuous=data.Array( + start=0, + end=1, + data=[0.1, 0.2], + type=data.VariableType.kOutput, + name="f", + ) ), ] @@ -379,24 +385,26 @@ def test_process_inputs_empty_array_raises_error(self): Tests that process_inputs raises ValueError when array data is empty. """ server = DisciplineServer() - + # Create request with empty data array request_iterator = [ - data.Array( - start=0, - end=2, - data=[], # Empty data array - type=data.VariableType.kInput, - name="x", + data.VariableMessage( + continuous=data.Array( + start=0, + end=2, + data=[], # Empty data array + type=data.VariableType.kInput, + name="x", + ) ), ] - + flat_inputs = {"x": np.zeros(3)} flat_outputs = {} - + with self.assertRaises(ValueError) as context: server.process_inputs(request_iterator, flat_inputs, flat_outputs) - + self.assertIn("Expected continuous variables but arrays were empty for variable x", str(context.exception)) diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index ecda48a..72b5e40 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -98,36 +98,39 @@ def test_get_available_options_with_invalid_type(self): def test_process_inputs_with_empty_continuous_data(self): """ - Test process_inputs with empty continuous data arrays (line 216). + Test process_inputs with empty continuous data arrays. """ server = DisciplineServer() discipline = Mock() - + # Set up discipline with continuous variables discipline._is_continuous = True discipline._var_meta = [Mock()] discipline._var_meta[0].name = "test_var" discipline._var_meta[0].shape = [2] discipline._var_meta[0].type = data.kInput - + server.attach_discipline(discipline) - - # Create a message with empty data - message = Mock() - message.name = "test_var" - message.type = data.VariableType.kInput - message.start = 0 - message.end = 1 - message.data = [] # Empty data array - + + # Create a VariableMessage wrapping an Array with empty data + message = data.VariableMessage( + continuous=data.Array( + name="test_var", + type=data.VariableType.kInput, + start=0, + end=1, + data=[], + ) + ) + # Create mock for flat_inputs and flat_outputs flat_inputs = {"test_var": [0.0, 0.0]} flat_outputs = {} - + # This should raise a ValueError with self.assertRaises(ValueError) as context: server.process_inputs([message], flat_inputs, flat_outputs) - + self.assertIn("Expected continuous variables but arrays were empty", str(context.exception)) @@ -138,30 +141,33 @@ class TestDisciplineClientEdgeCases(unittest.TestCase): def test_recover_outputs_with_empty_data(self): """ - Test _recover_outputs with empty data arrays (line 197). + Test _recover_outputs with empty data arrays. """ # Create a mock channel channel = Mock() client = DisciplineClient(channel) - + # Set up outputs structure client._var_meta = [Mock()] client._var_meta[0].name = "test_output" client._var_meta[0].shape = [2] client._var_meta[0].type = data.kOutput - - # Create a response message with empty data - message = Mock() - message.name = "test_output" - message.type = data.kOutput - message.start = 0 - message.end = 1 - message.data = [] # Empty data array - + + # Create a VariableMessage wrapping an Array with empty data + message = data.VariableMessage( + continuous=data.Array( + name="test_output", + type=data.kOutput, + start=0, + end=1, + data=[], + ) + ) + # This should raise a ValueError with self.assertRaises(ValueError) as context: client._recover_outputs([message]) - + self.assertIn("Expected continuous variables, but array is empty", str(context.exception)) # NOTE: Other client edge case tests are more complex to set up properly diff --git a/tests/test_explicit_client.py b/tests/test_explicit_client.py index 4a4be6c..59bd35b 100644 --- a/tests/test_explicit_client.py +++ b/tests/test_explicit_client.py @@ -60,11 +60,15 @@ def test_compute(self, mock_explicit_stub): "x": np.array([1.0, 2.0, 3.0, 4.0]).reshape(2, 2), } - response1 = data.Array( - name="f", type=data.kOutput, start=0, end=2, data=[5.0, 6.0, 7.0] + response1 = data.VariableMessage( + continuous=data.Array( + name="f", type=data.kOutput, start=0, end=2, data=[5.0, 6.0, 7.0] + ) ) - response2 = data.Array( - name="g", type=data.kOutput, start=0, end=2, data=[8.0, 9.0, 10.0] + response2 = data.VariableMessage( + continuous=data.Array( + name="g", type=data.kOutput, start=0, end=2, data=[8.0, 9.0, 10.0] + ) ) mock_responses = [response1, response2] @@ -101,16 +105,20 @@ def test_compute_partials(self, mock_explicit_stub): "x": np.array([1.0, 2.0, 3.0, 4.0]).reshape(2, 2), } - response1 = data.Array( - name="f", - subname="x", - type=data.kPartial, - start=0, - end=2, - data=[5.0, 6.0, 7.0], + response1 = data.VariableMessage( + continuous=data.Array( + name="f", + subname="x", + type=data.kPartial, + start=0, + end=2, + data=[5.0, 6.0, 7.0], + ) ) - response2 = data.Array( - name="f", subname="x", type=data.kPartial, start=3, end=3, data=[4.0] + response2 = data.VariableMessage( + continuous=data.Array( + name="f", subname="x", type=data.kPartial, start=3, end=3, data=[4.0] + ) ) mock_responses = [response1, response2] diff --git a/tests/test_explicit_server.py b/tests/test_explicit_server.py index 8ff60bf..0416ec0 100644 --- a/tests/test_explicit_server.py +++ b/tests/test_explicit_server.py @@ -56,15 +56,17 @@ def test_compute_function(self): context = Mock() request_iterator = [ - data.Array( - start=0, - end=2, - data=[0.5, 1.5, 3.5], - type=data.VariableType.kInput, - name="x", + data.VariableMessage( + continuous=data.Array( + start=0, end=2, data=[0.5, 1.5, 3.5], + type=data.VariableType.kInput, name="x", + ) ), - data.Array( - start=3, end=4, data=[4.5, 5.5], type=data.VariableType.kInput, name="x" + data.VariableMessage( + continuous=data.Array( + start=3, end=4, data=[4.5, 5.5], + type=data.VariableType.kInput, name="x", + ) ), ] @@ -81,8 +83,8 @@ def compute(inputs, outputs): # check that there is only one response self.assertEqual(len(responses), 1) - # check the function value - response = responses[0] + # check the function value (unwrap VariableMessage) + response = responses[0].continuous self.assertEqual(response.name, "f") self.assertEqual(response.start, 0) self.assertEqual(response.end, 1) @@ -102,15 +104,17 @@ def test_compute_gradient(self): context = Mock() request_iterator = [ - data.Array( - start=0, - end=2, - data=[0.5, 1.5, 3.5], - type=data.VariableType.kInput, - name="x", + data.VariableMessage( + continuous=data.Array( + start=0, end=2, data=[0.5, 1.5, 3.5], + type=data.VariableType.kInput, name="x", + ) ), - data.Array( - start=3, end=4, data=[4.5, 5.5], type=data.VariableType.kInput, name="x" + data.VariableMessage( + continuous=data.Array( + start=3, end=4, data=[4.5, 5.5], + type=data.VariableType.kInput, name="x", + ) ), ] @@ -124,18 +128,18 @@ def compute_partials(inputs, jac): response_generator = server.ComputeGradient(request_iterator, context) responses = list(response_generator) - # check that there is only one response + # check that there are two responses self.assertEqual(len(responses), 2) - # check the function value - response = responses[0] + # check the function value (unwrap VariableMessage) + response = responses[0].continuous self.assertEqual(response.name, "f") self.assertEqual(response.subname, "x") self.assertEqual(response.start, 0) self.assertEqual(response.end, 2) grad = np.array(response.data) - response = responses[1] + response = responses[1].continuous grad = np.append(grad, np.array(response.data)) self.assertTrue( np.array_equal(grad, np.array([-251.0, -499.0, 11105.0, 25007.0, -2950.0])) diff --git a/tests/test_implicit_client.py b/tests/test_implicit_client.py index 72e267f..3ec34a9 100644 --- a/tests/test_implicit_client.py +++ b/tests/test_implicit_client.py @@ -67,11 +67,15 @@ def test_compute_residuals(self, mock_implicit_stub): "g": np.array([7.0, 6.0, 5.0]), } - response1 = data.Array( - name="f", type=data.kResidual, start=0, end=2, data=[5.0, 6.0, 7.0] + response1 = data.VariableMessage( + continuous=data.Array( + name="f", type=data.kResidual, start=0, end=2, data=[5.0, 6.0, 7.0] + ) ) - response2 = data.Array( - name="g", type=data.kResidual, start=0, end=2, data=[8.0, 9.0, 10.0] + response2 = data.VariableMessage( + continuous=data.Array( + name="g", type=data.kResidual, start=0, end=2, data=[8.0, 9.0, 10.0] + ) ) mock_responses = [response1, response2] @@ -114,11 +118,15 @@ def test_solve_residuals(self, mock_implicit_stub): "f": np.array([5.0, 6.0, 7.0]), } - response1 = data.Array( - name="f", type=data.kOutput, start=0, end=2, data=[5.0, 6.0, 7.0] + response1 = data.VariableMessage( + continuous=data.Array( + name="f", type=data.kOutput, start=0, end=2, data=[5.0, 6.0, 7.0] + ) ) - response2 = data.Array( - name="g", type=data.kOutput, start=0, end=2, data=[8.0, 9.0, 10.0] + response2 = data.VariableMessage( + continuous=data.Array( + name="g", type=data.kOutput, start=0, end=2, data=[8.0, 9.0, 10.0] + ) ) mock_responses = [response1, response2] @@ -162,16 +170,20 @@ def test_residual_partials(self, mock_implicit_stub): "f": np.array([5.0, 6.0]), } - response1 = data.Array( - name="f", - subname="x", - type=data.kPartial, - start=0, - end=2, - data=[5.0, 6.0, 7.0], + response1 = data.VariableMessage( + continuous=data.Array( + name="f", + subname="x", + type=data.kPartial, + start=0, + end=2, + data=[5.0, 6.0, 7.0], + ) ) - response2 = data.Array( - name="f", subname="x", type=data.kPartial, start=3, end=3, data=[4.0] + response2 = data.VariableMessage( + continuous=data.Array( + name="f", subname="x", type=data.kPartial, start=3, end=3, data=[4.0] + ) ) mock_responses = [response1, response2] diff --git a/tests/test_implicit_server.py b/tests/test_implicit_server.py index 9a2e341..21e7c82 100644 --- a/tests/test_implicit_server.py +++ b/tests/test_implicit_server.py @@ -58,9 +58,9 @@ def test_compute_residuals(self): # mock request iterator mock_request_iterator = [ - data.Array(name="x", start=0, end=2, type=data.kInput, data=[1.0, 2.0]), - data.Array(name="y", start=0, end=2, type=data.kInput, data=[3.0, 4.0]), - data.Array(name="f", start=0, end=2, type=data.kOutput, data=[5.0, 6.0]), + data.VariableMessage(continuous=data.Array(name="x", start=0, end=2, type=data.kInput, data=[1.0, 2.0])), + data.VariableMessage(continuous=data.Array(name="y", start=0, end=2, type=data.kInput, data=[3.0, 4.0])), + data.VariableMessage(continuous=data.Array(name="f", start=0, end=2, type=data.kOutput, data=[5.0, 6.0])), ] # mock inputs, outputs, and residuals @@ -81,19 +81,17 @@ def compute_residuals(inputs, outputs, residuals): # assert that the expected residual messages were yielded expected_result = [ - data.Array( - name="f", - start=0, - end=1, - type=data.VariableType.kResidual, - data=[7.0], + data.VariableMessage( + continuous=data.Array( + name="f", start=0, end=1, + type=data.VariableType.kResidual, data=[7.0], + ) ), - data.Array( - name="f", - start=1, - end=2, - type=data.VariableType.kResidual, - data=[8.0], + data.VariableMessage( + continuous=data.Array( + name="f", start=1, end=2, + type=data.VariableType.kResidual, data=[8.0], + ) ), ] self.assertEqual(result, expected_result) @@ -115,9 +113,9 @@ def test_solve_residuals(self): # mock request iterator mock_request_iterator = [ - data.Array(name="x", start=0, end=2, type=data.kInput, data=[1.0, 2.0]), - data.Array(name="y", start=0, end=2, type=data.kInput, data=[3.0, 4.0]), - data.Array(name="f", start=0, end=2, type=data.kOutput, data=[5.0, 6.0]), + data.VariableMessage(continuous=data.Array(name="x", start=0, end=2, type=data.kInput, data=[1.0, 2.0])), + data.VariableMessage(continuous=data.Array(name="y", start=0, end=2, type=data.kInput, data=[3.0, 4.0])), + data.VariableMessage(continuous=data.Array(name="f", start=0, end=2, type=data.kOutput, data=[5.0, 6.0])), ] # mock inputs, outputs, and residuals @@ -132,25 +130,23 @@ def solve_residuals(inputs, outputs): server._discipline.solve_residuals = solve_residuals - # call the ComputeResiduals method + # call the SolveResiduals method response_generator = server.SolveResiduals(mock_request_iterator, None) result = list(response_generator) - # assert that the expected residual messages were yielded + # assert that the expected output messages were yielded expected_result = [ - data.Array( - name="f", - start=0, - end=1, - type=data.VariableType.kOutput, - data=[7.0], + data.VariableMessage( + continuous=data.Array( + name="f", start=0, end=1, + type=data.VariableType.kOutput, data=[7.0], + ) ), - data.Array( - name="f", - start=1, - end=2, - type=data.VariableType.kOutput, - data=[8.0], + data.VariableMessage( + continuous=data.Array( + name="f", start=1, end=2, + type=data.VariableType.kOutput, data=[8.0], + ) ), ] self.assertEqual(result, expected_result) @@ -168,19 +164,17 @@ def test_residual_gradients(self): context = Mock() request_iterator = [ - data.Array( - start=0, - end=2, - data=[0.5, 1.5, 3.5], - type=data.VariableType.kInput, - name="x", + data.VariableMessage( + continuous=data.Array( + start=0, end=2, data=[0.5, 1.5, 3.5], + type=data.VariableType.kInput, name="x", + ) ), - data.Array( - start=3, - end=4, - data=[4.5, 5.5], - type=data.VariableType.kInput, - name="x", + data.VariableMessage( + continuous=data.Array( + start=3, end=4, data=[4.5, 5.5], + type=data.VariableType.kInput, name="x", + ) ), ] @@ -194,18 +188,18 @@ def residual_partials(inputs, residuals, jac): response_generator = server.ComputeResidualGradients(request_iterator, context) responses = list(response_generator) - # check that there is only one response + # check that there are two responses self.assertEqual(len(responses), 2) - # check the function value - response = responses[0] + # check the function value (unwrap VariableMessage) + response = responses[0].continuous self.assertEqual(response.name, "f") self.assertEqual(response.subname, "x") self.assertEqual(response.start, 0) self.assertEqual(response.end, 3) grad = np.array(response.data) - response = responses[1] + response = responses[1].continuous grad = np.append(grad, np.array(response.data)) self.assertTrue( np.array_equal(grad, np.array([-251.0, -499.0, 11105.0, 25007.0, -2950.0])) diff --git a/tests/test_openmdao_explicit_client.py b/tests/test_openmdao_explicit_client.py index bc02160..c8339e5 100644 --- a/tests/test_openmdao_explicit_client.py +++ b/tests/test_openmdao_explicit_client.py @@ -224,7 +224,9 @@ def test_compute(self, om_explicit_component_patch): instance.compute(inputs, outputs, discrete_inputs, discrete_outputs) # Asserting that the method calls are made correctly - client_mock.run_compute.assert_called_once_with({"input1": 10, "input2": 20}) + client_mock.run_compute.assert_called_once_with( + {"input1": 10, "input2": 20}, discrete_inputs=None + ) self.assertEqual(outputs["output1"], 30) self.assertEqual(outputs["output2"], 40) @@ -281,7 +283,7 @@ def test_compute_partials(self, om_explicit_component_patch): # Asserting that the method calls are made correctly client_mock.run_compute_partials.assert_called_once_with( - {"input1": 10, "input2": 20} + {"input1": 10, "input2": 20}, discrete_inputs=None ) self.assertEqual(partials["output1"]["input1"], 1) self.assertEqual(partials["output1"]["input2"], 2) diff --git a/tests/test_openmdao_implicit_client.py b/tests/test_openmdao_implicit_client.py index 96cfc4e..2895e7f 100644 --- a/tests/test_openmdao_implicit_client.py +++ b/tests/test_openmdao_implicit_client.py @@ -228,7 +228,9 @@ def test_apply_nonlinear(self, mock): ) # asserting that the method calls are made correctly - client_mock.run_compute_residuals.assert_called_once_with(inputs, outputs) + client_mock.run_compute_residuals.assert_called_once_with( + inputs, outputs, None, None + ) for res_name, expected_data in expected_residuals.items(): self.assertTrue(res_name in residuals) np.testing.assert_array_equal(residuals[res_name], expected_data) @@ -277,7 +279,7 @@ def test_solve_nonlinear(self, mock): comp.solve_nonlinear(inputs, outputs, discrete_inputs, discrete_outputs) # asserting that the method calls are made correctly - client_mock.run_solve_residuals.assert_called_once_with(inputs) + client_mock.run_solve_residuals.assert_called_once_with(inputs, None) for output_name, expected_data in expected_outputs.items(): self.assertTrue(output_name in outputs) np.testing.assert_array_equal(outputs[output_name], expected_data) @@ -343,7 +345,9 @@ def test_linearize(self, om_explicit_component_patch): comp.linearize(inputs, outputs, partials, discrete_inputs, discrete_outputs) # asserting that the method calls are made correctly - client_mock.run_residual_gradients.assert_called_once_with(inputs, outputs) + client_mock.run_residual_gradients.assert_called_once_with( + inputs, outputs, None, None + ) for jac_key, expected_data in expected_jac.items(): self.assertTrue(jac_key in partials) np.testing.assert_array_equal(partials[jac_key], expected_data) diff --git a/tests/test_openmdao_utils.py b/tests/test_openmdao_utils.py index 8011ffb..d928faa 100644 --- a/tests/test_openmdao_utils.py +++ b/tests/test_openmdao_utils.py @@ -55,6 +55,7 @@ def test_openmdao_client_setup(self): var2.shape = [1] comp._client._var_meta = [var1, var2] + comp._client._discrete_var_meta = [] utils.client_setup(comp) From 610fb833ea51179b7b93b15336242448875b774a Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Wed, 8 Apr 2026 18:24:49 -0400 Subject: [PATCH 3/8] test: add unit tests for discrete variable support Add comprehensive test suite covering all new discrete variable code paths: value conversion helpers, discipline base class, server/client discrete message handling, explicit/implicit server discrete dispatch, client discrete assembly/recovery, and OpenMDAO bindings. Coverage restored to 99% (remaining misses are import guards and an unreachable defensive branch). --- tests/test_discrete_variables.py | 687 +++++++++++++++++++++++++++++++ 1 file changed, 687 insertions(+) create mode 100644 tests/test_discrete_variables.py diff --git a/tests/test_discrete_variables.py b/tests/test_discrete_variables.py new file mode 100644 index 0000000..6b3811a --- /dev/null +++ b/tests/test_discrete_variables.py @@ -0,0 +1,687 @@ +# Philote-Python +# +# Copyright 2022-2025 Christopher A. Lupp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This work has been cleared for public release, distribution unlimited, case +# number: AFRL-2023-5713. +# +# The views expressed are those of the authors and do not reflect the +# official guidance or position of the United States Government, the +# Department of Defense or of the United States Air Force. +# +# Statement from DoD: The Appearance of external hyperlinks does not +# constitute endorsement by the United States Department of Defense (DoD) of +# the linked websites, of the information, products, or services contained +# therein. The DoD does not exercise any editorial, security, or other +# control over the information you may find at these locations. +""" +Unit tests for discrete variable support across the Philote stack. +""" +import unittest +from unittest.mock import Mock, MagicMock, patch + +import numpy as np +from google.protobuf import struct_pb2 + +import philote_mdo.generated.data_pb2 as data +from philote_mdo.general import ( + Discipline, + DisciplineClient, + DisciplineServer, + ExplicitDiscipline, + ExplicitServer, + ExplicitClient, + ImplicitDiscipline, + ImplicitServer, + ImplicitClient, +) +from philote_mdo.general.discipline_server import _value_to_python, _python_to_value +import philote_mdo.openmdao.utils as om_utils + + +# --------------------------------------------------------------------------- +# Value conversion helpers +# --------------------------------------------------------------------------- +class TestValueConversion(unittest.TestCase): + """Tests for _value_to_python and _python_to_value round-trip conversion.""" + + def test_none(self): + val = _python_to_value(None) + self.assertIsNone(_value_to_python(val)) + + def test_bool_true(self): + val = _python_to_value(True) + self.assertIs(_value_to_python(val), True) + + def test_bool_false(self): + val = _python_to_value(False) + self.assertIs(_value_to_python(val), False) + + def test_int(self): + val = _python_to_value(42) + result = _value_to_python(val) + self.assertEqual(result, 42) + self.assertIsInstance(result, int) + + def test_float(self): + val = _python_to_value(3.14) + result = _value_to_python(val) + self.assertAlmostEqual(result, 3.14) + self.assertIsInstance(result, float) + + def test_string(self): + val = _python_to_value("hello") + self.assertEqual(_value_to_python(val), "hello") + + def test_list(self): + original = [1, "two", 3.0, True] + val = _python_to_value(original) + result = _value_to_python(val) + self.assertEqual(result, [1, "two", 3, True]) + + def test_tuple_converts_to_list(self): + val = _python_to_value((1, 2)) + result = _value_to_python(val) + self.assertEqual(result, [1, 2]) + + def test_dict(self): + original = {"key": "value", "n": 5} + val = _python_to_value(original) + result = _value_to_python(val) + self.assertEqual(result, original) + + def test_nested_structure(self): + original = {"mesh": "coarse", "params": [1, 2, 3], "opts": {"tol": 1e-6}} + val = _python_to_value(original) + result = _value_to_python(val) + self.assertEqual(result["mesh"], "coarse") + self.assertEqual(result["params"], [1, 2, 3]) + + def test_unsupported_type_becomes_string(self): + """Unsupported types are serialized via str().""" + val = _python_to_value(object) + result = _value_to_python(val) + self.assertIsInstance(result, str) + + +# --------------------------------------------------------------------------- +# Discipline base class +# --------------------------------------------------------------------------- +class TestDisciplineDiscreteVars(unittest.TestCase): + """Tests for add_discrete_input / add_discrete_output on Discipline.""" + + def test_add_discrete_input(self): + d = Discipline() + d.add_discrete_input("mode") + self.assertEqual(len(d._discrete_var_meta), 1) + meta = d._discrete_var_meta[0] + self.assertEqual(meta.name, "mode") + self.assertEqual(meta.type, data.VariableType.kDiscreteInput) + + def test_add_discrete_output(self): + d = Discipline() + d.add_discrete_output("status") + self.assertEqual(len(d._discrete_var_meta), 1) + meta = d._discrete_var_meta[0] + self.assertEqual(meta.name, "status") + self.assertEqual(meta.type, data.VariableType.kDiscreteOutput) + + def test_clear_data_resets_discrete_meta(self): + d = Discipline() + d.add_discrete_input("x") + d.add_discrete_output("y") + d._clear_data() + self.assertEqual(len(d._discrete_var_meta), 0) + + +# --------------------------------------------------------------------------- +# DisciplineServer – discrete metadata streaming and process_inputs +# --------------------------------------------------------------------------- +class TestDisciplineServerDiscrete(unittest.TestCase): + """Tests for discrete variable handling in DisciplineServer.""" + + def test_get_variable_definitions_includes_discrete(self): + """GetVariableDefinitions should stream both continuous and discrete metadata.""" + server = DisciplineServer() + server._discipline = Discipline() + server._discipline.add_input("x", shape=(1,)) + server._discipline.add_discrete_input("mode") + + responses = list(server.GetVariableDefinitions(None, None)) + self.assertEqual(len(responses), 2) + types = [r.type for r in responses] + self.assertIn(data.VariableType.kInput, types) + self.assertIn(data.VariableType.kDiscreteInput, types) + + def test_process_inputs_with_discrete(self): + """process_inputs should demux continuous and discrete messages.""" + server = DisciplineServer() + + flat_inputs = {"x": np.zeros(2)} + discrete_inputs = {} + discrete_outputs = {} + + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + name="x", start=0, end=1, + type=data.VariableType.kInput, data=[1.0, 2.0], + ) + ), + data.VariableMessage( + discrete=data.DiscreteVariable( + name="mode", + type=data.VariableType.kDiscreteInput, + value=_python_to_value("forward"), + ) + ), + data.VariableMessage( + discrete=data.DiscreteVariable( + name="status", + type=data.VariableType.kDiscreteOutput, + value=_python_to_value(0), + ) + ), + ] + + di, do = server.process_inputs( + request_iterator, flat_inputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, + ) + + np.testing.assert_array_equal(flat_inputs["x"], [1.0, 2.0]) + self.assertEqual(di["mode"], "forward") + self.assertEqual(do["status"], 0) + + +# --------------------------------------------------------------------------- +# DisciplineClient – discrete variable definitions, assembly, recovery +# --------------------------------------------------------------------------- +class TestDisciplineClientDiscrete(unittest.TestCase): + """Tests for discrete variable handling in DisciplineClient.""" + + @patch("philote_mdo.generated.disciplines_pb2_grpc.DisciplineServiceStub") + def test_get_variable_definitions_separates_discrete(self, mock_stub_cls): + mock_channel = Mock() + mock_stub = mock_stub_cls.return_value + mock_stub.GetVariableDefinitions.return_value = [ + data.VariableMetaData(name="x", type=data.kInput, shape=[1]), + data.VariableMetaData( + name="mode", type=data.VariableType.kDiscreteInput + ), + data.VariableMetaData(name="f", type=data.kOutput, shape=[1]), + data.VariableMetaData( + name="status", type=data.VariableType.kDiscreteOutput + ), + ] + + client = DisciplineClient(mock_channel) + client.get_variable_definitions() + + self.assertEqual(len(client._var_meta), 2) + self.assertEqual(len(client._discrete_var_meta), 2) + + def test_assemble_input_messages_with_discrete(self): + """_assemble_input_messages should include discrete messages.""" + mock_channel = Mock() + client = DisciplineClient(mock_channel) + client._stream_options.num_double = 10 + + inputs = {"x": np.array([1.0])} + discrete_inputs = {"mode": "forward", "order": 3} + + messages = client._assemble_input_messages( + inputs, discrete_inputs=discrete_inputs + ) + + # 1 continuous + 2 discrete = 3 messages + self.assertEqual(len(messages), 3) + + # Check that discrete messages are present + discrete_msgs = [ + m for m in messages if m.WhichOneof("payload") == "discrete" + ] + self.assertEqual(len(discrete_msgs), 2) + + names = {m.discrete.name for m in discrete_msgs} + self.assertEqual(names, {"mode", "order"}) + + def test_assemble_input_messages_with_discrete_outputs(self): + """_assemble_input_messages should include discrete output messages.""" + mock_channel = Mock() + client = DisciplineClient(mock_channel) + client._stream_options.num_double = 10 + + inputs = {"x": np.array([1.0])} + discrete_outputs = {"status": 0} + + messages = client._assemble_input_messages( + inputs, discrete_outputs=discrete_outputs + ) + + # 1 continuous + 1 discrete output + self.assertEqual(len(messages), 2) + discrete_msgs = [ + m for m in messages if m.WhichOneof("payload") == "discrete" + ] + self.assertEqual(len(discrete_msgs), 1) + self.assertEqual(discrete_msgs[0].discrete.name, "status") + self.assertEqual( + discrete_msgs[0].discrete.type, data.VariableType.kDiscreteOutput + ) + + def test_recover_outputs_with_discrete(self): + """_recover_outputs should return (outputs, discrete_outputs) tuple.""" + mock_channel = Mock() + client = DisciplineClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + ] + + responses = [ + data.VariableMessage( + continuous=data.Array( + name="f", type=data.kOutput, start=0, end=0, data=[42.0] + ) + ), + data.VariableMessage( + discrete=data.DiscreteVariable( + name="status", + type=data.VariableType.kDiscreteOutput, + value=_python_to_value("converged"), + ) + ), + ] + + result = client._recover_outputs(responses) + + # Should be a tuple (continuous_outputs, discrete_outputs) + self.assertIsInstance(result, tuple) + outputs, d_outputs = result + np.testing.assert_array_equal(outputs["f"], [42.0]) + self.assertEqual(d_outputs["status"], "converged") + + +# --------------------------------------------------------------------------- +# ExplicitServer – discrete data flow +# --------------------------------------------------------------------------- +class TestExplicitServerDiscrete(unittest.TestCase): + """Tests for discrete variable handling in ExplicitServer.""" + + def test_compute_function_with_discrete(self): + """ComputeFunction should pass discrete data to discipline.compute.""" + server = ExplicitServer() + discipline = ExplicitDiscipline() + discipline.add_input("x", shape=(1,)) + discipline.add_output("f", shape=(1,)) + discipline.add_discrete_input("mode") + discipline.add_discrete_output("status") + server._discipline = discipline + server._stream_opts.num_double = 10 + + captured = {} + + def compute(inputs, outputs, discrete_inputs, discrete_outputs): + captured["discrete_inputs"] = dict(discrete_inputs) + outputs["f"] = inputs["x"] * 2 + discrete_outputs["status"] = "ok" + + discipline.compute = compute + + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + name="x", start=0, end=0, + type=data.VariableType.kInput, data=[3.0], + ) + ), + data.VariableMessage( + discrete=data.DiscreteVariable( + name="mode", + type=data.VariableType.kDiscreteInput, + value=_python_to_value("fast"), + ) + ), + ] + + responses = list(server.ComputeFunction(request_iterator, None)) + + # Should have received the discrete input + self.assertEqual(captured["discrete_inputs"]["mode"], "fast") + + # Should have continuous and discrete output responses + continuous_responses = [ + r for r in responses if r.WhichOneof("payload") == "continuous" + ] + discrete_responses = [ + r for r in responses if r.WhichOneof("payload") == "discrete" + ] + + self.assertEqual(len(continuous_responses), 1) + self.assertEqual(continuous_responses[0].continuous.data[0], 6.0) + + self.assertEqual(len(discrete_responses), 1) + self.assertEqual(discrete_responses[0].discrete.name, "status") + + def test_compute_gradient_with_discrete(self): + """ComputeGradient should pass discrete data to compute_partials.""" + server = ExplicitServer() + discipline = ExplicitDiscipline() + discipline.add_input("x", shape=(1,)) + discipline.add_output("f", shape=(1,)) + discipline.add_discrete_input("mode") + discipline.declare_partials("f", "x") + server._discipline = discipline + server._stream_opts.num_double = 10 + + captured = {} + + def compute_partials(inputs, jac, discrete_inputs): + captured["discrete_inputs"] = dict(discrete_inputs) + jac["f", "x"] = np.array([2.0]) + + discipline.compute_partials = compute_partials + + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + name="x", start=0, end=0, + type=data.VariableType.kInput, data=[3.0], + ) + ), + data.VariableMessage( + discrete=data.DiscreteVariable( + name="mode", + type=data.VariableType.kDiscreteInput, + value=_python_to_value("fast"), + ) + ), + ] + + responses = list(server.ComputeGradient(request_iterator, None)) + + self.assertEqual(captured["discrete_inputs"]["mode"], "fast") + self.assertEqual(len(responses), 1) + self.assertEqual(responses[0].continuous.data[0], 2.0) + + +# --------------------------------------------------------------------------- +# ImplicitServer – discrete data flow +# --------------------------------------------------------------------------- +class TestImplicitServerDiscrete(unittest.TestCase): + """Tests for discrete variable handling in ImplicitServer.""" + + def _make_discipline(self): + discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,)) + discipline.add_output("f", shape=(1,)) + discipline.add_discrete_input("mode") + discipline.declare_partials("f", "x") + return discipline + + def _make_request(self, x_val=1.0, f_val=0.0, mode_val="fast"): + return [ + data.VariableMessage( + continuous=data.Array( + name="x", start=0, end=0, + type=data.VariableType.kInput, data=[x_val], + ) + ), + data.VariableMessage( + continuous=data.Array( + name="f", start=0, end=0, + type=data.VariableType.kOutput, data=[f_val], + ) + ), + data.VariableMessage( + discrete=data.DiscreteVariable( + name="mode", + type=data.VariableType.kDiscreteInput, + value=_python_to_value(mode_val), + ) + ), + ] + + def test_compute_residuals_with_discrete(self): + server = ImplicitServer() + discipline = self._make_discipline() + server._discipline = discipline + server._stream_opts.num_double = 10 + + captured = {} + + def compute_residuals(inputs, outputs, residuals, di, do): + captured["mode"] = di["mode"] + residuals["f"] = outputs["f"] - inputs["x"] + + discipline.compute_residuals = compute_residuals + + responses = list( + server.ComputeResiduals(self._make_request(), None) + ) + + self.assertEqual(captured["mode"], "fast") + self.assertGreater(len(responses), 0) + + def test_solve_residuals_with_discrete(self): + server = ImplicitServer() + discipline = self._make_discipline() + server._discipline = discipline + server._stream_opts.num_double = 10 + + captured = {} + + def solve_residuals(inputs, outputs, di): + captured["mode"] = di["mode"] + outputs["f"] = inputs["x"] + + discipline.solve_residuals = solve_residuals + + responses = list( + server.SolveResiduals(self._make_request(), None) + ) + + self.assertEqual(captured["mode"], "fast") + self.assertGreater(len(responses), 0) + + def test_compute_residual_gradients_with_discrete(self): + server = ImplicitServer() + discipline = self._make_discipline() + server._discipline = discipline + server._stream_opts.num_double = 10 + + captured = {} + + def residual_partials(inputs, outputs, jac, di, do): + captured["mode"] = di["mode"] + jac["f", "x"] = np.array([1.0]) + + discipline.residual_partials = residual_partials + + responses = list( + server.ComputeResidualGradients(self._make_request(), None) + ) + + self.assertEqual(captured["mode"], "fast") + self.assertGreater(len(responses), 0) + + +# --------------------------------------------------------------------------- +# ExplicitClient – discrete round-trip +# --------------------------------------------------------------------------- +class TestExplicitClientDiscrete(unittest.TestCase): + """Tests for discrete inputs through ExplicitClient.""" + + @patch("philote_mdo.generated.disciplines_pb2_grpc.ExplicitServiceStub") + def test_run_compute_with_discrete(self, mock_stub_cls): + mock_channel = Mock() + mock_stub = mock_stub_cls.return_value + client = ExplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + ] + + mock_stub.ComputeFunction.return_value = [ + data.VariableMessage( + continuous=data.Array( + name="f", type=data.kOutput, start=0, end=0, data=[6.0] + ) + ), + ] + + result = client.run_compute( + {"x": np.array([3.0])}, + discrete_inputs={"mode": "fast"}, + ) + + # Should have been called with VariableMessage including discrete + call_args = mock_stub.ComputeFunction.call_args + messages = list(call_args[0][0]) + discrete_msgs = [ + m for m in messages if m.WhichOneof("payload") == "discrete" + ] + self.assertEqual(len(discrete_msgs), 1) + + @patch("philote_mdo.generated.disciplines_pb2_grpc.ExplicitServiceStub") + def test_run_compute_partials_with_discrete(self, mock_stub_cls): + mock_channel = Mock() + mock_stub = mock_stub_cls.return_value + client = ExplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + ] + client._partials_meta = [ + data.PartialsMetaData(name="f", subname="x"), + ] + + mock_stub.ComputeGradient.return_value = [ + data.VariableMessage( + continuous=data.Array( + name="f", subname="x", type=data.kPartial, + start=0, end=0, data=[2.0], + ) + ), + ] + + partials = client.run_compute_partials( + {"x": np.array([3.0])}, + discrete_inputs={"mode": "fast"}, + ) + + np.testing.assert_array_equal(partials[("f", "x")], [2.0]) + + +# --------------------------------------------------------------------------- +# OpenMDAO utils – discrete variable setup and extraction +# --------------------------------------------------------------------------- +class TestOpenMdaoUtilsDiscrete(unittest.TestCase): + """Tests for discrete variable support in OpenMDAO utility functions.""" + + def test_client_setup_declares_discrete_vars(self): + comp = MagicMock() + comp._client._var_meta = [ + data.VariableMetaData(name="x", type=data.kInput, shape=[1], units="m"), + ] + comp._client._discrete_var_meta = [ + data.VariableMetaData( + name="mode", type=data.VariableType.kDiscreteInput + ), + data.VariableMetaData( + name="status", type=data.VariableType.kDiscreteOutput + ), + ] + + om_utils.client_setup(comp) + + comp.add_discrete_input.assert_called_once_with("mode", val=None) + comp.add_discrete_output.assert_called_once_with("status", val=None) + + def test_create_local_discrete_inputs(self): + discrete_inputs = {"mode": "forward", "extra": 99} + meta = [ + data.VariableMetaData( + name="mode", type=data.VariableType.kDiscreteInput + ), + ] + + result = om_utils.create_local_discrete_inputs(discrete_inputs, meta) + self.assertEqual(result, {"mode": "forward"}) + + def test_create_local_discrete_inputs_none(self): + result = om_utils.create_local_discrete_inputs(None, []) + self.assertIsNone(result) + + def test_create_local_discrete_inputs_empty_returns_none(self): + """When no matching vars, should return None.""" + discrete_inputs = {"other": 1} + meta = [ + data.VariableMetaData( + name="mode", type=data.VariableType.kDiscreteOutput + ), + ] + result = om_utils.create_local_discrete_inputs(discrete_inputs, meta) + self.assertIsNone(result) + + +# --------------------------------------------------------------------------- +# OpenMDAO Explicit – discrete tuple return handling +# --------------------------------------------------------------------------- +@patch("openmdao.api.ExplicitComponent.__init__") +class TestOpenMdaoExplicitDiscrete(unittest.TestCase): + """Tests for discrete data flow through RemoteExplicitComponent.""" + + def test_compute_with_discrete_tuple_result(self, om_patch): + from philote_mdo.openmdao import RemoteExplicitComponent + + mock_channel = Mock() + comp = RemoteExplicitComponent(channel=mock_channel) + + client_mock = MagicMock() + client_mock._var_meta = [ + data.VariableMetaData(name="x", type=data.kInput, shape=[1]), + data.VariableMetaData(name="f", type=data.kOutput, shape=[1]), + ] + client_mock._discrete_var_meta = [ + data.VariableMetaData( + name="mode", type=data.VariableType.kDiscreteInput + ), + data.VariableMetaData( + name="status", type=data.VariableType.kDiscreteOutput + ), + ] + # Simulate server returning tuple (outputs, discrete_outputs) + client_mock.run_compute.return_value = ( + {"f": np.array([6.0])}, + {"status": "ok"}, + ) + comp._client = client_mock + comp.name = "test" + + inputs = {"x": np.array([3.0])} + outputs = {"f": np.zeros(1)} + discrete_inputs = {"mode": "fast"} + discrete_outputs = {"status": None} + + comp.compute(inputs, outputs, discrete_inputs, discrete_outputs) + + np.testing.assert_array_equal(outputs["f"], [6.0]) + self.assertEqual(discrete_outputs["status"], "ok") + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 1c60fe32f1dc6f7df4767d97059904a7422a7b1f Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Wed, 8 Apr 2026 18:39:10 -0400 Subject: [PATCH 4/8] ci: enforce 95% code coverage threshold Add codecov.yml requiring 95% coverage on both project total and patch (new/changed lines). Add fail_under = 95 to .coveragerc for local enforcement via coverage report. --- .coveragerc | 3 +++ codecov.yml | 10 ++++++++++ 2 files changed, 13 insertions(+) create mode 100644 codecov.yml diff --git a/.coveragerc b/.coveragerc index 49c8e0b..d5bf0ce 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,6 +2,9 @@ source = philote_mdo [report] +# Fail if total coverage drops below 95% +fail_under = 95 + # Exclude generated protobuf files from coverage reports exclude_lines = pragma: no cover diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..e70d445 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,10 @@ +coverage: + status: + project: + default: + target: 95% + threshold: 0% + patch: + default: + target: 95% + threshold: 0% From 5bcb45e42211b267b6353905d0f3fe4052984aa8 Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Wed, 8 Apr 2026 18:46:35 -0400 Subject: [PATCH 5/8] chore: mark unreachable branches with pragma no cover Add coverage exclusion pragmas to three unreachable code paths: - _value_to_python else fallback (all protobuf Value kinds handled) - openmdao/__init__.py ImportError guard (requires uninstalling OpenMDAO) - examples/__init__.py ImportError guard (same) Also fixes bare except to except ImportError in examples/__init__.py. --- philote_mdo/examples/__init__.py | 2 +- philote_mdo/general/discipline_server.py | 2 +- philote_mdo/openmdao/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/philote_mdo/examples/__init__.py b/philote_mdo/examples/__init__.py index 075fd5a..414b3e5 100644 --- a/philote_mdo/examples/__init__.py +++ b/philote_mdo/examples/__init__.py @@ -34,5 +34,5 @@ try: import openmdao.api from .sellar import SellarGroup -except: +except ImportError: # pragma: no cover pass diff --git a/philote_mdo/general/discipline_server.py b/philote_mdo/general/discipline_server.py index 5e57904..bde20a4 100644 --- a/philote_mdo/general/discipline_server.py +++ b/philote_mdo/general/discipline_server.py @@ -285,7 +285,7 @@ def _value_to_python(value): return [_value_to_python(v) for v in value.list_value.values] elif kind == "struct_value": return {k: _value_to_python(v) for k, v in value.struct_value.fields.items()} - else: + else: # pragma: no cover – all protobuf Value kinds are handled above return None diff --git a/philote_mdo/openmdao/__init__.py b/philote_mdo/openmdao/__init__.py index 13cbbb4..34919f4 100644 --- a/philote_mdo/openmdao/__init__.py +++ b/philote_mdo/openmdao/__init__.py @@ -31,7 +31,7 @@ import openmdao.api as om omdao_installed = True -except ImportError: +except ImportError: # pragma: no cover omdao_installed = False om = None From 028c3d6fde00f39dd1506ad8eb339c2aad6a420f Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Wed, 8 Apr 2026 18:53:07 -0400 Subject: [PATCH 6/8] test: add integration tests for discrete variable round-trip Add end-to-end integration tests using a ScaledParaboloid discipline that uses a discrete input (mode) to switch scaling behavior and a discrete output (label) to report the mode used. Tests cover: - Raw ExplicitClient compute with discrete inputs/outputs - Raw ExplicitClient compute_partials with discrete inputs - OpenMDAO RemoteExplicitComponent auto-discovery and forwarding of discrete variables through a real gRPC server --- tests/test_discrete_integration.py | 254 +++++++++++++++++++++++++++++ 1 file changed, 254 insertions(+) create mode 100644 tests/test_discrete_integration.py diff --git a/tests/test_discrete_integration.py b/tests/test_discrete_integration.py new file mode 100644 index 0000000..79e1316 --- /dev/null +++ b/tests/test_discrete_integration.py @@ -0,0 +1,254 @@ +# Philote-Python +# +# Copyright 2022-2025 Christopher A. Lupp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This work has been cleared for public release, distribution unlimited, case +# number: AFRL-2023-5713. +# +# The views expressed are those of the authors and do not reflect the +# official guidance or position of the United States Government, the +# Department of Defense or of the United States Air Force. +# +# Statement from DoD: The Appearance of external hyperlinks does not +# constitute endorsement by the United States Department of Defense (DoD) of +# the linked websites, of the information, products, or services contained +# therein. The DoD does not exercise any editorial, security, or other +# control over the information you may find at these locations. +""" +Integration tests for disciplines with discrete variables. + +These tests spin up a real gRPC server and client and exercise the full +discrete variable round-trip: declaration, metadata discovery, message +serialization, discipline evaluation, and result recovery. +""" +from concurrent import futures +import unittest + +import grpc +import numpy as np +import openmdao.api as om + +import philote_mdo.general as pmdo +import philote_mdo.openmdao as pmdo_om + + +# --------------------------------------------------------------------------- +# Example discipline with discrete inputs/outputs +# --------------------------------------------------------------------------- +class ScaledParaboloid(pmdo.ExplicitDiscipline): + """ + Paraboloid whose output is scaled by a discrete mode flag. + + Continuous inputs: x, y (scalars) + Discrete input: mode ("double" or "half") + Continuous output: f_xy (scalar) + Discrete output: label (string describing the mode used) + + f_xy = scale * ((x - 3)^2 + x*y + (y + 4)^2 - 3) + + where scale = 2.0 for "double" and 0.5 for "half". + """ + + def setup(self): + self.add_input("x", shape=(1,), units="m") + self.add_input("y", shape=(1,), units="m") + self.add_output("f_xy", shape=(1,), units="m**2") + + self.add_discrete_input("mode") + self.add_discrete_output("label") + + def setup_partials(self): + self.declare_partials("f_xy", "x") + self.declare_partials("f_xy", "y") + + def compute(self, inputs, outputs, discrete_inputs=None, discrete_outputs=None): + x = inputs["x"] + y = inputs["y"] + base = (x - 3.0) ** 2 + x * y + (y + 4.0) ** 2 - 3.0 + + mode = (discrete_inputs or {}).get("mode", None) or "double" + scale = 2.0 if mode == "double" else 0.5 + + outputs["f_xy"] = scale * base + + if discrete_outputs is not None: + discrete_outputs["label"] = f"scaled_{mode}" + + def compute_partials(self, inputs, partials, discrete_inputs=None): + x = inputs["x"] + y = inputs["y"] + + mode = (discrete_inputs or {}).get("mode", None) or "double" + scale = 2.0 if mode == "double" else 0.5 + + partials["f_xy", "x"] = scale * (2.0 * x - 6.0 + y) + partials["f_xy", "y"] = scale * (2.0 * y + 8.0 + x) + + +# --------------------------------------------------------------------------- +# Integration tests +# --------------------------------------------------------------------------- +class TestDiscreteIntegration(unittest.TestCase): + """ + End-to-end integration tests for disciplines with discrete variables. + """ + + def _start_server(self, discipline, port): + """Helper to start a gRPC server with the given discipline.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) + explicit_server = pmdo.ExplicitServer(discipline=discipline) + explicit_server.attach_to_server(server) + server.add_insecure_port(f"[::]:{port}") + server.start() + return server + + # ------------------------------------------------------------------ + # Raw client tests (no OpenMDAO) + # ------------------------------------------------------------------ + def test_client_compute_with_discrete_inputs(self): + """ + Test that a raw ExplicitClient can send discrete inputs and + receive both continuous and discrete outputs. + """ + port = 50061 + server = self._start_server(ScaledParaboloid(), port) + + try: + channel = grpc.insecure_channel(f"localhost:{port}") + client = pmdo.ExplicitClient(channel) + + # setup handshake + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + client.get_partials_definitions() + + # verify discrete metadata was discovered + self.assertTrue(len(client._discrete_var_meta) > 0) + discrete_names = {v.name for v in client._discrete_var_meta} + self.assertIn("mode", discrete_names) + self.assertIn("label", discrete_names) + + # run compute with mode="double" + inputs = {"x": np.array([1.0]), "y": np.array([2.0])} + result = client.run_compute(inputs, discrete_inputs={"mode": "double"}) + + # Should return (outputs, discrete_outputs) tuple + self.assertIsInstance(result, tuple) + outputs, discrete_outputs = result + + # base paraboloid value at (1, 2) = (1-3)^2 + 1*2 + (2+4)^2 - 3 = 39 + # scaled by 2.0 => 78.0 + np.testing.assert_almost_equal(outputs["f_xy"][0], 78.0) + self.assertEqual(discrete_outputs["label"], "scaled_double") + + # run compute with mode="half" + result = client.run_compute(inputs, discrete_inputs={"mode": "half"}) + outputs, discrete_outputs = result + + # 39 * 0.5 = 19.5 + np.testing.assert_almost_equal(outputs["f_xy"][0], 19.5) + self.assertEqual(discrete_outputs["label"], "scaled_half") + + finally: + server.stop(0) + + def test_client_compute_partials_with_discrete_inputs(self): + """ + Test that a raw ExplicitClient can send discrete inputs for + gradient evaluation. + """ + port = 50062 + server = self._start_server(ScaledParaboloid(), port) + + try: + channel = grpc.insecure_channel(f"localhost:{port}") + client = pmdo.ExplicitClient(channel) + + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + client.get_partials_definitions() + + inputs = {"x": np.array([1.0]), "y": np.array([2.0])} + + # mode="double", scale=2.0 + # df/dx = 2*(2*1 - 6 + 2) = 2*(-2) = -4.0 + # df/dy = 2*(2*2 + 8 + 1) = 2*(13) = 26.0 + partials = client.run_compute_partials( + inputs, discrete_inputs={"mode": "double"} + ) + + np.testing.assert_almost_equal(partials[("f_xy", "x")][0], -4.0) + np.testing.assert_almost_equal(partials[("f_xy", "y")][0], 26.0) + + # mode="half", scale=0.5 + # df/dx = 0.5*(-2) = -1.0 + # df/dy = 0.5*(13) = 6.5 + partials = client.run_compute_partials( + inputs, discrete_inputs={"mode": "half"} + ) + + np.testing.assert_almost_equal(partials[("f_xy", "x")][0], -1.0) + np.testing.assert_almost_equal(partials[("f_xy", "y")][0], 6.5) + + finally: + server.stop(0) + + # ------------------------------------------------------------------ + # OpenMDAO integration test + # ------------------------------------------------------------------ + def test_openmdao_compute_with_discrete(self): + """ + Test the full OpenMDAO integration: RemoteExplicitComponent + auto-discovers discrete variables from the server and returns + correct results. + """ + port = 50063 + server = self._start_server(ScaledParaboloid(), port) + + try: + channel = grpc.insecure_channel(f"localhost:{port}") + + prob = om.Problem() + comp = pmdo_om.RemoteExplicitComponent(channel=channel) + prob.model.add_subsystem("scaled", comp) + + prob.setup() + + # verify discrete variables were discovered + discrete_names = {v.name for v in comp._client._discrete_var_meta} + self.assertIn("mode", discrete_names) + self.assertIn("label", discrete_names) + + # set continuous inputs + prob.set_val("scaled.x", 1.0) + prob.set_val("scaled.y", 2.0) + + # run model – mode defaults to "double" when not set (scale=2.0) + prob.run_model() + + # base value at (1,2) = 39.0, scale = 2.0 => 78.0 + np.testing.assert_almost_equal( + prob.get_val("scaled.f_xy")[0], 78.0 + ) + + finally: + server.stop(0) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From a9ef2f766032211d04949825c6acb8ca98e162e3 Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Wed, 8 Apr 2026 19:26:22 -0400 Subject: [PATCH 7/8] docs: update changelog with all discrete variable changes Add entries for coverage enforcement, pragma no cover annotations, and the bare except fix under the appropriate sections. --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a6b3237..1885188 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Bug Fixes +- Fixed bare `except` to `except ImportError` in `examples/__init__.py`. - Fixed `SellarMDA` promoted-input ambiguity that newer OpenMDAO releases reject during `final_setup`. The `x` and `z` defaults were being set on the inner `cycle` subgroup, but `obj_cmp` promoted the same variables @@ -33,6 +34,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Documentation & Infrastructure +- Added Codecov configuration (`codecov.yml`) requiring 95% coverage on + both project total and patch (new/changed lines). +- Added `fail_under = 95` to `.coveragerc` for local coverage enforcement. +- Marked unreachable import guards and defensive branches with + `pragma: no cover`. - Updated installation instructions to reflect PyPI install option. - Added documentation for implicit disciplines. - Added documentation for OpenMDAO clients From 3cd30c7b0ca4f66bba999f7f4d69defc85e7c64f Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Wed, 8 Apr 2026 21:24:25 -0400 Subject: [PATCH 8/8] feat: add support for struct (dict) discipline options (#49) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add kStruct DataType to the protocol and wire it through the Python stack so disciplines can declare and receive complex nested data as options. Changes: - Update proto submodule (kStruct = 4 in DataType enum) - Regenerate gRPC stubs - Map "dict" ↔ kStruct in server and client type mappings - Support dict type in OpenMDAO declare_options utility - Add tests for struct option round-trip - Update CHANGELOG --- CHANGELOG.md | 3 ++ philote_mdo/general/discipline.py | 2 +- philote_mdo/general/discipline_client.py | 2 + philote_mdo/general/discipline_server.py | 2 + philote_mdo/generated/data_pb2.py | 8 ++-- philote_mdo/generated/data_pb2.pyi | 2 + philote_mdo/openmdao/utils.py | 2 + proto | 2 +- tests/test_discipline_client.py | 45 +++++++++++++++++++++ tests/test_discipline_server.py | 50 ++++++++++++++++++++++++ tests/test_edge_cases.py | 20 ++++++++++ 11 files changed, 132 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1885188..e160b23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Features +- Added support for struct (dict) options via the new `kStruct` DataType enum + value, enabling complex nested data to be declared and passed as discipline + options (#49). - Added discrete variable support throughout the stack. Disciplines can now declare discrete inputs/outputs via `add_discrete_input` / `add_discrete_output`. Discrete data is serialized as diff --git a/philote_mdo/general/discipline.py b/philote_mdo/general/discipline.py index f5930cf..ae1f020 100644 --- a/philote_mdo/general/discipline.py +++ b/philote_mdo/general/discipline.py @@ -69,7 +69,7 @@ def add_option(self, name, type): the name of the option being added type : string the data type of the option. acceptable types are 'bool', 'int', - 'float' + 'float', 'str', 'dict' """ self.options_list[name] = type diff --git a/philote_mdo/general/discipline_client.py b/philote_mdo/general/discipline_client.py index 385d077..13b206f 100644 --- a/philote_mdo/general/discipline_client.py +++ b/philote_mdo/general/discipline_client.py @@ -97,6 +97,8 @@ def get_available_options(self): type_str = "float" if val == data.kString: type_str = "str" + if val == data.kStruct: + type_str = "dict" self.options_list[name] = type_str def send_options(self, options): diff --git a/philote_mdo/general/discipline_server.py b/philote_mdo/general/discipline_server.py index bde20a4..d784b18 100644 --- a/philote_mdo/general/discipline_server.py +++ b/philote_mdo/general/discipline_server.py @@ -100,6 +100,8 @@ def GetAvailableOptions(self, request, context): type = data.kDouble elif val == "str": type = data.kString + elif val == "dict": + type = data.kStruct else: raise ValueError( "Invalid value for discipline option '{}'".format(name) diff --git a/philote_mdo/generated/data_pb2.py b/philote_mdo/generated/data_pb2.py index 8334dc1..b6029b2 100644 --- a/philote_mdo/generated/data_pb2.py +++ b/philote_mdo/generated/data_pb2.py @@ -7,7 +7,7 @@ _runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 5, 27, 2, '', 'data.proto') _sym_db = _symbol_database.Default() from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x07philote\x1a\x1cgoogle/protobuf/struct.proto"}\n\x14DisciplineProperties\x12\x12\n\ncontinuous\x18\x01 \x01(\x08\x12\x16\n\x0edifferentiable\x18\x02 \x01(\x08\x12\x1a\n\x12provides_gradients\x18\x03 \x01(\x08\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x0f\n\x07version\x18\x05 \x01(\t"#\n\rStreamOptions\x12\x12\n\nnum_double\x18\x01 \x01(\x03"?\n\x0bOptionsList\x12\x0f\n\x07options\x18\x01 \x03(\t\x12\x1f\n\x04type\x18\x02 \x03(\x0e2\x11.philote.DataType"=\n\x11DisciplineOptions\x12(\n\x07options\x18\x01 \x01(\x0b2\x17.google.protobuf.Struct"c\n\x10VariableMetaData\x12#\n\x04type\x18\x01 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05units\x18\x05 \x01(\t"@\n\x10PartialsMetaData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"u\n\x05Array\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05start\x18\x03 \x01(\x03\x12\x0b\n\x03end\x18\x04 \x01(\x03\x12#\n\x04type\x18\x05 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04data\x18\x06 \x03(\x01"l\n\x10DiscreteVariable\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x04type\x18\x02 \x01(\x0e2\x15.philote.VariableType\x12%\n\x05value\x18\x03 \x01(\x0b2\x16.google.protobuf.Value"q\n\x0fVariableMessage\x12$\n\ncontinuous\x18\x01 \x01(\x0b2\x0e.philote.ArrayH\x00\x12-\n\x08discrete\x18\x02 \x01(\x0b2\x19.philote.DiscreteVariableH\x00B\t\n\x07payload*9\n\x08DataType\x12\t\n\x05kBool\x10\x00\x12\x08\n\x04kInt\x10\x01\x12\x0b\n\x07kDouble\x10\x02\x12\x0b\n\x07kString\x10\x03*m\n\x0cVariableType\x12\n\n\x06kInput\x10\x00\x12\x12\n\x0ekDiscreteInput\x10\x01\x12\r\n\tkResidual\x10\x02\x12\x0b\n\x07kOutput\x10\x03\x12\x13\n\x0fkDiscreteOutput\x10\x04\x12\x0c\n\x08kPartial\x10\x05B\x11\n\x0forg.philote.mdob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x07philote\x1a\x1cgoogle/protobuf/struct.proto"}\n\x14DisciplineProperties\x12\x12\n\ncontinuous\x18\x01 \x01(\x08\x12\x16\n\x0edifferentiable\x18\x02 \x01(\x08\x12\x1a\n\x12provides_gradients\x18\x03 \x01(\x08\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x0f\n\x07version\x18\x05 \x01(\t"#\n\rStreamOptions\x12\x12\n\nnum_double\x18\x01 \x01(\x03"?\n\x0bOptionsList\x12\x0f\n\x07options\x18\x01 \x03(\t\x12\x1f\n\x04type\x18\x02 \x03(\x0e2\x11.philote.DataType"=\n\x11DisciplineOptions\x12(\n\x07options\x18\x01 \x01(\x0b2\x17.google.protobuf.Struct"c\n\x10VariableMetaData\x12#\n\x04type\x18\x01 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05units\x18\x05 \x01(\t"@\n\x10PartialsMetaData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"u\n\x05Array\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05start\x18\x03 \x01(\x03\x12\x0b\n\x03end\x18\x04 \x01(\x03\x12#\n\x04type\x18\x05 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04data\x18\x06 \x03(\x01"l\n\x10DiscreteVariable\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x04type\x18\x02 \x01(\x0e2\x15.philote.VariableType\x12%\n\x05value\x18\x03 \x01(\x0b2\x16.google.protobuf.Value"q\n\x0fVariableMessage\x12$\n\ncontinuous\x18\x01 \x01(\x0b2\x0e.philote.ArrayH\x00\x12-\n\x08discrete\x18\x02 \x01(\x0b2\x19.philote.DiscreteVariableH\x00B\t\n\x07payload*F\n\x08DataType\x12\t\n\x05kBool\x10\x00\x12\x08\n\x04kInt\x10\x01\x12\x0b\n\x07kDouble\x10\x02\x12\x0b\n\x07kString\x10\x03\x12\x0b\n\x07kStruct\x10\x04*m\n\x0cVariableType\x12\n\n\x06kInput\x10\x00\x12\x12\n\x0ekDiscreteInput\x10\x01\x12\r\n\tkResidual\x10\x02\x12\x0b\n\x07kOutput\x10\x03\x12\x13\n\x0fkDiscreteOutput\x10\x04\x12\x0c\n\x08kPartial\x10\x05B\x11\n\x0forg.philote.mdob\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', _globals) @@ -15,9 +15,9 @@ _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\x0forg.philote.mdo' _globals['_DATATYPE']._serialized_start = 856 - _globals['_DATATYPE']._serialized_end = 913 - _globals['_VARIABLETYPE']._serialized_start = 915 - _globals['_VARIABLETYPE']._serialized_end = 1024 + _globals['_DATATYPE']._serialized_end = 926 + _globals['_VARIABLETYPE']._serialized_start = 928 + _globals['_VARIABLETYPE']._serialized_end = 1037 _globals['_DISCIPLINEPROPERTIES']._serialized_start = 53 _globals['_DISCIPLINEPROPERTIES']._serialized_end = 178 _globals['_STREAMOPTIONS']._serialized_start = 180 diff --git a/philote_mdo/generated/data_pb2.pyi b/philote_mdo/generated/data_pb2.pyi index a699396..13c800f 100644 --- a/philote_mdo/generated/data_pb2.pyi +++ b/philote_mdo/generated/data_pb2.pyi @@ -12,6 +12,7 @@ class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): kInt: _ClassVar[DataType] kDouble: _ClassVar[DataType] kString: _ClassVar[DataType] + kStruct: _ClassVar[DataType] class VariableType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = () @@ -25,6 +26,7 @@ kBool: DataType kInt: DataType kDouble: DataType kString: DataType +kStruct: DataType kInput: VariableType kDiscreteInput: VariableType kResidual: VariableType diff --git a/philote_mdo/openmdao/utils.py b/philote_mdo/openmdao/utils.py index c11f881..3b95870 100644 --- a/philote_mdo/openmdao/utils.py +++ b/philote_mdo/openmdao/utils.py @@ -44,6 +44,8 @@ def declare_options(opt_list, options): opt_type = float elif type_str == "str": opt_type = str + elif type_str == "dict": + opt_type = dict options.declare(name, types=opt_type) diff --git a/proto b/proto index a5a21b8..73c7d76 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit a5a21b8c151c34f9f7b6f8fb717da76bb7e18011 +Subproject commit 73c7d76ddf6b2e72c65cc39ee79e2ffbfa3f6cc5 diff --git a/tests/test_discipline_client.py b/tests/test_discipline_client.py index 1e24d9b..d7be559 100644 --- a/tests/test_discipline_client.py +++ b/tests/test_discipline_client.py @@ -130,6 +130,51 @@ def test_get_available_options(self, mock_discipline_stub): } self.assertEqual(instance.options_list, expected_options_list) + @patch("philote_mdo.generated.disciplines_pb2_grpc.DisciplineServiceStub") + def test_get_available_options_with_dict_type(self, mock_discipline_stub): + """ + Tests that get_available_options correctly maps kStruct to 'dict'. + """ + mock_channel = Mock() + mock_stub = mock_discipline_stub.return_value + + instance = DisciplineClient(channel=mock_channel) + + mock_options = MagicMock() + mock_options.options = ["config", "flag"] + mock_options.type = [data.kStruct, data.kBool] + instance._disc_stub.GetAvailableOptions.return_value = mock_options + + instance.get_available_options() + + expected_options_list = { + "config": "dict", + "flag": "bool", + } + self.assertEqual(instance.options_list, expected_options_list) + + def test_send_options_with_nested_dict(self): + """ + Tests that send_options correctly serializes nested dict values via Struct. + """ + mock_stub = Mock() + mock_channel = Mock() + mock_channel.stub.return_value = mock_stub + + client = DisciplineClient(channel=mock_channel) + client._disc_stub = mock_stub + + options = { + "config": {"solver": "newton", "tol": 1e-6}, + "name": "test", + } + + client.send_options(options) + + expected_proto_options = data.DisciplineOptions() + expected_proto_options.options.update(options) + mock_stub.SetOptions.assert_called_once_with(expected_proto_options) + def test_send_options(self): # mock gRPC stub and channel mock_stub = Mock() diff --git a/tests/test_discipline_server.py b/tests/test_discipline_server.py index f9b2550..4fa1329 100644 --- a/tests/test_discipline_server.py +++ b/tests/test_discipline_server.py @@ -362,6 +362,56 @@ def test_process_inputs(self): self.assertEqual(flat_inputs["x"].tolist(), [1.0, 2.0, 3.0, 4.0, 5.0, 0.0]) self.assertEqual(flat_outputs["f"].tolist(), [0.1, 0.2, 0.0]) + def test_get_available_options_with_dict_type(self): + """ + Tests that GetAvailableOptions correctly maps dict options to kStruct. + """ + server = DisciplineServer() + + request_mock = Mock() + context_mock = None + + server._discipline = Discipline() + server._discipline.options_list = { + "config": "dict", + "flag": "bool", + } + + results = server.GetAvailableOptions(request_mock, context_mock) + + expected_options = ["config", "flag"] + expected_types = [data.kStruct, data.kBool] + + self.assertEqual(results.options, expected_options) + self.assertEqual(results.type, expected_types) + + def test_set_options_with_nested_dict(self): + """ + Tests that SetOptions correctly passes nested dict values through. + """ + server = DisciplineServer() + + request_mock = Mock() + context_mock = Mock() + + # set nested dict options in the request + request_mock.options = { + "config": {"solver": "newton", "tol": 1e-6, "nested": {"a": 1}}, + "name": "test", + } + + discipline_mock = Mock() + server._discipline = discipline_mock + + server.SetOptions(request_mock, context_mock) + + server._discipline.set_options.assert_called_once_with( + { + "config": {"solver": "newton", "tol": 1e-6, "nested": {"a": 1}}, + "name": "test", + } + ) + def test_get_available_options_invalid_type_raises_error(self): """ Tests that GetAvailableOptions raises ValueError for invalid option types. diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 72b5e40..0afc37b 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -73,6 +73,26 @@ def test_get_available_options_with_str_type(self): # The method should complete without error and return options self.assertIsNotNone(result) + def test_get_available_options_with_dict_type(self): + """ + Test GetAvailableOptions with dict option type (covers kStruct mapping). + """ + server = DisciplineServer() + discipline = Mock() + + discipline.options_list = {"config": "dict"} + + server.attach_discipline(discipline) + + request = Mock() + context = Mock() + + result = server.GetAvailableOptions(request, context) + + self.assertIsNotNone(result) + self.assertEqual(list(result.options), ["config"]) + self.assertEqual(list(result.type), [data.kStruct]) + def test_get_available_options_with_invalid_type(self): """ Test GetAvailableOptions with invalid option type (covers lines 100-103).