diff --git a/.gitattributes b/.gitattributes index eae260e931..82cbc9c0a8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ python/features.bzl export-subst tools/publish/*.txt linguist-generated=true +requirements_lock.txt linguist-generated=true diff --git a/.gitignore b/.gitignore index fb1b17e466..bd64959c44 100644 --- a/.gitignore +++ b/.gitignore @@ -48,8 +48,10 @@ user.bazelrc # CLion .clwb -# Python cache +# Python artifacts **/__pycache__/ +*.egg +*.egg-info # MODULE.bazel.lock is ignored for now as per recommendation from upstream. # See https://github.com/bazelbuild/bazel/issues/20369 diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fbe0e8960..ac51fadf6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -74,6 +74,8 @@ END_UNRELEASED_TEMPLATE {#v0-0-0-added} ### Added +* (pip,python) Added `pyproject_toml` attribute to `pip.default()` and `python.defaults()` + to read Python version from pyproject.toml `requires-python` field (must be `==X.Y.Z` format). * (toolchain) Added {obj}`python.override.toolchain_target_settings` to allow adding `config_setting` labels to all registered toolchains. * (windows) Full venv support for Windows is available. Set diff --git a/python/private/BUILD.bazel b/python/private/BUILD.bazel index db96957724..0a6cbab947 100644 --- a/python/private/BUILD.bazel +++ b/python/private/BUILD.bazel @@ -274,6 +274,8 @@ bzl_library( deps = [ ":full_version_bzl", ":platform_info_bzl", + ":pyproject_repo_bzl", + ":pyproject_utils_bzl", ":python_register_toolchains_bzl", ":pythons_hub_bzl", ":repo_utils_bzl", @@ -751,6 +753,16 @@ bzl_library( ], ) +bzl_library( + name = "pyproject_repo_bzl", + srcs = ["pyproject_repo.bzl"], +) + +bzl_library( + name = "pyproject_utils_bzl", + srcs = ["pyproject_utils.bzl"], +) + # Needed to define bzl_library targets for docgen. (We don't define the # bzl_library target here because it'd give our users a transitive dependency # on Skylib.) diff --git a/python/private/pypi/BUILD.bazel b/python/private/pypi/BUILD.bazel index e7d19ea636..1f3f646dea 100644 --- a/python/private/pypi/BUILD.bazel +++ b/python/private/pypi/BUILD.bazel @@ -131,6 +131,7 @@ bzl_library( ":whl_library_bzl", "//python/private:auth_bzl", "//python/private:normalize_name_bzl", + "//python/private:pyproject_utils_bzl", "//python/private:repo_utils_bzl", "@bazel_features//:features", "@pythons_hub//:interpreters_bzl", diff --git a/python/private/pypi/extension.bzl b/python/private/pypi/extension.bzl index e6052782fa..22e05b8fde 100644 --- a/python/private/pypi/extension.bzl +++ b/python/private/pypi/extension.bzl @@ -19,6 +19,7 @@ load("@pythons_hub//:versions.bzl", "MINOR_MAPPING") load("@rules_python_internal//:rules_python_config.bzl", rp_config = "config") load("//python/private:auth.bzl", "AUTH_ATTRS") load("//python/private:normalize_name.bzl", "normalize_name") +load("//python/private:pyproject_utils.bzl", "read_pyproject_version") load("//python/private:repo_utils.bzl", "repo_utils") load(":hub_builder.bzl", "hub_builder") load(":hub_repository.bzl", "hub_repository", "whl_config_settings_to_json") @@ -205,12 +206,23 @@ def build_config( """ defaults = { "platforms": default_platforms(), + "python_version": None, } for mod in module_ctx.modules: if not (mod.is_root or mod.name == "rules_python"): continue for tag in mod.tags.default: + pyproject_toml = tag.pyproject_toml + if pyproject_toml: + pyproject_version = read_pyproject_version( + module_ctx, + pyproject_toml, + logger = None, + ) + if pyproject_version: + defaults["python_version"] = pyproject_version + platform = tag.platform if platform: specific_config = defaults["platforms"].setdefault(platform, {}) @@ -246,6 +258,7 @@ def build_config( auth_patterns = defaults.get("auth_patterns", {}), index_url = defaults.get("index_url", "https://pypi.org/simple").rstrip("/"), netrc = defaults.get("netrc", None), + python_version = defaults.get("python_version", None), platforms = { name: _plat(**values) for name, values in defaults["platforms"].items() @@ -345,6 +358,10 @@ You cannot use both the additive_build_content and additive_build_content_file a for mod in module_ctx.modules: for pip_attr in mod.tags.parse: + python_version = pip_attr.python_version or config.python_version + if not python_version: + _fail("pip.parse() requires either python_version attribute or pip.default(pyproject_toml=...) to be set") + hub_name = pip_attr.hub_name if hub_name not in pip_hub_map: builder = hub_builder( @@ -381,6 +398,7 @@ You cannot use both the additive_build_content and additive_build_content_file a builder.pip_parse( module_ctx, pip_attr = pip_attr, + python_version = python_version, ) # Keeps track of all the hub's whl repos across the different versions. @@ -536,7 +554,7 @@ Either this or {attr}`env` `platform_machine` key should be specified. """, ), "config_settings": attr.label_list( - mandatory = True, + mandatory = False, doc = """\ The list of labels to `config_setting` targets that need to be matched for the platform to be selected. @@ -618,6 +636,21 @@ If you are defining custom platforms in your project and don't want things to cl [isolation] feature. [isolation]: https://bazel.build/rules/lib/globals/module#use_extension.isolate +""", + ), + "pyproject_toml": attr.label( + mandatory = False, + doc = """\ +Label pointing to pyproject.toml file to read the default Python version from. +When specified, reads the `requires-python` field from pyproject.toml and uses +it as the default python_version for all `pip.parse()` calls that don't +explicitly specify one. + +The version must be specified as `==X.Y.Z` (exact version with full semver). +This is designed to work with dependency management tools like Renovate. + +:::{versionadded} VERSION_NEXT_FEATURE +::: """, ), "whl_abi_tags": attr.string_list( @@ -778,7 +811,7 @@ find in case extra indexes are specified. default = True, ), "python_version": attr.string( - mandatory = True, + mandatory = False, doc = """ The Python version the dependencies are targetting, in Major.Minor format (e.g., "3.11") or patch level granularity (e.g. "3.11.1"). @@ -786,6 +819,10 @@ The Python version the dependencies are targetting, in Major.Minor format If an interpreter isn't explicitly provided (using `python_interpreter` or `python_interpreter_target`), then the version specified here must have a corresponding `python.toolchain()` configured. + +:::{seealso} +The {obj}`pyproject_toml` attribute for getting the version from a project file. +::: """, ), "simpleapi_skip": attr.string_list( diff --git a/python/private/pypi/hub_builder.bzl b/python/private/pypi/hub_builder.bzl index 85a31cfc3c..889bc88368 100644 --- a/python/private/pypi/hub_builder.bzl +++ b/python/private/pypi/hub_builder.bzl @@ -147,8 +147,8 @@ def _build(self): whl_libraries = self._whl_libraries, ) -def _pip_parse(self, module_ctx, pip_attr): - python_version = pip_attr.python_version +def _pip_parse(self, module_ctx, pip_attr, python_version = None): + python_version = python_version or pip_attr.python_version if python_version in self._platforms: fail(( "Duplicate pip python version '{version}' for hub " + @@ -194,7 +194,8 @@ def _pip_parse(self, module_ctx, pip_attr): self, module_ctx, pip_attr = pip_attr, - enable_pipstar_extract = bool(self._config.enable_pipstar_extract or self._get_index_urls.get(pip_attr.python_version)), + python_version = python_version, + enable_pipstar_extract = bool(self._config.enable_pipstar_extract or self._get_index_urls.get(python_version)), ) ### end of PUBLIC methods @@ -387,11 +388,11 @@ def _set_get_index_urls(self, pip_attr): ) return True -def _detect_interpreter(self, pip_attr): +def _detect_interpreter(self, pip_attr, python_version): python_interpreter_target = pip_attr.python_interpreter_target if python_interpreter_target == None and not pip_attr.python_interpreter: python_name = "python_{}_host".format( - pip_attr.python_version.replace(".", "_"), + python_version.replace(".", "_"), ) if python_name not in self._available_interpreters: fail(( @@ -401,7 +402,7 @@ def _detect_interpreter(self, pip_attr): "Expected to find {python_name} among registered versions:\n {labels}" ).format( hub_name = self.name, - version = pip_attr.python_version, + version = python_version, python_name = python_name, labels = " \n".join(self._available_interpreters), )) @@ -465,13 +466,13 @@ def _platforms(module_ctx, *, python_version, config, target_platforms): ) return platforms -def _evaluate_markers(self, pip_attr): +def _evaluate_markers(self, python_version): if self._evaluate_markers_fn: return self._evaluate_markers_fn return lambda _, requirements: evaluate_markers_star( requirements = requirements, - platforms = self._platforms[pip_attr.python_version], + platforms = self._platforms[python_version], ) def _create_whl_repos( @@ -479,6 +480,7 @@ def _create_whl_repos( module_ctx, *, pip_attr, + python_version, enable_pipstar_extract = False): """create all of the whl repositories @@ -486,10 +488,11 @@ def _create_whl_repos( self: the builder. module_ctx: {type}`module_ctx`. pip_attr: {type}`struct` - the struct that comes from the tag class iteration. + python_version: {type}`str` - the resolved python version for this pip.parse call. enable_pipstar_extract: {type}`bool` - enable the pipstar extraction or not. """ logger = self._logger - platforms = self._platforms[pip_attr.python_version] + platforms = self._platforms[python_version] requirements_by_platform = parse_requirements( module_ctx, requirements_by_platform = requirements_files_by_platform( @@ -501,15 +504,15 @@ def _create_whl_repos( extra_pip_args = pip_attr.extra_pip_args, platforms = sorted(platforms), # here we only need keys python_version = full_version( - version = pip_attr.python_version, + version = python_version, minor_mapping = self._minor_mapping, ), logger = logger, ), platforms = platforms, extra_pip_args = pip_attr.extra_pip_args, - get_index_urls = self._get_index_urls.get(pip_attr.python_version), - evaluate_markers = _evaluate_markers(self, pip_attr), + get_index_urls = self._get_index_urls.get(python_version), + evaluate_markers = _evaluate_markers(self, python_version), logger = logger, ) @@ -530,7 +533,7 @@ def _create_whl_repos( pip_attr = pip_attr, ) - interpreter = _detect_interpreter(self, pip_attr) + interpreter = _detect_interpreter(self, pip_attr, python_version) for whl in requirements_by_platform: whl_library_args = common_args | _whl_library_args( @@ -545,16 +548,16 @@ def _create_whl_repos( whl_library_args = whl_library_args, download_only = pip_attr.download_only, netrc = self._config.netrc or pip_attr.netrc, - use_downloader = src.url and _use_downloader(self, pip_attr.python_version, whl.name), + use_downloader = src.url and _use_downloader(self, python_version, whl.name), auth_patterns = self._config.auth_patterns or pip_attr.auth_patterns, - python_version = _major_minor_version(pip_attr.python_version), + python_version = _major_minor_version(python_version), is_multiple_versions = whl.is_multiple_versions, interpreter = interpreter, enable_pipstar_extract = enable_pipstar_extract, ) _add_whl_library( self, - python_version = pip_attr.python_version, + python_version = python_version, whl = whl, repo = repo, ) diff --git a/python/private/pyproject_repo.bzl b/python/private/pyproject_repo.bzl new file mode 100644 index 0000000000..f54e0aa80e --- /dev/null +++ b/python/private/pyproject_repo.bzl @@ -0,0 +1,80 @@ +"""Repository rule to expose Python version from pyproject.toml.""" + +_TOML2JSON = Label("//tools/private/toml2json:toml2json.py") + +def _parse_requires_python(requires_python): + """Parse and validate the requires-python field.""" + if not requires_python.startswith("=="): + fail("requires-python must use '==' for exact version, got: {}".format(requires_python)) + + bare_version = requires_python[2:].strip() + parts = bare_version.split(".") + if len(parts) != 3: + fail("requires-python must be in X.Y.Z format, got: {}".format(bare_version)) + for part in parts: + if not part.isdigit(): + fail("requires-python must be in X.Y.Z format, got: {}".format(bare_version)) + + return bare_version + +def _pyproject_version_repo_impl(rctx): + """Create a repository that exports PYTHON_VERSION from pyproject.toml.""" + pyproject_path = rctx.path(rctx.attr.pyproject_toml) + rctx.read(pyproject_path, watch = "yes") + + toml2json = rctx.path(_TOML2JSON) + result = rctx.execute([ + "python3", + str(toml2json), + str(pyproject_path), + ]) + + if result.return_code != 0: + fail("Failed to parse pyproject.toml: " + result.stderr) + + data = json.decode(result.stdout) + requires_python = data.get("project", {}).get("requires-python") + if not requires_python: + fail("pyproject.toml must contain [project] requires-python field") + + version = _parse_requires_python(requires_python) + + rctx.file("version.bzl", """\ +\"\"\"Python version from pyproject.toml. + +This file is automatically generated. Do not edit. +\"\"\" + +PYTHON_VERSION = "{version}" +""".format(version = version)) + + rctx.file("BUILD.bazel", """\ +# Automatically generated from pyproject.toml +exports_files(["version.bzl"]) +""") + +pyproject_version_repo = repository_rule( + implementation = _pyproject_version_repo_impl, + attrs = { + "pyproject_toml": attr.label( + mandatory = True, + doc = "Label pointing to pyproject.toml file.", + ), + }, + doc = """Repository rule that reads Python version from pyproject.toml. + +This rule creates a repository with a `version.bzl` file that exports +`PYTHON_VERSION` constant. + +Example: +```python + load("@python_version_from_pyproject//:version.bzl", "PYTHON_VERSION") + + compile_pip_requirements( + name = "requirements", + python_version = PYTHON_VERSION, + requirements_txt = "requirements.txt", + ) +``` +""", +) diff --git a/python/private/pyproject_utils.bzl b/python/private/pyproject_utils.bzl new file mode 100644 index 0000000000..6b79870c47 --- /dev/null +++ b/python/private/pyproject_utils.bzl @@ -0,0 +1,66 @@ +"""Utilities for reading Python version from pyproject.toml.""" + +_TOML2JSON = Label("//tools/private/toml2json:toml2json.py") + +def _parse_requires_python(requires_python): + """Parse and validate the requires-python field. + + Args: + requires_python: The raw requires-python string from pyproject.toml. + + Returns: + The bare version string (e.g. "3.13.9"). + """ + if not requires_python.startswith("=="): + fail("requires-python must use '==' for exact version, got: {}".format(requires_python)) + + bare_version = requires_python[2:].strip() + + # Validate X.Y.Z format + parts = bare_version.split(".") + if len(parts) != 3: + fail("requires-python must be in X.Y.Z format, got: {}".format(bare_version)) + for part in parts: + if not part.isdigit(): + fail("requires-python must be in X.Y.Z format, got: {}".format(bare_version)) + + return bare_version + +def read_pyproject_version(module_ctx, pyproject_label, logger = None): + """Reads Python version from pyproject.toml if requested. + + Args: + module_ctx: The module_ctx object from the module extension. + pyproject_label: Label pointing to the pyproject.toml file, or None. + logger: Optional logger instance for informational messages. + + Returns: + The Python version string (e.g. "3.13.9") or None if pyproject_label is None. + """ + if not pyproject_label: + return None + + pyproject_path = module_ctx.path(pyproject_label) + module_ctx.read(pyproject_path, watch = "yes") + + toml2json = module_ctx.path(_TOML2JSON) + result = module_ctx.execute([ + "python3", + str(toml2json), + str(pyproject_path), + ]) + + if result.return_code != 0: + fail("Failed to parse pyproject.toml: " + result.stderr) + + data = json.decode(result.stdout) + requires_python = data.get("project", {}).get("requires-python") + if not requires_python: + fail("pyproject.toml must contain [project] requires-python field") + + version = _parse_requires_python(requires_python) + + if logger: + logger.info(lambda: "Read Python version {} from {}".format(version, pyproject_label)) + + return version diff --git a/python/private/python.bzl b/python/private/python.bzl index 2ea757892e..458026a661 100644 --- a/python/private/python.bzl +++ b/python/private/python.bzl @@ -19,6 +19,8 @@ load("//python:versions.bzl", "DEFAULT_RELEASE_BASE_URL", "PLATFORMS", "TOOL_VER load(":auth.bzl", "AUTH_ATTRS") load(":full_version.bzl", "full_version") load(":platform_info.bzl", "platform_info") +load(":pyproject_repo.bzl", "pyproject_version_repo") +load(":pyproject_utils.bzl", "read_pyproject_version") load(":python_register_toolchains.bzl", "python_register_toolchains") load(":pythons_hub.bzl", "hub_repo") load(":repo_utils.bzl", "repo_utils") @@ -87,6 +89,7 @@ def parse_modules(*, module_ctx, logger, _fail = fail): mod = mod, seen_versions = seen_versions, config = config, + default_python_version = default_python_version, ) for toolchain_attr in toolchain_attr_structs: @@ -216,6 +219,20 @@ def _python_impl(module_ctx): logger = repo_utils.logger(module_ctx, "python") py = parse_modules(module_ctx = module_ctx, logger = logger) + # Create pyproject version repo if pyproject.toml is used + created_pyproject_repo = False + for mod in module_ctx.modules: + if mod.is_root: + for tag in mod.tags.defaults: + if tag.pyproject_toml: + pyproject_version_repo( + name = "python_version_from_pyproject", + pyproject_toml = tag.pyproject_toml, + ) + created_pyproject_repo = True + break + break + # Host compatible runtime repos # dict[str version, struct] where struct has: # * full_python_version: str @@ -459,7 +476,16 @@ def _python_impl(module_ctx): ) if bazel_features.external_deps.extension_metadata_has_reproducible: - return module_ctx.extension_metadata(reproducible = True) + # Build the list of direct dependencies + root_direct_deps = ["pythons_hub", "python_versions"] + if created_pyproject_repo: + root_direct_deps.append("python_version_from_pyproject") + + return module_ctx.extension_metadata( + root_module_direct_deps = root_direct_deps, + root_module_direct_dev_deps = [], + reproducible = True, + ) else: return None @@ -861,8 +887,15 @@ def _compute_default_python_version(mctx): defaults_attr_structs = _create_defaults_attr_structs(mod = mod) default_python_version_env = None default_python_version_file = None + pyproject_toml_label = None for defaults_attr in defaults_attr_structs: + pyproject_toml_label = _one_or_the_same( + pyproject_toml_label, + defaults_attr.pyproject_toml, + onerror = lambda: fail("Multiple pyproject.toml files specified in defaults"), + ) + default_python_version = _one_or_the_same( default_python_version, defaults_attr.python_version, @@ -878,11 +911,21 @@ def _compute_default_python_version(mctx): defaults_attr.python_version_file, onerror = _fail_multiple_defaults_python_version_file, ) + + # Priority order: ENV > pyproject_toml > python_version_file > python_version if default_python_version_file: default_python_version = _one_or_the_same( default_python_version, mctx.read(default_python_version_file, watch = "yes").strip(), ) + if pyproject_toml_label: + pyproject_version = read_pyproject_version( + mctx, + pyproject_toml_label, + logger = None, + ) + if pyproject_version: + default_python_version = pyproject_version if default_python_version_env: default_python_version = mctx.getenv( default_python_version_env, @@ -924,11 +967,29 @@ def _create_defaults_attr_struct(*, tag): python_version = getattr(tag, "python_version", None), python_version_env = getattr(tag, "python_version_env", None), python_version_file = getattr(tag, "python_version_file", None), + pyproject_toml = getattr(tag, "pyproject_toml", None), ) -def _create_toolchain_attr_structs(*, mod, config, seen_versions): +def _create_toolchain_attr_structs(*, mod, config, seen_versions, default_python_version): arg_structs = [] + # Auto-register a toolchain for the default version if not already + # registered via an explicit python.toolchain() call. + # This works for any default source: pyproject_toml, python_version_file, + # python_version_env, or python_version. + has_explicit_toolchain = default_python_version and any([ + tag.python_version == default_python_version + for tag in mod.tags.toolchain + ]) + if (default_python_version and + default_python_version not in seen_versions and + mod.is_root and not has_explicit_toolchain): + arg_structs.append(_create_toolchain_attrs_struct( + python_version = default_python_version, + toolchain_tag_count = 1, + )) + seen_versions[default_python_version] = True + for tag in mod.tags.toolchain: arg_structs.append(_create_toolchain_attrs_struct( tag = tag, @@ -968,6 +1029,17 @@ def _create_toolchain_attrs_struct( _defaults = tag_class( doc = """Tag class to specify the default Python version.""", attrs = { + "pyproject_toml": attr.label( + mandatory = False, + doc = """\ +Label pointing to pyproject.toml file to read the default Python version from. +When specified, reads the `requires-python` field from pyproject.toml. +The version must be specified as `==X.Y.Z` (exact version with full semver). + +:::{versionadded} VERSION_NEXT_FEATURE +::: +""", + ), "python_version": attr.string( mandatory = False, doc = """\ diff --git a/tests/pypi/extension/extension_tests.bzl b/tests/pypi/extension/extension_tests.bzl index 5a40714b64..eea38e2b7f 100644 --- a/tests/pypi/extension/extension_tests.bzl +++ b/tests/pypi/extension/extension_tests.bzl @@ -50,6 +50,7 @@ def _default( netrc = None, os_name = None, platform = None, + pyproject_toml = None, whl_platform_tags = None, whl_abi_tags = None): return struct( @@ -62,6 +63,7 @@ def _default( netrc = netrc, os_name = os_name, platform = platform, + pyproject_toml = pyproject_toml, whl_abi_tags = whl_abi_tags or [], whl_platform_tags = whl_platform_tags or [], ) diff --git a/tests/pypi/extension/pip_parse.bzl b/tests/pypi/extension/pip_parse.bzl index 2d55d5cd1f..95cf666056 100644 --- a/tests/pypi/extension/pip_parse.bzl +++ b/tests/pypi/extension/pip_parse.bzl @@ -65,6 +65,5 @@ def pip_parse( parallel_download = False, experimental_index_url_overrides = {}, simpleapi_skip = simpleapi_skip, - _evaluate_markers_srcs = [], **kwargs ) diff --git a/tests/tools/private/toml2json/BUILD.bazel b/tests/tools/private/toml2json/BUILD.bazel new file mode 100644 index 0000000000..e8830f0030 --- /dev/null +++ b/tests/tools/private/toml2json/BUILD.bazel @@ -0,0 +1,10 @@ +load("@rules_python//python:defs.bzl", "py_test") + +py_test( + name = "toml2json_test", + srcs = ["toml2json_test.py"], + main = "toml2json_test.py", + deps = [ + "//tools/private/toml2json", + ], +) diff --git a/tests/tools/private/toml2json/toml2json_test.py b/tests/tools/private/toml2json/toml2json_test.py new file mode 100644 index 0000000000..6c63f7e499 --- /dev/null +++ b/tests/tools/private/toml2json/toml2json_test.py @@ -0,0 +1,62 @@ +import io +import json +import os +import sys +import tempfile +import unittest +from unittest.mock import patch + +from tools.private.toml2json import toml2json + +class Toml2JsonTest(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.addCleanup(self.temp_dir.cleanup) + + def _create_temp_toml_file(self, content): + fd, path = tempfile.mkstemp(suffix=".toml", dir=self.temp_dir.name) + with os.fdopen(fd, "wb") as f: + f.write(content) + return path + + def test_basic_conversion(self): + toml_content = b""" +[owner] +name = "Tom Preston-Werner" +dob = 1979-05-27T07:32:00-08:00 +""" + expected_json = { + "owner": { + "name": "Tom Preston-Werner", + "dob": "1979-05-27T07:32:00-08:00" + } + } + + toml_file_path = self._create_temp_toml_file(toml_content) + + with patch('sys.stdout', new=io.StringIO()) as mock_stdout: + with patch('sys.argv', ['toml2json.py', toml_file_path]): + toml2json.main() + actual_json = json.loads(mock_stdout.getvalue()) + self.assertEqual(actual_json, expected_json) + + def test_invalid_toml(self): + toml_content = b""" +[owner +name = "Tom Preston-Werner" +""" + + toml_file_path = self._create_temp_toml_file(toml_content) + + with patch('sys.stderr', new=io.StringIO()) as mock_stderr: + with patch('sys.stdout', new=io.StringIO()): # We don't expect stdout for errors + with patch('sys.exit') as mock_exit: + with patch('sys.argv', ['toml2json.py', toml_file_path]): + toml2json.main() + mock_exit.assert_called_with(1) + self.assertIn("Error decoding TOML", mock_stderr.getvalue()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/private/toml2json/BUILD.bazel b/tools/private/toml2json/BUILD.bazel new file mode 100644 index 0000000000..428f512923 --- /dev/null +++ b/tools/private/toml2json/BUILD.bazel @@ -0,0 +1,9 @@ +load("@rules_python//python:defs.bzl", "py_binary") + +exports_files(["toml2json.py"]) + +py_binary( + name = "toml2json", + srcs = ["toml2json.py"], + visibility = ["//visibility:public"], +) diff --git a/tools/private/toml2json/toml2json.py b/tools/private/toml2json/toml2json.py new file mode 100644 index 0000000000..84b7f9c30f --- /dev/null +++ b/tools/private/toml2json/toml2json.py @@ -0,0 +1,42 @@ +import json +import sys +import datetime + +try: + import tomllib +except ImportError: + try: + import tomli as tomllib + except ImportError: + print("Error: need tomllib (python >=3.11) or tomli installed on host python", file=sys.stderr) + sys.exit(1) + + +def json_serializer(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +def main(): + if len(sys.argv) < 2: + print("Usage: toml2json ", file=sys.stderr) + sys.exit(1) + + toml_file_path = sys.argv[1] + + try: + with open(toml_file_path, "rb") as f: + data = tomllib.load(f) + json.dump(data, sys.stdout, indent=2, default=json_serializer) + print() + except FileNotFoundError: + print(f"Error: File not found: {toml_file_path}", file=sys.stderr) + sys.exit(1) + except tomllib.TOMLDecodeError as e: + print(f"Error decoding TOML: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main()