Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,24 @@ public static Function create(Method method) {
}

/*
* Finds a method in a given class by name.
* Finds a method in a given class by name. In case of overloaded methods with the same name,
* this prioritizes the overload with the maximum number of parameters. This ensures Calcite
* can resolve optional/default trailing parameters correctly when binding UDF overloads.
*
* @param clazz class to search method in
* @param name name of the method to find
* @return the first method with matching name or null when no method found
* @return the matching method with the highest parameter count or null when no method found
*/
static @Nullable Method findMethod(Class<?> clazz, String name) {
Method bestMethod = null;
for (Method method : clazz.getMethods()) {
if (method.getName().equals(name) && !method.isBridge()) {
return method;
if (bestMethod == null
|| method.getParameterTypes().length > bestMethod.getParameterTypes().length) {
bestMethod = method;
}
}
}
return null;
return bestMethod;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Top;
Expand Down Expand Up @@ -134,12 +135,12 @@ public BeamSortRel(
}

if (fetch == null) {
throw new UnsupportedOperationException("ORDER BY without a LIMIT is not supported!");
count = -1;
} else {
RexLiteral fetchLiteral = (RexLiteral) fetch;
count = ((BigDecimal) fetchLiteral.getValue()).intValue();
}

RexLiteral fetchLiteral = (RexLiteral) fetch;
count = ((BigDecimal) fetchLiteral.getValue()).intValue();

if (offset != null) {
RexLiteral offsetLiteral = (RexLiteral) offset;
startIndex = ((BigDecimal) offsetLiteral.getValue()).intValue();
Expand Down Expand Up @@ -209,6 +210,20 @@ public PCollection<Row> expand(PCollectionList<Row> pinput) {
GlobalWindows.class.getSimpleName(), windowingStrategy));
}

// When no limit is specified (count == -1), we must sort the entire dataset.
// To achieve this globally, we key all rows by a single dummy key, group them together
// using GroupByKey to ensure they are processed together, and then sort them in-memory
// via SortInMemoryFn. Note: This can be memory-intensive for large datasets.
if (count == -1) {
BeamSqlRowComparator comparator =
new BeamSqlRowComparator(fieldIndices, orientation, nullsFirst);
return upstream
.apply("WithDummyKey", WithKeys.of("DummyKey"))
.apply("GroupByKey", GroupByKey.create())
.apply("SortInMemory", ParDo.of(new SortInMemoryFn(comparator)))
.setRowSchema(CalciteUtils.toSchema(getRowType()));
}

ReversedBeamSqlRowComparator comparator =
new ReversedBeamSqlRowComparator(fieldIndices, orientation, nullsFirst);

