Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,14 @@ public class BeamRuleSets {
ImmutableList.of(BeamEnumerableConverterRule.INSTANCE);

public static Collection<RuleSet> getRuleSets() {
return ImmutableList.of(RuleSets.ofList(getAllRules()));
}

return ImmutableList.of(
RuleSets.ofList(
ImmutableList.<RelOptRule>builder()
.addAll(BEAM_CONVERTERS)
.addAll(BEAM_TO_ENUMERABLE)
.addAll(LOGICAL_OPTIMIZATIONS)
.build()));
public static List<RelOptRule> getAllRules() {
return ImmutableList.<RelOptRule>builder()
.addAll(BEAM_CONVERTERS)
.addAll(BEAM_TO_ENUMERABLE)
.addAll(LOGICAL_OPTIMIZATIONS)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ public NodeStats getNodeStats(RelNode rel, RelMetadataQuery mq) {
return this.getBeamNodeStats((BeamRelNode) rel, bmq);
}

if (rel instanceof org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.core.Values) {
org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.core.Values values =
(org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.core.Values) rel;
return NodeStats.create(values.getTuples().size());
}

// We can later define custom methods for all different RelNodes to prevent hitting this point.
// Similar to RelMdRowCount in calcite.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,18 @@ public CalcFn(
List<String> jarPaths,
FieldAccessDescriptor fieldAccess,
boolean collectErrors) {
this.processElementBlock = processElementBlock;
this.processElementBlock =
processElementBlock.replace(
"(byte[]) org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.runtime.SqlFunctions.concat",
"org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.runtime.SqlFunctions.concat");
this.outputSchema = outputSchema;
this.verifyRowValues = verifyRowValues;
this.jarPaths = jarPaths;
this.fieldAccess = fieldAccess;
this.collectErrors = collectErrors;

// validate generated code
compile(processElementBlock, jarPaths);
compile(this.processElementBlock, jarPaths);
}

private static ScriptEvaluator compile(String processElementBlock, List<String> jarPaths) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,28 @@ static List<Row> toRowList(PipelineOptions options, BeamRelNode node) {
if (node instanceof BeamIOSinkRel) {
throw new UnsupportedOperationException("Does not support BeamIOSinkRel in toRowList.");
} else if (isLimitQuery(node)) {
throw new UnsupportedOperationException("Does not support queries with LIMIT in toRowList.");
return limitRowList(options, node);
}
return collectRows(options, node).stream().collect(Collectors.toList());
}

private static List<Row> limitRowList(PipelineOptions options, BeamRelNode node) {
long id = options.getOptionsId();
ConcurrentLinkedQueue<Row> values = new ConcurrentLinkedQueue<>();
int limitCount = getLimitCount(node);

Collector.globalValues.put(id, values);
limitRun(options, node, new Collector(), values, limitCount);
Collector.globalValues.remove(id);

// remove extra retrieved values
while (values.size() > limitCount) {
values.remove();
}

return values.stream().collect(Collectors.toList());
}

static Enumerable<Object> toEnumerable(PipelineOptions options, BeamRelNode node) {
if (node instanceof BeamIOSinkRel) {
return count(options, node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.rel;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;

import java.io.Serializable;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamSetOperatorsTransforms;
import org.apache.beam.sdk.schemas.transforms.CoGroup;
Expand Down Expand Up @@ -59,14 +57,27 @@ public BeamSetOperatorRelBase(BeamRelNode beamRelNode, OpType opType, boolean al

@Override
public PCollection<Row> expand(PCollectionList<Row> inputs) {
checkArgument(
inputs.size() == 2,
"Wrong number of arguments to %s: %s",
beamRelNode.getClass().getSimpleName(),
inputs);
PCollection<Row> leftRows = inputs.get(0);
PCollection<Row> rightRows = inputs.get(1);
// Reverted Flatten optimization as it fails when inputs have slightly different schemas (e.g.
// NULL vs VARCHAR)
// if (opType == OpType.UNION && all) {
// return inputs.apply("UnionAllFlatten", Flatten.pCollections());
// }

if (inputs.size() == 2) {
return expandPair(inputs.get(0), inputs.get(1));
} else if (inputs.size() > 2) {
PCollection<Row> result = inputs.get(0);
for (int i = 1; i < inputs.size(); i++) {
result = expandPair(result, inputs.get(i));
}
return result;
} else {
throw new IllegalArgumentException(
"Too few arguments to " + beamRelNode.getClass().getSimpleName());
}
}

private PCollection<Row> expandPair(PCollection<Row> leftRows, PCollection<Row> rightRows) {
WindowFn leftWindow = leftRows.getWindowingStrategy().getWindowFn();
WindowFn rightWindow = rightRows.getWindowingStrategy().getWindowFn();
if (!leftWindow.isCompatible(rightWindow)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.Row;
Expand Down Expand Up @@ -84,15 +86,41 @@ public PCollection<Row> expand(PCollectionList<Row> pinput) {
BeamValuesRel.class.getSimpleName(),
pinput);

Schema schema = CalciteUtils.toSchema(getRowType());
Schema inferredSchema = CalciteUtils.toSchema(getRowType());
Schema.Builder schemaBuilder = Schema.builder();
for (int i = 0; i < inferredSchema.getFieldCount(); i++) {
Schema.Field field = inferredSchema.getField(i);
boolean hasNull = false;
for (ImmutableList<RexLiteral> tuple : tuples) {
if (tuple.get(i).getValue() == null) {
hasNull = true;
break;
}
}
if (hasNull && !field.getType().getNullable()) {
schemaBuilder.addField(field.getName(), field.getType().withNullable(true));
} else {
schemaBuilder.addField(field);
}
}
Schema schema = schemaBuilder.build();
List<Row> rows = tuples.stream().map(tuple -> tupleToRow(schema, tuple)).collect(toList());
return pinput.getPipeline().begin().apply(Create.of(rows).withRowSchema(schema));
return pinput
.getPipeline()
.begin()
.apply(Impulse.create())
.apply(ParDo.of(new EmitRowsFn(rows)))
.setRowSchema(schema);
}
}

private Row tupleToRow(Schema schema, ImmutableList<RexLiteral> tuple) {
return IntStream.range(0, tuple.size())
.mapToObj(i -> autoCastField(schema.getField(i), tuple.get(i).getValue()))
.mapToObj(
i -> {
Object val = tuple.get(i).getValue();
return autoCastField(schema.getField(i), val);
})
.collect(toRow(schema));
}

Expand All @@ -106,4 +134,19 @@ public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQ
NodeStats estimates = BeamSqlRelUtils.getNodeStats(this, mq);
return BeamCostModel.FACTORY.makeCost(estimates.getRowCount(), estimates.getRate());
}

private static class EmitRowsFn extends DoFn<byte[], Row> {
private final List<Row> rows;

public EmitRowsFn(List<Row> rows) {
this.rows = rows;
}

@ProcessElement
public void processElement(ProcessContext c) {
for (Row row : rows) {
c.output(row);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ public class BeamBuiltinAggregations {
BUILTIN_AGGREGATOR_FACTORIES =
ImmutableMap.<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>builder()
.put("ANY_VALUE", typeName -> Sample.anyValueCombineFn())
// SINGLE_VALUE is emitted by Calcite to enforce the cardinality of a scalar
// subquery (a subquery used as a scalar must yield exactly one row). The single
// input value is returned as-is; unlike COUNT/SUM it must not drop nulls, so a
// scalar subquery evaluating to NULL surfaces NULL.
.put("SINGLE_VALUE", typeName -> Sample.anyValueCombineFn())
// Drop null elements for these aggregations.
.put("COUNT", typeName -> new DropNullFnWithDefault(Count.combineFn()))
.put("MAX", typeName -> new DropNullFn(BeamBuiltinAggregations.createMax(typeName)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1133,4 +1133,36 @@ public void testCountIfFunction() throws Exception {
PAssert.that(result).containsInAnyOrder(rowResult);
pipeline.run().waitUntilFinish();
}

@Test
public void testCorrelatedScalarSubqueryWithSingleValue() {
Schema schema = Schema.builder().addInt32Field("id").addStringField("name").build();
PCollection<Row> input =
pipeline.apply(
"input",
Create.of(
Row.withSchema(schema).addValues(1, "a").build(),
Row.withSchema(schema).addValues(2, "b").build())
.withRowSchema(schema));

// Correlated subquery. Calcite should decorrelate this into a LEFT JOIN
// and use SINGLE_VALUE aggregation on the subquery side.
String sql =
"SELECT t1.id, (SELECT t2.name FROM PCOLLECTION t2 WHERE t2.id = t1.id) FROM PCOLLECTION t1";

PCollection<Row> result = input.apply("sql_subquery", SqlTransform.query(sql));

Schema expectedSchema =
Schema.builder()
.addInt32Field("id")
.addNullableField("name", Schema.FieldType.STRING)
.build();

PAssert.that(result)
.containsInAnyOrder(
Row.withSchema(expectedSchema).addValues(1, "a").build(),
Row.withSchema(expectedSchema).addValues(2, "b").build());

pipeline.run();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@

import java.math.BigDecimal;
import java.util.List;
import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.extensions.sql.meta.SchemaBaseBeamTable;
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestBoundedTable;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.Schema;
Expand Down Expand Up @@ -125,6 +127,21 @@ public void testToListRow_collectMultiple() {
assertEquals(Row.withSchema(schema).addValues(0L, 1L).build(), rowList.get(0));
}

@Test
public void testToRowList_limit() {
Schema schema = Schema.builder().addInt64Field("id").build();
java.util.Map<String, org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable> tables =
new java.util.HashMap<>();
tables.put("TEST", TestBoundedTable.of(Schema.FieldType.INT64, "id").addRows(1L, 2L, 3L));
BeamSqlEnv env = BeamSqlEnv.readOnly("test", tables);
BeamRelNode node = env.parseQuery("SELECT id FROM TEST LIMIT 2");

List<Row> rowList = BeamEnumerableConverter.toRowList(options, node);
assertEquals(2, rowList.size());
assertTrue(rowList.contains(Row.withSchema(schema).addValue(1L).build()));
assertTrue(rowList.contains(Row.withSchema(schema).addValue(2L).build()));
}

private static class FakeTable extends SchemaBaseBeamTable {
public FakeTable() {
super(null);
Expand Down
Loading