diff --git a/CHANGELOG.md b/CHANGELOG.md index 1885188..e160b23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Features +- Added support for struct (dict) options via the new `kStruct` DataType enum + value, enabling complex nested data to be declared and passed as discipline + options (#49). - Added discrete variable support throughout the stack. Disciplines can now declare discrete inputs/outputs via `add_discrete_input` / `add_discrete_output`. Discrete data is serialized as diff --git a/philote_mdo/general/discipline.py b/philote_mdo/general/discipline.py index f5930cf..ae1f020 100644 --- a/philote_mdo/general/discipline.py +++ b/philote_mdo/general/discipline.py @@ -69,7 +69,7 @@ def add_option(self, name, type): the name of the option being added type : string the data type of the option. acceptable types are 'bool', 'int', - 'float' + 'float', 'str', 'dict' """ self.options_list[name] = type diff --git a/philote_mdo/general/discipline_client.py b/philote_mdo/general/discipline_client.py index 385d077..13b206f 100644 --- a/philote_mdo/general/discipline_client.py +++ b/philote_mdo/general/discipline_client.py @@ -97,6 +97,8 @@ def get_available_options(self): type_str = "float" if val == data.kString: type_str = "str" + if val == data.kStruct: + type_str = "dict" self.options_list[name] = type_str def send_options(self, options): diff --git a/philote_mdo/general/discipline_server.py b/philote_mdo/general/discipline_server.py index bde20a4..d784b18 100644 --- a/philote_mdo/general/discipline_server.py +++ b/philote_mdo/general/discipline_server.py @@ -100,6 +100,8 @@ def GetAvailableOptions(self, request, context): type = data.kDouble elif val == "str": type = data.kString + elif val == "dict": + type = data.kStruct else: raise ValueError( "Invalid value for discipline option '{}'".format(name) diff --git a/philote_mdo/generated/data_pb2.py b/philote_mdo/generated/data_pb2.py index 8334dc1..b6029b2 100644 --- a/philote_mdo/generated/data_pb2.py +++ b/philote_mdo/generated/data_pb2.py @@ -7,7 +7,7 @@ _runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 5, 27, 2, '', 'data.proto') _sym_db = _symbol_database.Default() from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x07philote\x1a\x1cgoogle/protobuf/struct.proto"}\n\x14DisciplineProperties\x12\x12\n\ncontinuous\x18\x01 \x01(\x08\x12\x16\n\x0edifferentiable\x18\x02 \x01(\x08\x12\x1a\n\x12provides_gradients\x18\x03 \x01(\x08\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x0f\n\x07version\x18\x05 \x01(\t"#\n\rStreamOptions\x12\x12\n\nnum_double\x18\x01 \x01(\x03"?\n\x0bOptionsList\x12\x0f\n\x07options\x18\x01 \x03(\t\x12\x1f\n\x04type\x18\x02 \x03(\x0e2\x11.philote.DataType"=\n\x11DisciplineOptions\x12(\n\x07options\x18\x01 \x01(\x0b2\x17.google.protobuf.Struct"c\n\x10VariableMetaData\x12#\n\x04type\x18\x01 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05units\x18\x05 \x01(\t"@\n\x10PartialsMetaData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"u\n\x05Array\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05start\x18\x03 \x01(\x03\x12\x0b\n\x03end\x18\x04 \x01(\x03\x12#\n\x04type\x18\x05 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04data\x18\x06 \x03(\x01"l\n\x10DiscreteVariable\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x04type\x18\x02 \x01(\x0e2\x15.philote.VariableType\x12%\n\x05value\x18\x03 \x01(\x0b2\x16.google.protobuf.Value"q\n\x0fVariableMessage\x12$\n\ncontinuous\x18\x01 \x01(\x0b2\x0e.philote.ArrayH\x00\x12-\n\x08discrete\x18\x02 \x01(\x0b2\x19.philote.DiscreteVariableH\x00B\t\n\x07payload*9\n\x08DataType\x12\t\n\x05kBool\x10\x00\x12\x08\n\x04kInt\x10\x01\x12\x0b\n\x07kDouble\x10\x02\x12\x0b\n\x07kString\x10\x03*m\n\x0cVariableType\x12\n\n\x06kInput\x10\x00\x12\x12\n\x0ekDiscreteInput\x10\x01\x12\r\n\tkResidual\x10\x02\x12\x0b\n\x07kOutput\x10\x03\x12\x13\n\x0fkDiscreteOutput\x10\x04\x12\x0c\n\x08kPartial\x10\x05B\x11\n\x0forg.philote.mdob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x07philote\x1a\x1cgoogle/protobuf/struct.proto"}\n\x14DisciplineProperties\x12\x12\n\ncontinuous\x18\x01 \x01(\x08\x12\x16\n\x0edifferentiable\x18\x02 \x01(\x08\x12\x1a\n\x12provides_gradients\x18\x03 \x01(\x08\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x0f\n\x07version\x18\x05 \x01(\t"#\n\rStreamOptions\x12\x12\n\nnum_double\x18\x01 \x01(\x03"?\n\x0bOptionsList\x12\x0f\n\x07options\x18\x01 \x03(\t\x12\x1f\n\x04type\x18\x02 \x03(\x0e2\x11.philote.DataType"=\n\x11DisciplineOptions\x12(\n\x07options\x18\x01 \x01(\x0b2\x17.google.protobuf.Struct"c\n\x10VariableMetaData\x12#\n\x04type\x18\x01 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05units\x18\x05 \x01(\t"@\n\x10PartialsMetaData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"u\n\x05Array\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05start\x18\x03 \x01(\x03\x12\x0b\n\x03end\x18\x04 \x01(\x03\x12#\n\x04type\x18\x05 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04data\x18\x06 \x03(\x01"l\n\x10DiscreteVariable\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x04type\x18\x02 \x01(\x0e2\x15.philote.VariableType\x12%\n\x05value\x18\x03 \x01(\x0b2\x16.google.protobuf.Value"q\n\x0fVariableMessage\x12$\n\ncontinuous\x18\x01 \x01(\x0b2\x0e.philote.ArrayH\x00\x12-\n\x08discrete\x18\x02 \x01(\x0b2\x19.philote.DiscreteVariableH\x00B\t\n\x07payload*F\n\x08DataType\x12\t\n\x05kBool\x10\x00\x12\x08\n\x04kInt\x10\x01\x12\x0b\n\x07kDouble\x10\x02\x12\x0b\n\x07kString\x10\x03\x12\x0b\n\x07kStruct\x10\x04*m\n\x0cVariableType\x12\n\n\x06kInput\x10\x00\x12\x12\n\x0ekDiscreteInput\x10\x01\x12\r\n\tkResidual\x10\x02\x12\x0b\n\x07kOutput\x10\x03\x12\x13\n\x0fkDiscreteOutput\x10\x04\x12\x0c\n\x08kPartial\x10\x05B\x11\n\x0forg.philote.mdob\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', _globals) @@ -15,9 +15,9 @@ _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\x0forg.philote.mdo' _globals['_DATATYPE']._serialized_start = 856 - _globals['_DATATYPE']._serialized_end = 913 - _globals['_VARIABLETYPE']._serialized_start = 915 - _globals['_VARIABLETYPE']._serialized_end = 1024 + _globals['_DATATYPE']._serialized_end = 926 + _globals['_VARIABLETYPE']._serialized_start = 928 + _globals['_VARIABLETYPE']._serialized_end = 1037 _globals['_DISCIPLINEPROPERTIES']._serialized_start = 53 _globals['_DISCIPLINEPROPERTIES']._serialized_end = 178 _globals['_STREAMOPTIONS']._serialized_start = 180 diff --git a/philote_mdo/generated/data_pb2.pyi b/philote_mdo/generated/data_pb2.pyi index a699396..13c800f 100644 --- a/philote_mdo/generated/data_pb2.pyi +++ b/philote_mdo/generated/data_pb2.pyi @@ -12,6 +12,7 @@ class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): kInt: _ClassVar[DataType] kDouble: _ClassVar[DataType] kString: _ClassVar[DataType] + kStruct: _ClassVar[DataType] class VariableType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = () @@ -25,6 +26,7 @@ kBool: DataType kInt: DataType kDouble: DataType kString: DataType +kStruct: DataType kInput: VariableType kDiscreteInput: VariableType kResidual: VariableType diff --git a/philote_mdo/openmdao/utils.py b/philote_mdo/openmdao/utils.py index c11f881..3b95870 100644 --- a/philote_mdo/openmdao/utils.py +++ b/philote_mdo/openmdao/utils.py @@ -44,6 +44,8 @@ def declare_options(opt_list, options): opt_type = float elif type_str == "str": opt_type = str + elif type_str == "dict": + opt_type = dict options.declare(name, types=opt_type) diff --git a/proto b/proto index a5a21b8..300d686 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit a5a21b8c151c34f9f7b6f8fb717da76bb7e18011 +Subproject commit 300d6865d4d5b394853d06604b73c4f8e3a0c4a1 diff --git a/tests/test_discipline_client.py b/tests/test_discipline_client.py index 1e24d9b..d7be559 100644 --- a/tests/test_discipline_client.py +++ b/tests/test_discipline_client.py @@ -130,6 +130,51 @@ def test_get_available_options(self, mock_discipline_stub): } self.assertEqual(instance.options_list, expected_options_list) + @patch("philote_mdo.generated.disciplines_pb2_grpc.DisciplineServiceStub") + def test_get_available_options_with_dict_type(self, mock_discipline_stub): + """ + Tests that get_available_options correctly maps kStruct to 'dict'. + """ + mock_channel = Mock() + mock_stub = mock_discipline_stub.return_value + + instance = DisciplineClient(channel=mock_channel) + + mock_options = MagicMock() + mock_options.options = ["config", "flag"] + mock_options.type = [data.kStruct, data.kBool] + instance._disc_stub.GetAvailableOptions.return_value = mock_options + + instance.get_available_options() + + expected_options_list = { + "config": "dict", + "flag": "bool", + } + self.assertEqual(instance.options_list, expected_options_list) + + def test_send_options_with_nested_dict(self): + """ + Tests that send_options correctly serializes nested dict values via Struct. + """ + mock_stub = Mock() + mock_channel = Mock() + mock_channel.stub.return_value = mock_stub + + client = DisciplineClient(channel=mock_channel) + client._disc_stub = mock_stub + + options = { + "config": {"solver": "newton", "tol": 1e-6}, + "name": "test", + } + + client.send_options(options) + + expected_proto_options = data.DisciplineOptions() + expected_proto_options.options.update(options) + mock_stub.SetOptions.assert_called_once_with(expected_proto_options) + def test_send_options(self): # mock gRPC stub and channel mock_stub = Mock() diff --git a/tests/test_discipline_server.py b/tests/test_discipline_server.py index f9b2550..2f30542 100644 --- a/tests/test_discipline_server.py +++ b/tests/test_discipline_server.py @@ -362,6 +362,55 @@ def test_process_inputs(self): self.assertEqual(flat_inputs["x"].tolist(), [1.0, 2.0, 3.0, 4.0, 5.0, 0.0]) self.assertEqual(flat_outputs["f"].tolist(), [0.1, 0.2, 0.0]) + def test_get_available_options_with_dict_type(self): + """ + Tests that GetAvailableOptions correctly maps dict options to kStruct. + """ + server = DisciplineServer() + + request_mock = Mock() + context_mock = None + + server._discipline = Discipline() + server._discipline.options_list = { + "config": "dict", + "flag": "bool", + } + + results = server.GetAvailableOptions(request_mock, context_mock) + + expected_options = ["config", "flag"] + expected_types = [data.kStruct, data.kBool] + + self.assertEqual(results.options, expected_options) + self.assertEqual(results.type, expected_types) + + def test_set_options_with_nested_dict(self): + """ + Tests that SetOptions correctly passes nested dict values through. + """ + server = DisciplineServer() + + request_mock = Mock() + context_mock = Mock() + + request_mock.options = { + "config": {"solver": "newton", "tol": 1e-6, "nested": {"a": 1}}, + "name": "test", + } + + discipline_mock = Mock() + server._discipline = discipline_mock + + server.SetOptions(request_mock, context_mock) + + server._discipline.set_options.assert_called_once_with( + { + "config": {"solver": "newton", "tol": 1e-6, "nested": {"a": 1}}, + "name": "test", + } + ) + def test_get_available_options_invalid_type_raises_error(self): """ Tests that GetAvailableOptions raises ValueError for invalid option types. diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 72b5e40..0afc37b 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -73,6 +73,26 @@ def test_get_available_options_with_str_type(self): # The method should complete without error and return options self.assertIsNotNone(result) + def test_get_available_options_with_dict_type(self): + """ + Test GetAvailableOptions with dict option type (covers kStruct mapping). + """ + server = DisciplineServer() + discipline = Mock() + + discipline.options_list = {"config": "dict"} + + server.attach_discipline(discipline) + + request = Mock() + context = Mock() + + result = server.GetAvailableOptions(request, context) + + self.assertIsNotNone(result) + self.assertEqual(list(result.options), ["config"]) + self.assertEqual(list(result.type), [data.kStruct]) + def test_get_available_options_with_invalid_type(self): """ Test GetAvailableOptions with invalid option type (covers lines 100-103). diff --git a/tests/test_integration.py b/tests/test_integration.py index 4229d46..12be39f 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -222,5 +222,81 @@ def test_quadratic_residual_gradients(self): server.stop(0) +class StructOptionDiscipline(pmdo.ExplicitDiscipline): + """ + Minimal discipline that declares a dict option and uses it in compute. + """ + + def initialize(self): + self.add_option("config", "dict") + + def set_options(self, options): + self.config = dict(options["config"]) + + def setup(self): + self.add_input("x", shape=(1,), units="") + self.add_output("f", shape=(1,), units="") + + def setup_partials(self): + self.declare_partials("f", "x") + + def compute(self, inputs, outputs): + scale = self.config.get("scale", 1.0) + offset = self.config.get("offset", 0.0) + outputs["f"] = scale * inputs["x"] + offset + + def compute_partials(self, inputs, partials): + scale = self.config.get("scale", 1.0) + partials["f", "x"] = np.array([scale]) + + +class StructOptionIntegrationTests(unittest.TestCase): + """ + Integration tests for struct (dict) options round-trip over gRPC. + """ + + def test_struct_option_round_trip(self): + """ + Tests that a discipline with a dict option can be discovered, + set with a nested dict, and used for compute over gRPC. + """ + # server + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + discipline = pmdo.ExplicitServer(discipline=StructOptionDiscipline()) + discipline.attach_to_server(server) + server.add_insecure_port("[::]:50051") + server.start() + + try: + # client + client = pmdo.ExplicitClient( + channel=grpc.insecure_channel("localhost:50051") + ) + + # discover options and verify dict type + client.get_available_options() + self.assertEqual(client.options_list["config"], "dict") + + # send nested dict option + client.send_options({"config": {"scale": 3.0, "offset": 5.0}}) + + # standard setup + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + client.get_partials_definitions() + + # compute: f = 3.0 * 2.0 + 5.0 = 11.0 + inputs = {"x": np.array([2.0])} + outputs = client.run_compute(inputs) + self.assertAlmostEqual(outputs["f"][0], 11.0) + + # partials: df/dx = 3.0 + jac = client.run_compute_partials(inputs) + self.assertAlmostEqual(jac["f", "x"][0], 3.0) + finally: + server.stop(0) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_openmdao_utils.py b/tests/test_openmdao_utils.py index d928faa..95a5acd 100644 --- a/tests/test_openmdao_utils.py +++ b/tests/test_openmdao_utils.py @@ -188,21 +188,23 @@ def test_declare_options(self): ("max_iter", "int"), ("tolerance", "float"), ("method", "str"), - ("verbose", "bool") + ("verbose", "bool"), + ("config", "dict"), ] declare_options(opt_list, options_mock) - + expected_calls = [ ("max_iter", int), ("tolerance", float), ("method", str), - ("verbose", bool) + ("verbose", bool), + ("config", dict), ] - + for name, opt_type in expected_calls: options_mock.declare.assert_any_call(name, types=opt_type) - - self.assertEqual(options_mock.declare.call_count, 4) + + self.assertEqual(options_mock.declare.call_count, 5) # Test case 3: Unknown type (should result in None) options_mock.reset_mock()