diff --git a/be/src/storage/index/ann/ann_topn_runtime.cpp b/be/src/storage/index/ann/ann_topn_runtime.cpp index a4de1f3f7deac4..63586121fa885a 100644 --- a/be/src/storage/index/ann/ann_topn_runtime.cpp +++ b/be/src/storage/index/ann/ann_topn_runtime.cpp @@ -27,11 +27,10 @@ #include "common/status.h" #include "core/column/column.h" #include "core/column/column_array.h" +#include "core/column/column_const.h" #include "core/column/column_nullable.h" #include "core/data_type/primitive_type.h" #include "exprs/function/array/function_array_distance.h" -#include "exprs/varray_literal.h" -#include "exprs/vcast_expr.h" #include "exprs/vexpr_context.h" #include "exprs/vexpr_fwd.h" #include "exprs/virtual_slot_ref.h" @@ -49,19 +48,6 @@ Result extract_query_vector(std::shared_ptr arg_expr) { arg_expr->debug_string())); } - // Accept either ArrayLiteral([..]) or CAST('..' AS Nullable(Array(Nullable(Float32)))) - // First, check the expr node type for clarity. - - bool is_array_literal = std::dynamic_pointer_cast(arg_expr) != nullptr; - bool is_cast_expr = std::dynamic_pointer_cast(arg_expr) != nullptr; - if (!is_array_literal && !is_cast_expr) { - return ResultError( - Status::InvalidArgument("Constant must be ArrayLiteral or CAST to array, got\n{}", - arg_expr->debug_string())); - } - - // We'll validate shape by inspecting the materialized constant column below. - std::shared_ptr column_wrapper; auto st = arg_expr->get_const_col(nullptr, &column_wrapper); if (!st.ok()) { @@ -69,8 +55,11 @@ Result extract_query_vector(std::shared_ptr arg_expr) { st.to_string())); } - // Execute the constant array literal and extract its float elements into _query_array - IColumn::Ptr col_ptr = column_wrapper->column_ptr->convert_to_full_column_if_const(); + // Unwrap ColumnConst without copy to get the underlying single-row column + IColumn::Ptr col_ptr = column_wrapper->column_ptr; + if (const auto* const_col = check_and_get_column(*col_ptr)) { + col_ptr = const_col->get_data_column_ptr(); + } // The expected runtime column layout for the literal is: // Nullable(ColumnArray(Nullable(ColumnFloat32))) with exactly 1 row (one array literal) @@ -126,7 +115,7 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_des |---------------- | | | | - SlotRef CAST(String as Nullable) OR ArrayLiteral + SlotRef Constant Array Expression */ std::shared_ptr vir_slot_ref = std::dynamic_pointer_cast(_order_by_expr_ctx->root()); diff --git a/be/src/storage/index/ann/ann_topn_runtime.h b/be/src/storage/index/ann/ann_topn_runtime.h index 5beca98f599e53..95b13b9462fa09 100644 --- a/be/src/storage/index/ann/ann_topn_runtime.h +++ b/be/src/storage/index/ann/ann_topn_runtime.h @@ -37,8 +37,6 @@ #include "core/column/column.h" #include "core/data_type/primitive_type.h" -#include "exprs/varray_literal.h" -#include "exprs/vcast_expr.h" #include "exprs/vectorized_fn_call.h" #include "exprs/vexpr.h" #include "exprs/vexpr_context.h" diff --git a/be/test/storage/index/ann/extract_query_vector_test.cpp b/be/test/storage/index/ann/extract_query_vector_test.cpp new file mode 100644 index 00000000000000..13e059825b3a2b --- /dev/null +++ b/be/test/storage/index/ann/extract_query_vector_test.cpp @@ -0,0 +1,248 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "core/column/column_array.h" +#include "core/column/column_const.h" +#include "core/column/column_nullable.h" +#include "core/column/column_vector.h" +#include "exprs/vexpr.h" +#include "storage/index/ann/ann_topn_runtime.h" + +namespace doris::segment_v2 { + +// A minimal mock VExpr that returns a pre-set constant column. +class MockConstVExpr : public VExpr { +public: + static TExprNode make_tnode() { + TExprNode node; + node.node_type = TExprNodeType::FLOAT_LITERAL; + node.type = TTypeDesc(); + TTypeNode type_node; + type_node.type = TTypeNodeType::SCALAR; + TScalarType scalar_type; + scalar_type.__set_type(TPrimitiveType::FLOAT); + type_node.__set_scalar_type(scalar_type); + node.type.types.push_back(type_node); + return node; + } + + MockConstVExpr() : VExpr(make_tnode()) {} + + Status get_const_col(VExprContext* /*context*/, + std::shared_ptr* output) override { + *output = std::make_shared(_col); + return Status::OK(); + } + + bool is_constant() const override { return _is_constant; } + + Status execute_column(VExprContext* context, const Block* block, const Selector* selector, + size_t count, ColumnPtr& result_column) const { + return Status::OK(); + } + + Status execute_column_impl(VExprContext* context, const Block* block, const Selector* selector, + size_t count, ColumnPtr& result_column) const override { + return Status::OK(); + } + + Status prepare(RuntimeState* state, const RowDescriptor& desc, VExprContext* context) override { + return Status::OK(); + } + + Status open(RuntimeState* state, VExprContext* context, + FunctionContext::FunctionStateScope scope) override { + return Status::OK(); + } + + void close(VExprContext* context, FunctionContext::FunctionStateScope scope) override {} + + const std::string& expr_name() const override { + static std::string name = "MockConstVExpr"; + return name; + } + + std::string debug_string() const override { return "MockConstVExpr"; } + + void set_column(IColumn::Ptr col) { _col = std::move(col); } + void set_is_constant(bool v) { _is_constant = v; } + +private: + IColumn::Ptr _col; + bool _is_constant = true; +}; + +// Helper: build Nullable(ColumnArray(Nullable(ColumnFloat32))) with 1 row of given floats +static IColumn::Ptr make_nullable_array_column(const std::vector& values) { + auto float_col = ColumnFloat32::create(); + for (float v : values) { + float_col->insert_value(v); + } + auto null_map_inner = ColumnUInt8::create(values.size(), 0); + auto nullable_inner = ColumnNullable::create(std::move(float_col), std::move(null_map_inner)); + + auto offsets = ColumnArray::ColumnOffsets::create(); + offsets->insert_value(static_cast(values.size())); + auto array_col = ColumnArray::create(std::move(nullable_inner), std::move(offsets)); + + auto null_map_outer = ColumnUInt8::create(1, 0); + return ColumnNullable::create(std::move(array_col), std::move(null_map_outer)); +} + +class ExtractQueryVectorTest : public ::testing::Test {}; + +// ColumnConst wrapping a Nullable(Array(Nullable(Float32))) — the array_repeat case +TEST_F(ExtractQueryVectorTest, ColumnConstWrappedArray) { + auto inner = make_nullable_array_column({1.0f, 2.0f, 3.0f}); + auto const_col = ColumnConst::create(std::move(inner), 1); + + auto mock = std::make_shared(); + mock->set_column(std::move(const_col)); + + auto result = extract_query_vector(mock); + ASSERT_TRUE(result.has_value()) << result.error().to_string(); + EXPECT_EQ(result.value()->size(), 3u); +} + +// Direct Nullable(Array(Nullable(Float32))) without ColumnConst wrapper +TEST_F(ExtractQueryVectorTest, DirectNullableArray) { + auto col = make_nullable_array_column({4.0f, 5.0f}); + + auto mock = std::make_shared(); + mock->set_column(std::move(col)); + + auto result = extract_query_vector(mock); + ASSERT_TRUE(result.has_value()) << result.error().to_string(); + EXPECT_EQ(result.value()->size(), 2u); +} + +// Non-nullable ColumnArray(Nullable(Float32)) directly +TEST_F(ExtractQueryVectorTest, NonNullableArray) { + auto float_col = ColumnFloat32::create(); + float_col->insert_value(1.0f); + float_col->insert_value(2.0f); + auto null_map = ColumnUInt8::create(2, 0); + auto nullable_inner = ColumnNullable::create(std::move(float_col), std::move(null_map)); + auto offsets = ColumnArray::ColumnOffsets::create(); + offsets->insert_value(2); + auto array_col = ColumnArray::create(std::move(nullable_inner), std::move(offsets)); + + auto mock = std::make_shared(); + mock->set_column(std::move(array_col)); + + auto result = extract_query_vector(mock); + ASSERT_TRUE(result.has_value()) << result.error().to_string(); + EXPECT_EQ(result.value()->size(), 2u); +} + +// ColumnConst wrapping non-nullable array (another possible shape) +TEST_F(ExtractQueryVectorTest, ColumnConstNonNullableArray) { + auto float_col = ColumnFloat32::create(); + float_col->insert_value(7.0f); + auto null_map = ColumnUInt8::create(1, 0); + auto nullable_inner = ColumnNullable::create(std::move(float_col), std::move(null_map)); + auto offsets = ColumnArray::ColumnOffsets::create(); + offsets->insert_value(1); + auto array_col = ColumnArray::create(std::move(nullable_inner), std::move(offsets)); + auto const_col = ColumnConst::create(std::move(array_col), 1); + + auto mock = std::make_shared(); + mock->set_column(std::move(const_col)); + + auto result = extract_query_vector(mock); + ASSERT_TRUE(result.has_value()) << result.error().to_string(); + EXPECT_EQ(result.value()->size(), 1u); +} + +// Verify extracted float values match input +TEST_F(ExtractQueryVectorTest, ValuesMatchInput) { + std::vector input = {1.5f, 2.5f, 3.5f, 4.5f}; + auto col = make_nullable_array_column(input); + auto const_col = ColumnConst::create(std::move(col), 1); + + auto mock = std::make_shared(); + mock->set_column(std::move(const_col)); + + auto result = extract_query_vector(mock); + ASSERT_TRUE(result.has_value()); + auto* float_col = assert_cast(result.value().get()); + ASSERT_EQ(float_col->size(), 4u); + for (size_t i = 0; i < input.size(); ++i) { + EXPECT_FLOAT_EQ(float_col->get_data()[i], input[i]); + } +} + +// Error: non-constant expression +TEST_F(ExtractQueryVectorTest, NonConstantExprFails) { + auto mock = std::make_shared(); + mock->set_is_constant(false); + + auto result = extract_query_vector(mock); + ASSERT_FALSE(result.has_value()); + EXPECT_TRUE(result.error().to_string().find("must be constant") != std::string::npos); +} + +// Error: NULL array +TEST_F(ExtractQueryVectorTest, NullArrayFails) { + auto float_col = ColumnFloat32::create(); + auto null_map_inner = ColumnUInt8::create(0, 0); + auto nullable_inner = ColumnNullable::create(std::move(float_col), std::move(null_map_inner)); + auto offsets = ColumnArray::ColumnOffsets::create(); + offsets->insert_value(0); + auto array_col = ColumnArray::create(std::move(nullable_inner), std::move(offsets)); + // Outer nullable with null flag set to 1 + auto null_map_outer = ColumnUInt8::create(1, 1); + auto nullable_outer = ColumnNullable::create(std::move(array_col), std::move(null_map_outer)); + + auto mock = std::make_shared(); + mock->set_column(std::move(nullable_outer)); + + auto result = extract_query_vector(mock); + ASSERT_FALSE(result.has_value()); + EXPECT_TRUE(result.error().to_string().find("cannot be NULL") != std::string::npos); +} + +// Error: empty array +TEST_F(ExtractQueryVectorTest, EmptyArrayFails) { + auto col = make_nullable_array_column({}); + + auto mock = std::make_shared(); + mock->set_column(std::move(col)); + + auto result = extract_query_vector(mock); + ASSERT_FALSE(result.has_value()); + EXPECT_TRUE(result.error().to_string().find("cannot be empty") != std::string::npos); +} + +// Error: not an array column at all +TEST_F(ExtractQueryVectorTest, NonArrayColumnFails) { + auto float_col = ColumnFloat32::create(); + float_col->insert_value(1.0f); + + auto mock = std::make_shared(); + mock->set_column(std::move(float_col)); + + auto result = extract_query_vector(mock); + ASSERT_FALSE(result.has_value()); + EXPECT_TRUE(result.error().to_string().find("Array literal") != std::string::npos); +} + +} // namespace doris::segment_v2 diff --git a/regression-test/data/ann_index_p0/ann_const_expr_vector.out b/regression-test/data/ann_index_p0/ann_const_expr_vector.out new file mode 100644 index 00000000000000..3aec3b406d2b15 --- /dev/null +++ b/regression-test/data/ann_index_p0/ann_const_expr_vector.out @@ -0,0 +1,21 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !array_repeat -- +1 0.0 +2 1.98 +3 9.98 + +-- !array_with_constant -- +1 0.0 +2 1.98 +3 9.98 + +-- !direct_literal -- +1 0.0 +2 1.98 +3 9.98 + +-- !cast_string -- +1 0.0 +2 1.98 +3 9.98 + diff --git a/regression-test/suites/ann_index_p0/ann_const_expr_vector.groovy b/regression-test/suites/ann_index_p0/ann_const_expr_vector.groovy new file mode 100644 index 00000000000000..380ac4cc40e052 --- /dev/null +++ b/regression-test/suites/ann_index_p0/ann_const_expr_vector.groovy @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Test that ANN queries work with various constant expression forms for the +// query vector, not just direct array literals or CAST expressions. +// Covers: array_repeat, array_with_constant, nested function calls, etc. +suite("ann_const_expr_vector") { + sql "unset variable all;" + sql "set enable_common_expr_pushdown=true;" + + def tableName = "ann_const_expr_tbl" + + sql "drop table if exists ${tableName}" + sql """ + CREATE TABLE ${tableName} ( + id INT NOT NULL, + embedding ARRAY NOT NULL, + INDEX idx_emb (`embedding`) USING ANN PROPERTIES( + "index_type"="hnsw", + "metric_type"="l2_distance", + "dim"="4" + ) + ) ENGINE=OLAP + DUPLICATE KEY(id) + DISTRIBUTED BY HASH(id) BUCKETS AUTO + PROPERTIES ("replication_num" = "1"); + """ + + sql """ + INSERT INTO ${tableName} VALUES + (1, [0.01, 0.01, 0.01, 0.01]), + (2, [1.0, 1.0, 1.0, 1.0]), + (3, [5.0, 5.0, 5.0, 5.0]), + (4, [10.0, 10.0, 10.0, 10.0]); + """ + + + // ========================================================================= + // Test 1: array_repeat - generates a constant array via function call + // This is the key test case: array_repeat produces a ColumnConst wrapping + // an array, which the old code rejected. + // ========================================================================= + qt_array_repeat """ + SELECT id, l2_distance_approximate(embedding, array_repeat(CAST(0.01 AS FLOAT), 4)) AS score + FROM ${tableName} + ORDER BY score ASC + LIMIT 3; + """ + + // ========================================================================= + // Test 2: array_with_constant - another way to produce constant arrays + // ========================================================================= + qt_array_with_constant """ + SELECT id, l2_distance_approximate(embedding, array_with_constant(4, CAST(0.01 AS FLOAT))) AS score + FROM ${tableName} + ORDER BY score ASC + LIMIT 3; + """ + + // ========================================================================= + // Test 3: Direct array literal (baseline, should always work) + // ========================================================================= + qt_direct_literal """ + SELECT id, l2_distance_approximate(embedding, [0.01, 0.01, 0.01, 0.01]) AS score + FROM ${tableName} + ORDER BY score ASC + LIMIT 3; + """ + + // ========================================================================= + // Test 4: CAST string to array (existing functionality) + // ========================================================================= + qt_cast_string """ + SELECT id, l2_distance_approximate(embedding, cast('[0.01,0.01,0.01,0.01]' as array)) AS score + FROM ${tableName} + ORDER BY score ASC + LIMIT 3; + """ + + // ========================================================================= + // Test 5: Error case - array_repeat with wrong dimension + // ========================================================================= + test { + sql """ + SELECT id FROM ${tableName} + ORDER BY l2_distance_approximate(embedding, array_repeat(CAST(0.01 AS FLOAT), 3)) + LIMIT 1; + """ + exception "[INVALID_ARGUMENT]" + } + + // ========================================================================= + // Test 6: Error case - empty array via array_repeat + // ========================================================================= + test { + sql """ + SELECT id FROM ${tableName} + ORDER BY l2_distance_approximate(embedding, array_repeat(CAST(0.01 AS FLOAT), 0)) + LIMIT 1; + """ + exception "Ann topn query vector cannot be empty" + } +} diff --git a/regression-test/suites/ann_index_p0/cast_string_as_array.groovy b/regression-test/suites/ann_index_p0/cast_string_as_array.groovy index 9d0ea331ef2ee7..01a1c4e5f0726a 100644 --- a/regression-test/suites/ann_index_p0/cast_string_as_array.groovy +++ b/regression-test/suites/ann_index_p0/cast_string_as_array.groovy @@ -67,7 +67,7 @@ suite("cast_string_as_array") { // runtime of ANN topn. So here we will get null directly... test { sql "select id from ann_cast_rhs_l2 order by l2_distance_approximate(embedding, cast(NULL as array)) limit 1;" - exception "Constant must be ArrayLiteral or CAST to array" + exception "Ann query vector cannot be NULL" }