diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/UdfImpl.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/UdfImpl.java index 7ebd3faea782..63cd3c90419f 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/UdfImpl.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/UdfImpl.java @@ -70,17 +70,24 @@ public static Function create(Method method) { } /* - * Finds a method in a given class by name. + * Finds a method in a given class by name. In case of overloaded methods with the same name, + * this prioritizes the overload with the maximum number of parameters. This ensures Calcite + * can resolve optional/default trailing parameters correctly when binding UDF overloads. + * * @param clazz class to search method in * @param name name of the method to find - * @return the first method with matching name or null when no method found + * @return the matching method with the highest parameter count or null when no method found */ static @Nullable Method findMethod(Class clazz, String name) { + Method bestMethod = null; for (Method method : clazz.getMethods()) { if (method.getName().equals(name) && !method.isBridge()) { - return method; + if (bestMethod == null + || method.getParameterTypes().length > bestMethod.getParameterTypes().length) { + bestMethod = method; + } } } - return null; + return bestMethod; } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java index aaa4d66011a6..10fea6120360 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java @@ -40,6 +40,7 @@ import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Top; @@ -134,12 +135,12 @@ public BeamSortRel( } if (fetch == null) { - throw new UnsupportedOperationException("ORDER BY without a LIMIT is not supported!"); + count = -1; + } else { + RexLiteral fetchLiteral = (RexLiteral) fetch; + count = ((BigDecimal) fetchLiteral.getValue()).intValue(); } - RexLiteral fetchLiteral = (RexLiteral) fetch; - count = ((BigDecimal) fetchLiteral.getValue()).intValue(); - if (offset != null) { RexLiteral offsetLiteral = (RexLiteral) offset; startIndex = ((BigDecimal) offsetLiteral.getValue()).intValue(); @@ -209,6 +210,20 @@ public PCollection expand(PCollectionList pinput) { GlobalWindows.class.getSimpleName(), windowingStrategy)); } + // When no limit is specified (count == -1), we must sort the entire dataset. + // To achieve this globally, we key all rows by a single dummy key, group them together + // using GroupByKey to ensure they are processed together, and then sort them in-memory + // via SortInMemoryFn. Note: This can be memory-intensive for large datasets. + if (count == -1) { + BeamSqlRowComparator comparator = + new BeamSqlRowComparator(fieldIndices, orientation, nullsFirst); + return upstream + .apply("WithDummyKey", WithKeys.of("DummyKey")) + .apply("GroupByKey", GroupByKey.create()) + .apply("SortInMemory", ParDo.of(new SortInMemoryFn(comparator))) + .setRowSchema(CalciteUtils.toSchema(getRowType())); + } + ReversedBeamSqlRowComparator comparator = new ReversedBeamSqlRowComparator(fieldIndices, orientation, nullsFirst); @@ -340,9 +355,16 @@ public int compare(Row row1, Row row2) { if (isValue1Null && isValue2Null) { continue; } else if (isValue1Null && !isValue2Null) { - fieldRet = -1 * (nullsFirst.get(i) ? -1 : 1); + // NULL placement is absolute: NULLS FIRST means the null row always sorts before the + // non-null row regardless of ASC/DESC, and NULLS LAST means it always sorts after. Do + // NOT apply the ASC/DESC orientation flip here — that flip only governs value-vs-value + // ordering (handled below). Mixing the two reversed NULLS LAST under ascending sorts + // (and NULLS FIRST under descending sorts), placing nulls on the wrong end. + fieldRet = nullsFirst.get(i) ? -1 : 1; + return fieldRet; } else if (!isValue1Null && isValue2Null) { - fieldRet = 1 * (nullsFirst.get(i) ? -1 : 1); + fieldRet = nullsFirst.get(i) ? 1 : -1; + return fieldRet; } else { switch (sqlTypeName) { case TINYINT: @@ -351,9 +373,17 @@ public int compare(Row row1, Row row2) { case BIGINT: case FLOAT: case DOUBLE: + case DECIMAL: + case BOOLEAN: case VARCHAR: + case CHAR: case DATE: + case TIME: case TIMESTAMP: + // All of the above map to Java types that implement Comparable (Boolean, BigDecimal, + // LocalTime, etc.), so a uniform Comparable comparison yields the correct SQL + // ordering + // (false < true for BOOLEAN). The base value is extracted via Comparable.class. Comparable v1 = row1.getBaseValue(fieldIndex, Comparable.class); Comparable v2 = row2.getBaseValue(fieldIndex, Comparable.class); fieldRet = v1.compareTo(v2); @@ -374,6 +404,27 @@ public int compare(Row row1, Row row2) { } } + private static class SortInMemoryFn extends DoFn>, Row> { + private final BeamSqlRowComparator comparator; + + public SortInMemoryFn(BeamSqlRowComparator comparator) { + this.comparator = comparator; + } + + @ProcessElement + public void processElement(ProcessContext ctx) { + Iterable input = ctx.element().getValue(); + List list = new ArrayList<>(); + for (Row r : input) { + list.add(r); + } + list.sort(comparator); + for (Row r : list) { + ctx.output(r); + } + } + } + private static class ReversedBeamSqlRowComparator implements Comparator, Serializable { private final BeamSqlRowComparator comparator; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java index 3aaa91680999..5bc1518b2bdb 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java @@ -27,6 +27,7 @@ import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.schemas.logicaltypes.NanosDuration; import org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType; import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.util.Preconditions; @@ -137,6 +138,7 @@ public static boolean isStringType(FieldType fieldType) { .put(TIME_WITH_LOCAL_TZ, SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE) .put(TIMESTAMP, SqlTypeName.TIMESTAMP) .put(TIMESTAMP_WITH_LOCAL_TZ, SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE) + .put(FieldType.logicalType(new NanosDuration()), SqlTypeName.INTERVAL_DAY_SECOND) .build(); private static final ImmutableMap CALCITE_TO_BEAM_TYPE_MAPPING = @@ -161,6 +163,8 @@ public static boolean isStringType(FieldType fieldType) { .put(SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE, TIME_WITH_LOCAL_TZ) .put(SqlTypeName.TIMESTAMP, TIMESTAMP) .put(SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE, TIMESTAMP_WITH_LOCAL_TZ) + .put(SqlTypeName.NULL, VARCHAR) + .put(SqlTypeName.INTERVAL_DAY_SECOND, FieldType.logicalType(new NanosDuration())) .build(); // Since there are multiple Calcite type that correspond to a single Beam type, this is the @@ -396,6 +400,55 @@ public static RelDataType sqlTypeWithAutoCast(RelDataTypeFactory typeFactory, Ty + ". This is currently unsupported, use List instead " + "of Array."); } + if (type instanceof Class) { + Class clazz = (Class) type; + if (clazz == String.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.VARCHAR), true); + } else if (clazz == Integer.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.INTEGER), true); + } else if (clazz == int.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.INTEGER), false); + } else if (clazz == Long.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.BIGINT), true); + } else if (clazz == long.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.BIGINT), false); + } else if (clazz == Double.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.DOUBLE), true); + } else if (clazz == double.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.DOUBLE), false); + } else if (clazz == Float.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.FLOAT), true); + } else if (clazz == float.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.FLOAT), false); + } else if (clazz == Short.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.SMALLINT), true); + } else if (clazz == short.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.SMALLINT), false); + } else if (clazz == Byte.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.TINYINT), true); + } else if (clazz == byte.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.TINYINT), false); + } else if (clazz == Boolean.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.BOOLEAN), true); + } else if (clazz == boolean.class) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.BOOLEAN), false); + } + } return typeFactory.createJavaType((Class) type); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java index 17636c628eb8..408118ff09e9 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java @@ -34,6 +34,7 @@ import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.schema.AggregateFunction; import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.schema.FunctionParameter; +import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.sql.type.SqlTypeName; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; import org.junit.Rule; @@ -77,7 +78,9 @@ public void subclassGetUdafImpl() { LazyAggregateCombineFn combiner = new LazyAggregateCombineFn<>(aggregateFn); AggregateFunction aggregateFunction = combiner.getUdafImpl(); RelDataTypeFactory typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT); - RelDataType expectedType = typeFactory.createJavaType(Long.class); + RelDataType expectedType = + typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.BIGINT), true); List params = aggregateFunction.getParameters(); assertThat(params, hasSize(1)); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java index dbe8be441ac6..0160c99cf4ba 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java @@ -293,6 +293,28 @@ public void testOrderBy_exception() { compilePipeline(sql, pipeline); } + @Test + public void testOrderBy_withoutLimit() { + String sql = + "INSERT INTO SUB_ORDER_RAM(order_id, site_id, price) SELECT " + + " order_id, site_id, price " + + "FROM ORDER_DETAILS " + + "ORDER BY order_id asc, site_id desc"; + + PCollection rows = compilePipeline(sql, pipeline); + PAssert.that(rows) + .containsInAnyOrder( + TestUtils.RowsBuilder.of( + Schema.FieldType.INT64, "order_id", + Schema.FieldType.INT32, "site_id", + Schema.FieldType.DOUBLE, "price") + .addRows( + 1L, 2, 1.0, 1L, 1, 2.0, 2L, 4, 3.0, 2L, 1, 4.0, 5L, 5, 5.0, 6L, 6, 6.0, 7L, 7, + 7.0, 8L, 8888, 8.0, 8L, 999, 9.0, 10L, 100, 10.0) + .getRows()); + pipeline.run().waitUntilFinish(); + } + @Test public void testNodeStatsEstimation() { String sql =