Expand Down Expand Up @@ -340,9 +355,16 @@ public int compare(Row row1, Row row2) {
if (isValue1Null && isValue2Null) {
continue;
} else if (isValue1Null && !isValue2Null) {
fieldRet = -1 * (nullsFirst.get(i) ? -1 : 1);
// NULL placement is absolute: NULLS FIRST means the null row always sorts before the
// non-null row regardless of ASC/DESC, and NULLS LAST means it always sorts after. Do
// NOT apply the ASC/DESC orientation flip here — that flip only governs value-vs-value
// ordering (handled below). Mixing the two reversed NULLS LAST under ascending sorts
// (and NULLS FIRST under descending sorts), placing nulls on the wrong end.
fieldRet = nullsFirst.get(i) ? -1 : 1;
return fieldRet;
} else if (!isValue1Null && isValue2Null) {
fieldRet = 1 * (nullsFirst.get(i) ? -1 : 1);
fieldRet = nullsFirst.get(i) ? 1 : -1;
return fieldRet;
} else {
switch (sqlTypeName) {
case TINYINT:
Expand All @@ -351,9 +373,17 @@ public int compare(Row row1, Row row2) {
case BIGINT:
case FLOAT:
case DOUBLE:
case DECIMAL:
case BOOLEAN:
case VARCHAR:
case CHAR:
case DATE:
case TIME:
case TIMESTAMP:
// All of the above map to Java types that implement Comparable (Boolean, BigDecimal,
// LocalTime, etc.), so a uniform Comparable comparison yields the correct SQL
// ordering
// (false < true for BOOLEAN). The base value is extracted via Comparable.class.
Comparable v1 = row1.getBaseValue(fieldIndex, Comparable.class);
Comparable v2 = row2.getBaseValue(fieldIndex, Comparable.class);
fieldRet = v1.compareTo(v2);
Expand All @@ -374,6 +404,27 @@ public int compare(Row row1, Row row2) {
}
}

private static class SortInMemoryFn extends DoFn<KV<String, Iterable<Row>>, Row> {
private final BeamSqlRowComparator comparator;

public SortInMemoryFn(BeamSqlRowComparator comparator) {
this.comparator = comparator;
}

@ProcessElement
public void processElement(ProcessContext ctx) {
Iterable<Row> input = ctx.element().getValue();
List<Row> list = new ArrayList<>();
for (Row r : input) {
list.add(r);
}
list.sort(comparator);
for (Row r : list) {
ctx.output(r);
}
}
}

private static class ReversedBeamSqlRowComparator implements Comparator<Row>, Serializable {
private final BeamSqlRowComparator comparator;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.logicaltypes.NanosDuration;
import org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType;
import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
import org.apache.beam.sdk.util.Preconditions;
Expand Down Expand Up @@ -137,6 +138,7 @@ public static boolean isStringType(FieldType fieldType) {
.put(TIME_WITH_LOCAL_TZ, SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE)
.put(TIMESTAMP, SqlTypeName.TIMESTAMP)
.put(TIMESTAMP_WITH_LOCAL_TZ, SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE)
.put(FieldType.logicalType(new NanosDuration()), SqlTypeName.INTERVAL_DAY_SECOND)
.build();

private static final ImmutableMap<SqlTypeName, FieldType> CALCITE_TO_BEAM_TYPE_MAPPING =
Expand All @@ -161,6 +163,8 @@ public static boolean isStringType(FieldType fieldType) {
.put(SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE, TIME_WITH_LOCAL_TZ)
.put(SqlTypeName.TIMESTAMP, TIMESTAMP)
.put(SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE, TIMESTAMP_WITH_LOCAL_TZ)
.put(SqlTypeName.NULL, VARCHAR)
.put(SqlTypeName.INTERVAL_DAY_SECOND, FieldType.logicalType(new NanosDuration()))
.build();

// Since there are multiple Calcite type that correspond to a single Beam type, this is the
Expand Down Expand Up @@ -396,6 +400,55 @@ public static RelDataType sqlTypeWithAutoCast(RelDataTypeFactory typeFactory, Ty
+ ". This is currently unsupported, use List instead "
+ "of Array.");
}
if (type instanceof Class) {
Class<?> clazz = (Class<?>) type;
if (clazz == String.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.VARCHAR), true);
} else if (clazz == Integer.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.INTEGER), true);
} else if (clazz == int.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.INTEGER), false);
} else if (clazz == Long.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BIGINT), true);
} else if (clazz == long.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BIGINT), false);
} else if (clazz == Double.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.DOUBLE), true);
} else if (clazz == double.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.DOUBLE), false);
} else if (clazz == Float.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.FLOAT), true);
} else if (clazz == float.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.FLOAT), false);
} else if (clazz == Short.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.SMALLINT), true);
} else if (clazz == short.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.SMALLINT), false);
} else if (clazz == Byte.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.TINYINT), true);
} else if (clazz == byte.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.TINYINT), false);
} else if (clazz == Boolean.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BOOLEAN), true);
} else if (clazz == boolean.class) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
}
}
return typeFactory.createJavaType((Class) type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.schema.AggregateFunction;
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.schema.FunctionParameter;
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.Rule;
Expand Down Expand Up @@ -77,7 +78,9 @@ public void subclassGetUdafImpl() {
LazyAggregateCombineFn<?, ?, ?> combiner = new LazyAggregateCombineFn<>(aggregateFn);
AggregateFunction aggregateFunction = combiner.getUdafImpl();
RelDataTypeFactory typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
RelDataType expectedType = typeFactory.createJavaType(Long.class);
RelDataType expectedType =
typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BIGINT), true);

List<FunctionParameter> params = aggregateFunction.getParameters();
assertThat(params, hasSize(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,28 @@ public void testOrderBy_exception() {
compilePipeline(sql, pipeline);
}

@Test
public void testOrderBy_withoutLimit() {
String sql =
"INSERT INTO SUB_ORDER_RAM(order_id, site_id, price) SELECT "
+ " order_id, site_id, price "
+ "FROM ORDER_DETAILS "
+ "ORDER BY order_id asc, site_id desc";

PCollection<Row> rows = compilePipeline(sql, pipeline);
PAssert.that(rows)
.containsInAnyOrder(
TestUtils.RowsBuilder.of(
Schema.FieldType.INT64, "order_id",
Schema.FieldType.INT32, "site_id",
Schema.FieldType.DOUBLE, "price")
.addRows(
1L, 2, 1.0, 1L, 1, 2.0, 2L, 4, 3.0, 2L, 1, 4.0, 5L, 5, 5.0, 6L, 6, 6.0, 7L, 7,
7.0, 8L, 8888, 8.0, 8L, 999, 9.0, 10L, 100, 10.0)
.getRows());
pipeline.run().waitUntilFinish();
}

@Test
public void testNodeStatsEstimation() {
String sql =
Expand Down
Loading