Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 7 additions & 18 deletions be/src/storage/index/ann/ann_topn_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@
#include "common/status.h"
#include "core/column/column.h"
#include "core/column/column_array.h"
#include "core/column/column_const.h"
#include "core/column/column_nullable.h"
#include "core/data_type/primitive_type.h"
#include "exprs/function/array/function_array_distance.h"
#include "exprs/varray_literal.h"
#include "exprs/vcast_expr.h"
#include "exprs/vexpr_context.h"
#include "exprs/vexpr_fwd.h"
#include "exprs/virtual_slot_ref.h"
Expand All @@ -49,28 +48,18 @@ Result<IColumn::Ptr> extract_query_vector(std::shared_ptr<VExpr> arg_expr) {
arg_expr->debug_string()));
}

// Accept either ArrayLiteral([..]) or CAST('..' AS Nullable(Array(Nullable(Float32))))
// First, check the expr node type for clarity.

bool is_array_literal = std::dynamic_pointer_cast<VArrayLiteral>(arg_expr) != nullptr;
bool is_cast_expr = std::dynamic_pointer_cast<VCastExpr>(arg_expr) != nullptr;
if (!is_array_literal && !is_cast_expr) {
return ResultError(
Status::InvalidArgument("Constant must be ArrayLiteral or CAST to array, got\n{}",
arg_expr->debug_string()));
}

// We'll validate shape by inspecting the materialized constant column below.

std::shared_ptr<ColumnPtrWrapper> column_wrapper;
auto st = arg_expr->get_const_col(nullptr, &column_wrapper);
if (!st.ok()) {
return ResultError(Status::InvalidArgument("Failed to get constant column, error: {}",
st.to_string()));
}

// Execute the constant array literal and extract its float elements into _query_array
IColumn::Ptr col_ptr = column_wrapper->column_ptr->convert_to_full_column_if_const();
// Unwrap ColumnConst without copy to get the underlying single-row column
IColumn::Ptr col_ptr = column_wrapper->column_ptr;
if (const auto* const_col = check_and_get_column<ColumnConst>(*col_ptr)) {
col_ptr = const_col->get_data_column_ptr();
}

// The expected runtime column layout for the literal is:
// Nullable(ColumnArray(Nullable(ColumnFloat32))) with exactly 1 row (one array literal)
Expand Down Expand Up @@ -126,7 +115,7 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_des
|----------------
| |
| |
SlotRef CAST(String as Nullable<ArrayFloat>) OR ArrayLiteral
SlotRef Constant Array Expression
*/
std::shared_ptr<VirtualSlotRef> vir_slot_ref =
std::dynamic_pointer_cast<VirtualSlotRef>(_order_by_expr_ctx->root());
Expand Down
2 changes: 0 additions & 2 deletions be/src/storage/index/ann/ann_topn_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@

#include "core/column/column.h"
#include "core/data_type/primitive_type.h"
#include "exprs/varray_literal.h"
#include "exprs/vcast_expr.h"
#include "exprs/vectorized_fn_call.h"
#include "exprs/vexpr.h"
#include "exprs/vexpr_context.h"
Expand Down
248 changes: 248 additions & 0 deletions be/test/storage/index/ann/extract_query_vector_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <gtest/gtest.h>

#include <memory>

#include "core/column/column_array.h"
#include "core/column/column_const.h"
#include "core/column/column_nullable.h"
#include "core/column/column_vector.h"
#include "exprs/vexpr.h"
#include "storage/index/ann/ann_topn_runtime.h"

