From 7158788dc68c1541968fcd99207179ce35bdbd31 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Tue, 9 Jun 2026 15:21:26 +0000 Subject: [PATCH] SQL Physical Operator Fixes and Enhancements --- .../sql/impl/planner/BeamRuleSets.java | 15 +++--- .../sql/impl/planner/RelMdNodeStats.java | 6 +++ .../extensions/sql/impl/rel/BeamCalcRel.java | 7 ++- .../sql/impl/rel/BeamEnumerableConverter.java | 19 ++++++- .../sql/impl/rel/BeamSetOperatorRelBase.java | 29 +++++++---- .../sql/impl/rel/BeamValuesRel.java | 51 +++++++++++++++++-- .../transform/BeamBuiltinAggregations.java | 5 ++ .../sql/BeamSqlDslAggregationTest.java | 32 ++++++++++++ .../impl/rel/BeamEnumerableConverterTest.java | 17 +++++++ 9 files changed, 158 insertions(+), 23 deletions(-) diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/BeamRuleSets.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/BeamRuleSets.java index 8382fcb6e382..8e4de66683e4 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/BeamRuleSets.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/BeamRuleSets.java @@ -151,13 +151,14 @@ public class BeamRuleSets { ImmutableList.of(BeamEnumerableConverterRule.INSTANCE); public static Collection getRuleSets() { + return ImmutableList.of(RuleSets.ofList(getAllRules())); + } - return ImmutableList.of( - RuleSets.ofList( - ImmutableList.builder() - .addAll(BEAM_CONVERTERS) - .addAll(BEAM_TO_ENUMERABLE) - .addAll(LOGICAL_OPTIMIZATIONS) - .build())); + public static List getAllRules() { + return ImmutableList.builder() + .addAll(BEAM_CONVERTERS) + .addAll(BEAM_TO_ENUMERABLE) + .addAll(LOGICAL_OPTIMIZATIONS) + .build(); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/RelMdNodeStats.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/RelMdNodeStats.java index 40db8b074efa..264151fecbf4 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/RelMdNodeStats.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/RelMdNodeStats.java @@ -57,6 +57,12 @@ public NodeStats getNodeStats(RelNode rel, RelMetadataQuery mq) { return this.getBeamNodeStats((BeamRelNode) rel, bmq); } + if (rel instanceof org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.core.Values) { + org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.core.Values values = + (org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.core.Values) rel; + return NodeStats.create(values.getTuples().size()); + } + // We can later define custom methods for all different RelNodes to prevent hitting this point. // Similar to RelMdRowCount in calcite. diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java index fad96abb29a5..a3a909a92403 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java @@ -269,7 +269,10 @@ public CalcFn( List jarPaths, FieldAccessDescriptor fieldAccess, boolean collectErrors) { - this.processElementBlock = processElementBlock; + this.processElementBlock = + processElementBlock.replace( + "(byte[]) org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.runtime.SqlFunctions.concat", + "org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.runtime.SqlFunctions.concat"); this.outputSchema = outputSchema; this.verifyRowValues = verifyRowValues; this.jarPaths = jarPaths; @@ -277,7 +280,7 @@ public CalcFn( this.collectErrors = collectErrors; // validate generated code - compile(processElementBlock, jarPaths); + compile(this.processElementBlock, jarPaths); } private static ScriptEvaluator compile(String processElementBlock, List jarPaths) { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverter.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverter.java index 69eb0a0a8d5c..7f867423e839 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverter.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverter.java @@ -159,11 +159,28 @@ static List toRowList(PipelineOptions options, BeamRelNode node) { if (node instanceof BeamIOSinkRel) { throw new UnsupportedOperationException("Does not support BeamIOSinkRel in toRowList."); } else if (isLimitQuery(node)) { - throw new UnsupportedOperationException("Does not support queries with LIMIT in toRowList."); + return limitRowList(options, node); } return collectRows(options, node).stream().collect(Collectors.toList()); } + private static List limitRowList(PipelineOptions options, BeamRelNode node) { + long id = options.getOptionsId(); + ConcurrentLinkedQueue values = new ConcurrentLinkedQueue<>(); + int limitCount = getLimitCount(node); + + Collector.globalValues.put(id, values); + limitRun(options, node, new Collector(), values, limitCount); + Collector.globalValues.remove(id); + + // remove extra retrieved values + while (values.size() > limitCount) { + values.remove(); + } + + return values.stream().collect(Collectors.toList()); + } + static Enumerable toEnumerable(PipelineOptions options, BeamRelNode node) { if (node instanceof BeamIOSinkRel) { return count(options, node); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBase.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBase.java index cad845c2b3a9..956e48188b50 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBase.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBase.java @@ -17,8 +17,6 @@ */ package org.apache.beam.sdk.extensions.sql.impl.rel; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - import java.io.Serializable; import org.apache.beam.sdk.extensions.sql.impl.transform.BeamSetOperatorsTransforms; import org.apache.beam.sdk.schemas.transforms.CoGroup; @@ -59,14 +57,27 @@ public BeamSetOperatorRelBase(BeamRelNode beamRelNode, OpType opType, boolean al @Override public PCollection expand(PCollectionList inputs) { - checkArgument( - inputs.size() == 2, - "Wrong number of arguments to %s: %s", - beamRelNode.getClass().getSimpleName(), - inputs); - PCollection leftRows = inputs.get(0); - PCollection rightRows = inputs.get(1); + // Reverted Flatten optimization as it fails when inputs have slightly different schemas (e.g. + // NULL vs VARCHAR) + // if (opType == OpType.UNION && all) { + // return inputs.apply("UnionAllFlatten", Flatten.pCollections()); + // } + + if (inputs.size() == 2) { + return expandPair(inputs.get(0), inputs.get(1)); + } else if (inputs.size() > 2) { + PCollection result = inputs.get(0); + for (int i = 1; i < inputs.size(); i++) { + result = expandPair(result, inputs.get(i)); + } + return result; + } else { + throw new IllegalArgumentException( + "Too few arguments to " + beamRelNode.getClass().getSimpleName()); + } + } + private PCollection expandPair(PCollection leftRows, PCollection rightRows) { WindowFn leftWindow = leftRows.getWindowingStrategy().getWindowFn(); WindowFn rightWindow = rightRows.getWindowingStrategy().getWindowFn(); if (!leftWindow.isCompatible(rightWindow)) { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java index 6163dbe770ae..81d0a417c2fb 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java @@ -30,8 +30,10 @@ import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Impulse; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.Row; @@ -84,15 +86,41 @@ public PCollection expand(PCollectionList pinput) { BeamValuesRel.class.getSimpleName(), pinput); - Schema schema = CalciteUtils.toSchema(getRowType()); + Schema inferredSchema = CalciteUtils.toSchema(getRowType()); + Schema.Builder schemaBuilder = Schema.builder(); + for (int i = 0; i < inferredSchema.getFieldCount(); i++) { + Schema.Field field = inferredSchema.getField(i); + boolean hasNull = false; + for (ImmutableList tuple : tuples) { + if (tuple.get(i).getValue() == null) { + hasNull = true; + break; + } + } + if (hasNull && !field.getType().getNullable()) { + schemaBuilder.addField(field.getName(), field.getType().withNullable(true)); + } else { + schemaBuilder.addField(field); + } + } + Schema schema = schemaBuilder.build(); List rows = tuples.stream().map(tuple -> tupleToRow(schema, tuple)).collect(toList()); - return pinput.getPipeline().begin().apply(Create.of(rows).withRowSchema(schema)); + return pinput + .getPipeline() + .begin() + .apply(Impulse.create()) + .apply(ParDo.of(new EmitRowsFn(rows))) + .setRowSchema(schema); } } private Row tupleToRow(Schema schema, ImmutableList tuple) { return IntStream.range(0, tuple.size()) - .mapToObj(i -> autoCastField(schema.getField(i), tuple.get(i).getValue())) + .mapToObj( + i -> { + Object val = tuple.get(i).getValue(); + return autoCastField(schema.getField(i), val); + }) .collect(toRow(schema)); } @@ -106,4 +134,19 @@ public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQ NodeStats estimates = BeamSqlRelUtils.getNodeStats(this, mq); return BeamCostModel.FACTORY.makeCost(estimates.getRowCount(), estimates.getRate()); } + + private static class EmitRowsFn extends DoFn { + private final List rows; + + public EmitRowsFn(List rows) { + this.rows = rows; + } + + @ProcessElement + public void processElement(ProcessContext c) { + for (Row row : rows) { + c.output(row); + } + } + } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java index 3fc299bd5a33..66993abe4057 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java @@ -61,6 +61,11 @@ public class BeamBuiltinAggregations { BUILTIN_AGGREGATOR_FACTORIES = ImmutableMap.>>builder() .put("ANY_VALUE", typeName -> Sample.anyValueCombineFn()) + // SINGLE_VALUE is emitted by Calcite to enforce the cardinality of a scalar + // subquery (a subquery used as a scalar must yield exactly one row). The single + // input value is returned as-is; unlike COUNT/SUM it must not drop nulls, so a + // scalar subquery evaluating to NULL surfaces NULL. + .put("SINGLE_VALUE", typeName -> Sample.anyValueCombineFn()) // Drop null elements for these aggregations. .put("COUNT", typeName -> new DropNullFnWithDefault(Count.combineFn())) .put("MAX", typeName -> new DropNullFn(BeamBuiltinAggregations.createMax(typeName))) diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java index 37243fe7d5f6..ab74c61a771a 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java @@ -1133,4 +1133,36 @@ public void testCountIfFunction() throws Exception { PAssert.that(result).containsInAnyOrder(rowResult); pipeline.run().waitUntilFinish(); } + + @Test + public void testCorrelatedScalarSubqueryWithSingleValue() { + Schema schema = Schema.builder().addInt32Field("id").addStringField("name").build(); + PCollection input = + pipeline.apply( + "input", + Create.of( + Row.withSchema(schema).addValues(1, "a").build(), + Row.withSchema(schema).addValues(2, "b").build()) + .withRowSchema(schema)); + + // Correlated subquery. Calcite should decorrelate this into a LEFT JOIN + // and use SINGLE_VALUE aggregation on the subquery side. + String sql = + "SELECT t1.id, (SELECT t2.name FROM PCOLLECTION t2 WHERE t2.id = t1.id) FROM PCOLLECTION t1"; + + PCollection result = input.apply("sql_subquery", SqlTransform.query(sql)); + + Schema expectedSchema = + Schema.builder() + .addInt32Field("id") + .addNullableField("name", Schema.FieldType.STRING) + .build(); + + PAssert.that(result) + .containsInAnyOrder( + Row.withSchema(expectedSchema).addValues(1, "a").build(), + Row.withSchema(expectedSchema).addValues(2, "b").build()); + + pipeline.run(); + } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverterTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverterTest.java index 697643d4715c..2b486e20d978 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverterTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverterTest.java @@ -23,8 +23,10 @@ import java.math.BigDecimal; import java.util.List; +import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.extensions.sql.meta.SchemaBaseBeamTable; +import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestBoundedTable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.schemas.Schema; @@ -125,6 +127,21 @@ public void testToListRow_collectMultiple() { assertEquals(Row.withSchema(schema).addValues(0L, 1L).build(), rowList.get(0)); } + @Test + public void testToRowList_limit() { + Schema schema = Schema.builder().addInt64Field("id").build(); + java.util.Map tables = + new java.util.HashMap<>(); + tables.put("TEST", TestBoundedTable.of(Schema.FieldType.INT64, "id").addRows(1L, 2L, 3L)); + BeamSqlEnv env = BeamSqlEnv.readOnly("test", tables); + BeamRelNode node = env.parseQuery("SELECT id FROM TEST LIMIT 2"); + + List rowList = BeamEnumerableConverter.toRowList(options, node); + assertEquals(2, rowList.size()); + assertTrue(rowList.contains(Row.withSchema(schema).addValue(1L).build())); + assertTrue(rowList.contains(Row.withSchema(schema).addValue(2L).build())); + } + private static class FakeTable extends SchemaBaseBeamTable { public FakeTable() { super(null);