namespace doris::segment_v2 {

// A minimal mock VExpr that returns a pre-set constant column.
class MockConstVExpr : public VExpr {
public:
static TExprNode make_tnode() {
TExprNode node;
node.node_type = TExprNodeType::FLOAT_LITERAL;
node.type = TTypeDesc();
TTypeNode type_node;
type_node.type = TTypeNodeType::SCALAR;
TScalarType scalar_type;
scalar_type.__set_type(TPrimitiveType::FLOAT);
type_node.__set_scalar_type(scalar_type);
node.type.types.push_back(type_node);
return node;
}

MockConstVExpr() : VExpr(make_tnode()) {}

Status get_const_col(VExprContext* /*context*/,
std::shared_ptr<ColumnPtrWrapper>* output) override {
*output = std::make_shared<ColumnPtrWrapper>(_col);
return Status::OK();
}

bool is_constant() const override { return _is_constant; }

Status execute_column(VExprContext* context, const Block* block, const Selector* selector,
size_t count, ColumnPtr& result_column) const {
return Status::OK();
}

Status execute_column_impl(VExprContext* context, const Block* block, const Selector* selector,
size_t count, ColumnPtr& result_column) const override {
return Status::OK();
}

Status prepare(RuntimeState* state, const RowDescriptor& desc, VExprContext* context) override {
return Status::OK();
}

Status open(RuntimeState* state, VExprContext* context,
FunctionContext::FunctionStateScope scope) override {
return Status::OK();
}

void close(VExprContext* context, FunctionContext::FunctionStateScope scope) override {}

const std::string& expr_name() const override {
static std::string name = "MockConstVExpr";
return name;
}

std::string debug_string() const override { return "MockConstVExpr"; }

void set_column(IColumn::Ptr col) { _col = std::move(col); }
void set_is_constant(bool v) { _is_constant = v; }

private:
IColumn::Ptr _col;
bool _is_constant = true;
};

// Helper: build Nullable(ColumnArray(Nullable(ColumnFloat32))) with 1 row of given floats
static IColumn::Ptr make_nullable_array_column(const std::vector<float>& values) {
auto float_col = ColumnFloat32::create();
for (float v : values) {
float_col->insert_value(v);
}
auto null_map_inner = ColumnUInt8::create(values.size(), 0);
auto nullable_inner = ColumnNullable::create(std::move(float_col), std::move(null_map_inner));

auto offsets = ColumnArray::ColumnOffsets::create();
offsets->insert_value(static_cast<IColumn::Offset>(values.size()));
auto array_col = ColumnArray::create(std::move(nullable_inner), std::move(offsets));

auto null_map_outer = ColumnUInt8::create(1, 0);
return ColumnNullable::create(std::move(array_col), std::move(null_map_outer));
}

class ExtractQueryVectorTest : public ::testing::Test {};

// ColumnConst wrapping a Nullable(Array(Nullable(Float32))) — the array_repeat case
TEST_F(ExtractQueryVectorTest, ColumnConstWrappedArray) {
auto inner = make_nullable_array_column({1.0f, 2.0f, 3.0f});
auto const_col = ColumnConst::create(std::move(inner), 1);

auto mock = std::make_shared<MockConstVExpr>();
mock->set_column(std::move(const_col));

auto result = extract_query_vector(mock);
ASSERT_TRUE(result.has_value()) << result.error().to_string();
EXPECT_EQ(result.value()->size(), 3u);
}

// Direct Nullable(Array(Nullable(Float32))) without ColumnConst wrapper
TEST_F(ExtractQueryVectorTest, DirectNullableArray) {
auto col = make_nullable_array_column({4.0f, 5.0f});

auto mock = std::make_shared<MockConstVExpr>();
mock->set_column(std::move(col));

auto result = extract_query_vector(mock);
ASSERT_TRUE(result.has_value()) << result.error().to_string();
EXPECT_EQ(result.value()->size(), 2u);
}

// Non-nullable ColumnArray(Nullable(Float32)) directly
TEST_F(ExtractQueryVectorTest, NonNullableArray) {
auto float_col = ColumnFloat32::create();
float_col->insert_value(1.0f);
float_col->insert_value(2.0f);
auto null_map = ColumnUInt8::create(2, 0);
auto nullable_inner = ColumnNullable::create(std::move(float_col), std::move(null_map));
auto offsets = ColumnArray::ColumnOffsets::create();
offsets->insert_value(2);
auto array_col = ColumnArray::create(std::move(nullable_inner), std::move(offsets));

auto mock = std::make_shared<MockConstVExpr>();
mock->set_column(std::move(array_col));

auto result = extract_query_vector(mock);
ASSERT_TRUE(result.has_value()) << result.error().to_string();
EXPECT_EQ(result.value()->size(), 2u);
}

// ColumnConst wrapping non-nullable array (another possible shape)
TEST_F(ExtractQueryVectorTest, ColumnConstNonNullableArray) {
auto float_col = ColumnFloat32::create();
float_col->insert_value(7.0f);
auto null_map = ColumnUInt8::create(1, 0);
auto nullable_inner = ColumnNullable::create(std::move(float_col), std::move(null_map));
auto offsets = ColumnArray::ColumnOffsets::create();
offsets->insert_value(1);
auto array_col = ColumnArray::create(std::move(nullable_inner), std::move(offsets));
auto const_col = ColumnConst::create(std::move(array_col), 1);

auto mock = std::make_shared<MockConstVExpr>();
mock->set_column(std::move(const_col));

auto result = extract_query_vector(mock);
ASSERT_TRUE(result.has_value()) << result.error().to_string();
EXPECT_EQ(result.value()->size(), 1u);
}

// Verify extracted float values match input
TEST_F(ExtractQueryVectorTest, ValuesMatchInput) {
std::vector<float> input = {1.5f, 2.5f, 3.5f, 4.5f};
auto col = make_nullable_array_column(input);
auto const_col = ColumnConst::create(std::move(col), 1);

auto mock = std::make_shared<MockConstVExpr>();
mock->set_column(std::move(const_col));

auto result = extract_query_vector(mock);
ASSERT_TRUE(result.has_value());
auto* float_col = assert_cast<const ColumnFloat32*>(result.value().get());
ASSERT_EQ(float_col->size(), 4u);
for (size_t i = 0; i < input.size(); ++i) {
EXPECT_FLOAT_EQ(float_col->get_data()[i], input[i]);
}
}

// Error: non-constant expression
TEST_F(ExtractQueryVectorTest, NonConstantExprFails) {
auto mock = std::make_shared<MockConstVExpr>();
mock->set_is_constant(false);

auto result = extract_query_vector(mock);
ASSERT_FALSE(result.has_value());
EXPECT_TRUE(result.error().to_string().find("must be constant") != std::string::npos);
}

// Error: NULL array
TEST_F(ExtractQueryVectorTest, NullArrayFails) {
auto float_col = ColumnFloat32::create();
auto null_map_inner = ColumnUInt8::create(0, 0);
auto nullable_inner = ColumnNullable::create(std::move(float_col), std::move(null_map_inner));
auto offsets = ColumnArray::ColumnOffsets::create();
offsets->insert_value(0);
auto array_col = ColumnArray::create(std::move(nullable_inner), std::move(offsets));
// Outer nullable with null flag set to 1
auto null_map_outer = ColumnUInt8::create(1, 1);
auto nullable_outer = ColumnNullable::create(std::move(array_col), std::move(null_map_outer));

auto mock = std::make_shared<MockConstVExpr>();
mock->set_column(std::move(nullable_outer));

auto result = extract_query_vector(mock);
ASSERT_FALSE(result.has_value());
EXPECT_TRUE(result.error().to_string().find("cannot be NULL") != std::string::npos);
}

// Error: empty array
TEST_F(ExtractQueryVectorTest, EmptyArrayFails) {
auto col = make_nullable_array_column({});

auto mock = std::make_shared<MockConstVExpr>();
mock->set_column(std::move(col));

auto result = extract_query_vector(mock);
ASSERT_FALSE(result.has_value());
EXPECT_TRUE(result.error().to_string().find("cannot be empty") != std::string::npos);
}

// Error: not an array column at all
TEST_F(ExtractQueryVectorTest, NonArrayColumnFails) {
auto float_col = ColumnFloat32::create();
float_col->insert_value(1.0f);

auto mock = std::make_shared<MockConstVExpr>();
mock->set_column(std::move(float_col));

auto result = extract_query_vector(mock);
ASSERT_FALSE(result.has_value());
EXPECT_TRUE(result.error().to_string().find("Array literal") != std::string::npos);
}

} // namespace doris::segment_v2
21 changes: 21 additions & 0 deletions regression-test/data/ann_index_p0/ann_const_expr_vector.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !array_repeat --
1 0.0
2 1.98
3 9.98

-- !array_with_constant --
1 0.0
2 1.98
3 9.98

-- !direct_literal --
1 0.0
2 1.98
3 9.98

-- !cast_string --
1 0.0
2 1.98
3 9.98

Loading
Loading