diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java index 9fc5fcb436c89..5075212bf4171 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java @@ -49,6 +49,8 @@ import org.apache.iotdb.commons.auth.entity.PrivilegeType; import org.apache.iotdb.commons.client.request.AsyncRequestContext; import org.apache.iotdb.commons.cluster.NodeStatus; +import org.apache.iotdb.commons.concurrent.Await; +import org.apache.iotdb.commons.concurrent.AwaitTimeoutException; import org.apache.iotdb.commons.concurrent.IoTThreadFactory; import org.apache.iotdb.commons.concurrent.ThreadName; import org.apache.iotdb.commons.concurrent.threadpool.WrappedThreadPoolExecutor; @@ -193,6 +195,7 @@ import org.apache.iotdb.db.schemaengine.template.TemplateInternalRPCUpdateType; import org.apache.iotdb.db.schemaengine.template.TemplateInternalRPCUtil; import org.apache.iotdb.db.service.DataNode; +import org.apache.iotdb.db.service.DataNode.DataNodeContext; import org.apache.iotdb.db.service.RegionMigrateService; import org.apache.iotdb.db.service.externalservice.ExternalServiceManagementService; import org.apache.iotdb.db.service.metrics.FileMetrics; @@ -416,6 +419,8 @@ public class DataNodeInternalRPCServiceImpl implements IDataNodeRPCService.Iface private final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig(); + private final DataNodeContext dataNodeContext; + private final ExecutorService schemaExecutor = new WrappedThreadPoolExecutor( 0, @@ -430,10 +435,33 @@ public class DataNodeInternalRPCServiceImpl implements IDataNodeRPCService.Iface private static final String SYSTEM = "system"; - public DataNodeInternalRPCServiceImpl() { + public DataNodeInternalRPCServiceImpl(DataNodeContext dataNodeContext) { super(); partitionFetcher = ClusterPartitionFetcher.getInstance(); schemaFetcher = ClusterSchemaFetcher.getInstance(); + this.dataNodeContext = dataNodeContext; + } + + private long consensusWaitTimeoutSeconds = 30; + + private TSStatus waitForConsensusStarted() { + if (dataNodeContext.isAllConsensusStarted()) { + return null; + } + try { + Await.await() + .atMost(consensusWaitTimeoutSeconds, TimeUnit.SECONDS) + .pollInterval(100, TimeUnit.MILLISECONDS) + .until(dataNodeContext::isAllConsensusStarted); + return null; + } catch (AwaitTimeoutException e) { + LOGGER.warn( + "Consensus has not been started after {} seconds, rejecting region request", + consensusWaitTimeoutSeconds); + return RpcUtils.getStatus( + TSStatusCode.CONSENSUS_NOT_INITIALIZED, + "Consensus has not been started after " + consensusWaitTimeoutSeconds + " seconds"); + } } @Override @@ -624,11 +652,19 @@ private TLoadResp createTLoadResp(final TSStatus resultStatus) { @Override public TSStatus createSchemaRegion(final TCreateSchemaRegionReq req) { + TSStatus consensusStatus = waitForConsensusStarted(); + if (consensusStatus != null) { + return consensusStatus; + } return regionManager.createSchemaRegion(req.getRegionReplicaSet(), req.getStorageGroup()); } @Override public TSStatus createDataRegion(TCreateDataRegionReq req) { + TSStatus consensusStatus = waitForConsensusStarted(); + if (consensusStatus != null) { + return consensusStatus; + } return regionManager.createDataRegion(req.getRegionReplicaSet(), req.getStorageGroup()); } @@ -2616,6 +2652,10 @@ public TSStatus updateTemplate(final TUpdateTemplateReq req) { @Override public TSStatus deleteRegion(TConsensusGroupId tconsensusGroupId) { + TSStatus consensusStatus = waitForConsensusStarted(); + if (consensusStatus != null) { + return consensusStatus; + } ConsensusGroupId consensusGroupId = ConsensusGroupId.Factory.createFromTConsensusGroupId(tconsensusGroupId); if (consensusGroupId instanceof DataRegionId) { @@ -2644,6 +2684,12 @@ public TRegionLeaderChangeResp changeRegionLeader(TRegionLeaderChangeReq req) { LOGGER.info("[ChangeRegionLeader] {}", req); TRegionLeaderChangeResp resp = new TRegionLeaderChangeResp(); + TSStatus consensusStatus = waitForConsensusStarted(); + if (consensusStatus != null) { + resp.setStatus(consensusStatus); + return resp; + } + TSStatus successStatus = new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); TConsensusGroupId tgId = req.getRegionId(); ConsensusGroupId regionId = ConsensusGroupId.Factory.createFromTConsensusGroupId(tgId); @@ -2713,6 +2759,10 @@ private boolean isLeader(ConsensusGroupId regionId) { @Override public TSStatus createNewRegionPeer(TCreatePeerReq req) { + TSStatus consensusStatus = waitForConsensusStarted(); + if (consensusStatus != null) { + return consensusStatus; + } ConsensusGroupId regionId = ConsensusGroupId.Factory.createFromTConsensusGroupId(req.getRegionId()); List peers = @@ -2733,6 +2783,10 @@ public TSStatus createNewRegionPeer(TCreatePeerReq req) { @Override public TSStatus addRegionPeer(TMaintainPeerReq req) { + TSStatus consensusStatus = waitForConsensusStarted(); + if (consensusStatus != null) { + return consensusStatus; + } TConsensusGroupId regionId = req.getRegionId(); String selectedDataNodeIP = req.getDestNode().getInternalEndPoint().getIp(); boolean submitSucceed = RegionMigrateService.getInstance().submitAddRegionPeerTask(req); @@ -2751,6 +2805,10 @@ public TSStatus addRegionPeer(TMaintainPeerReq req) { @Override public TSStatus removeRegionPeer(TMaintainPeerReq req) { + TSStatus consensusStatus = waitForConsensusStarted(); + if (consensusStatus != null) { + return consensusStatus; + } TConsensusGroupId regionId = req.getRegionId(); String selectedDataNodeIP = req.getDestNode().getInternalEndPoint().getIp(); boolean submitSucceed = RegionMigrateService.getInstance().submitRemoveRegionPeerTask(req); @@ -2769,6 +2827,10 @@ public TSStatus removeRegionPeer(TMaintainPeerReq req) { @Override public TSStatus deleteOldRegionPeer(TMaintainPeerReq req) { + TSStatus consensusStatus = waitForConsensusStarted(); + if (consensusStatus != null) { + return consensusStatus; + } TConsensusGroupId regionId = req.getRegionId(); String selectedDataNodeIP = req.getDestNode().getInternalEndPoint().getIp(); boolean submitSucceed = RegionMigrateService.getInstance().submitDeleteOldRegionPeerTask(req); @@ -2788,6 +2850,10 @@ public TSStatus deleteOldRegionPeer(TMaintainPeerReq req) { // TODO: return which DataNode fail @Override public TSStatus resetPeerList(TResetPeerListReq req) throws TException { + TSStatus consensusStatus = waitForConsensusStarted(); + if (consensusStatus != null) { + return consensusStatus; + } return RegionMigrateService.getInstance().resetPeerList(req); } @@ -2798,6 +2864,10 @@ public TRegionMigrateResult getRegionMaintainResult(long taskId) throws TExcepti @Override public TSStatus notifyRegionMigration(TNotifyRegionMigrationReq req) throws TException { + TSStatus consensusStatus = waitForConsensusStarted(); + if (consensusStatus != null) { + return consensusStatus; + } RegionMigrateService.getInstance().notifyRegionMigration(req); return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); } @@ -3439,4 +3509,8 @@ private List serializeDatabaseScopedTableList( return result; } + + public void setConsensusWaitTimeoutSeconds(long consensusWaitTimeoutSeconds) { + this.consensusWaitTimeoutSeconds = consensusWaitTimeoutSeconds; + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java index 7004d09872db5..677d677311f32 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java @@ -183,14 +183,16 @@ public class DataNode extends ServerCommandLine implements DataNodeMBean { private static final String REGISTER_INTERRUPTION = "Unexpected interruption when waiting to register to the cluster"; - private boolean schemaRegionConsensusStarted = false; - private boolean dataRegionConsensusStarted = false; + private volatile boolean schemaRegionConsensusStarted = false; + private volatile boolean dataRegionConsensusStarted = false; private static Thread watcherThread; + private DataNodeContext context; public DataNode() { super("DataNode"); // We do not init anything here, so that we can re-initialize the instance in IT. DataNodeHolder.INSTANCE = this; + context = new DataNodeContext(); } public static void reinitializeStatics() { @@ -934,7 +936,9 @@ private void setUpRPCService() throws StartupException { protected void registerInternalRPCService() throws StartupException { // Start InternalRPCService to indicate that the current DataNode can accept cluster scheduling - registerManager.register(DataNodeInternalRPCService.getInstance()); + DataNodeInternalRPCService instance = DataNodeInternalRPCService.getInstance(); + instance.setDataNodeContext(context); + registerManager.register(instance); } // make it easier for users to extend ClientRPCServiceImpl to export more rpc services @@ -1373,4 +1377,10 @@ private DataNodeHolder() { // Empty constructor } } + + public class DataNodeContext { + public boolean isAllConsensusStarted() { + return dataRegionConsensusStarted && schemaRegionConsensusStarted; + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNodeInternalRPCService.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNodeInternalRPCService.java index 5de1041a9a007..f3bf8e507c284 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNodeInternalRPCService.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNodeInternalRPCService.java @@ -31,6 +31,7 @@ import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.db.protocol.thrift.handler.InternalServiceThriftHandler; import org.apache.iotdb.db.protocol.thrift.impl.DataNodeInternalRPCServiceImpl; +import org.apache.iotdb.db.service.DataNode.DataNodeContext; import org.apache.iotdb.db.service.metrics.DataNodeInternalRPCServiceMetrics; import org.apache.iotdb.mpp.rpc.thrift.IDataNodeRPCService.Processor; import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory; @@ -44,6 +45,7 @@ public class DataNodeInternalRPCService extends ThriftService private static final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig(); private final AtomicReference impl = new AtomicReference<>(); + private DataNodeContext dataNodeContext; private DataNodeInternalRPCService() {} @@ -54,9 +56,9 @@ public ServiceType getID() { @Override public void initTProcessor() { - impl.compareAndSet(null, new DataNodeInternalRPCServiceImpl()); + DataNodeInternalRPCServiceImpl service = getImpl(); initSyncedServiceImpl(null); - processor = new Processor<>(impl.get()); + processor = new Processor<>(service); } @Override @@ -109,7 +111,7 @@ public int getBindPort() { } public DataNodeInternalRPCServiceImpl getImpl() { - impl.compareAndSet(null, new DataNodeInternalRPCServiceImpl()); + impl.compareAndSet(null, new DataNodeInternalRPCServiceImpl(dataNodeContext)); return impl.get(); } @@ -122,4 +124,8 @@ private DataNodeInternalRPCServiceHolder() {} public static DataNodeInternalRPCService getInstance() { return DataNodeInternalRPCServiceHolder.INSTANCE; } + + public void setDataNodeContext(DataNodeContext dataNodeContext) { + this.dataNodeContext = dataNodeContext; + } } diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/protocol/thrift/impl/ConsensusWaitTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/protocol/thrift/impl/ConsensusWaitTest.java new file mode 100644 index 0000000000000..65fd2b4785896 --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/protocol/thrift/impl/ConsensusWaitTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.thrift.impl; + +import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; +import org.apache.iotdb.common.rpc.thrift.TConsensusGroupType; +import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.service.DataNode.DataNodeContext; +import org.apache.iotdb.mpp.rpc.thrift.TCreateDataRegionReq; +import org.apache.iotdb.mpp.rpc.thrift.TCreateSchemaRegionReq; +import org.apache.iotdb.mpp.rpc.thrift.TRegionLeaderChangeReq; +import org.apache.iotdb.mpp.rpc.thrift.TRegionLeaderChangeResp; +import org.apache.iotdb.rpc.TSStatusCode; + +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.mockito.Mockito; + +import java.util.Collections; + +import static org.mockito.Mockito.when; + +public class ConsensusWaitTest { + + @BeforeClass + public static void setUp() { + IoTDBDescriptor.getInstance().getConfig().setDataNodeId(0); + } + + private DataNodeInternalRPCServiceImpl createServiceWithConsensusState(boolean started) { + DataNodeContext context = Mockito.mock(DataNodeContext.class); + when(context.isAllConsensusStarted()).thenReturn(started); + DataNodeInternalRPCServiceImpl service = new DataNodeInternalRPCServiceImpl(context); + service.setConsensusWaitTimeoutSeconds(1); + return service; + } + + private TCreateSchemaRegionReq createSchemaRegionReq() { + TCreateSchemaRegionReq req = new TCreateSchemaRegionReq(); + req.setStorageGroup("root.test"); + TRegionReplicaSet replicaSet = new TRegionReplicaSet(); + replicaSet.setRegionId(new TConsensusGroupId(TConsensusGroupType.SchemaRegion, 0)); + TDataNodeLocation location = new TDataNodeLocation(); + location.setDataNodeId(0); + location.setClientRpcEndPoint(new TEndPoint("0.0.0.0", 6667)); + location.setInternalEndPoint(new TEndPoint("0.0.0.0", 10730)); + location.setMPPDataExchangeEndPoint(new TEndPoint("0.0.0.0", 10740)); + location.setDataRegionConsensusEndPoint(new TEndPoint("0.0.0.0", 10760)); + location.setSchemaRegionConsensusEndPoint(new TEndPoint("0.0.0.0", 10750)); + replicaSet.setDataNodeLocations(Collections.singletonList(location)); + req.setRegionReplicaSet(replicaSet); + return req; + } + + private TCreateDataRegionReq createDataRegionReq() { + TCreateDataRegionReq req = new TCreateDataRegionReq(); + req.setStorageGroup("root.test"); + TRegionReplicaSet replicaSet = new TRegionReplicaSet(); + replicaSet.setRegionId(new TConsensusGroupId(TConsensusGroupType.DataRegion, 0)); + TDataNodeLocation location = new TDataNodeLocation(); + location.setDataNodeId(0); + location.setClientRpcEndPoint(new TEndPoint("0.0.0.0", 6667)); + location.setInternalEndPoint(new TEndPoint("0.0.0.0", 10730)); + location.setMPPDataExchangeEndPoint(new TEndPoint("0.0.0.0", 10740)); + location.setDataRegionConsensusEndPoint(new TEndPoint("0.0.0.0", 10760)); + location.setSchemaRegionConsensusEndPoint(new TEndPoint("0.0.0.0", 10750)); + replicaSet.setDataNodeLocations(Collections.singletonList(location)); + req.setRegionReplicaSet(replicaSet); + return req; + } + + @Test + public void testCreateSchemaRegionRejectsWhenConsensusNotStarted() { + DataNodeInternalRPCServiceImpl service = createServiceWithConsensusState(false); + TSStatus status = service.createSchemaRegion(createSchemaRegionReq()); + Assert.assertEquals(TSStatusCode.CONSENSUS_NOT_INITIALIZED.getStatusCode(), status.getCode()); + } + + @Test + public void testCreateDataRegionRejectsWhenConsensusNotStarted() { + DataNodeInternalRPCServiceImpl service = createServiceWithConsensusState(false); + TSStatus status = service.createDataRegion(createDataRegionReq()); + Assert.assertEquals(TSStatusCode.CONSENSUS_NOT_INITIALIZED.getStatusCode(), status.getCode()); + } + + @Test + public void testDeleteRegionRejectsWhenConsensusNotStarted() { + DataNodeInternalRPCServiceImpl service = createServiceWithConsensusState(false); + TConsensusGroupId groupId = new TConsensusGroupId(TConsensusGroupType.DataRegion, 0); + TSStatus status = service.deleteRegion(groupId); + Assert.assertEquals(TSStatusCode.CONSENSUS_NOT_INITIALIZED.getStatusCode(), status.getCode()); + } + + @Test + public void testChangeRegionLeaderRejectsWhenConsensusNotStarted() { + DataNodeInternalRPCServiceImpl service = createServiceWithConsensusState(false); + TRegionLeaderChangeReq req = new TRegionLeaderChangeReq(); + req.setRegionId(new TConsensusGroupId(TConsensusGroupType.DataRegion, 0)); + TDataNodeLocation newLeader = new TDataNodeLocation(); + newLeader.setDataNodeId(0); + newLeader.setInternalEndPoint(new TEndPoint("0.0.0.0", 10730)); + newLeader.setDataRegionConsensusEndPoint(new TEndPoint("0.0.0.0", 10760)); + newLeader.setSchemaRegionConsensusEndPoint(new TEndPoint("0.0.0.0", 10750)); + req.setNewLeaderNode(newLeader); + TRegionLeaderChangeResp resp = service.changeRegionLeader(req); + Assert.assertEquals( + TSStatusCode.CONSENSUS_NOT_INITIALIZED.getStatusCode(), resp.getStatus().getCode()); + } +} diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/service/DataNodeInternalRPCServiceImplTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/service/DataNodeInternalRPCServiceImplTest.java index 22a2d6cdc461f..63dda4537f947 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/service/DataNodeInternalRPCServiceImplTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/service/DataNodeInternalRPCServiceImplTest.java @@ -50,6 +50,7 @@ import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metadata.write.CreateMultiTimeSeriesNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metadata.write.CreateTimeSeriesNode; import org.apache.iotdb.db.schemaengine.SchemaEngine; +import org.apache.iotdb.db.service.DataNode.DataNodeContext; import org.apache.iotdb.db.storageengine.dataregion.DataRegion; import org.apache.iotdb.db.storageengine.dataregion.tsfile.TsFileResource; import org.apache.iotdb.db.utils.EnvironmentUtils; @@ -68,6 +69,7 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; +import org.mockito.Mockito; import java.io.File; import java.io.IOException; @@ -80,6 +82,8 @@ import java.util.Objects; import java.util.Optional; +import static org.mockito.Mockito.when; + public class DataNodeInternalRPCServiceImplTest { private static final IoTDBConfig conf = IoTDBDescriptor.getInstance().getConfig(); @@ -134,7 +138,9 @@ public void setUp() throws Exception { .createLocalPeer( ConsensusGroupId.Factory.createFromTConsensusGroupId(regionReplicaSet.getRegionId()), genSchemaRegionPeerList(regionReplicaSet)); - dataNodeInternalRPCServiceImpl = new DataNodeInternalRPCServiceImpl(); + DataNodeContext context = Mockito.mock(DataNodeContext.class); + when(context.isAllConsensusStarted()).thenReturn(true); + dataNodeInternalRPCServiceImpl = new DataNodeInternalRPCServiceImpl(context); } @After diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/Await.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/Await.java new file mode 100644 index 0000000000000..f08ccd17495b7 --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/Await.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.concurrent; + +/** + * Lightweight polling utility for production code. Provides a fluent API similar to Awaitility for + * waiting until a condition becomes true. + * + *
{@code
+ * // Wait with timeout
+ * Await.await()
+ *     .atMost(5, TimeUnit.SECONDS)
+ *     .pollInterval(100, TimeUnit.MILLISECONDS)
+ *     .until(() -> isReady());
+ *
+ * // Wait forever (use with caution)
+ * Await.await()
+ *     .forever()
+ *     .pollInterval(1, TimeUnit.SECONDS)
+ *     .until(() -> isReady());
+ *
+ * // Ignore exceptions during polling
+ * Await.await()
+ *     .atMost(30, TimeUnit.SECONDS)
+ *     .ignoreExceptions()
+ *     .until(() -> tryConnect());
+ * }
+ */ +public final class Await { + + private Await() {} + + public static ConditionAwaiter await() { + return new ConditionAwaiter(); + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/AwaitTimeoutException.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/AwaitTimeoutException.java new file mode 100644 index 0000000000000..b0d5c98bfe73f --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/AwaitTimeoutException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.concurrent; + +public class AwaitTimeoutException extends RuntimeException { + + public AwaitTimeoutException(String message) { + super(message); + } + + public AwaitTimeoutException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ConditionAwaiter.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ConditionAwaiter.java new file mode 100644 index 0000000000000..f88db57f612ab --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ConditionAwaiter.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.concurrent; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; + +public class ConditionAwaiter { + + private static final long DEFAULT_POLL_INTERVAL_MS = 100; + private static final long DEFAULT_TIMEOUT_MS = 10_000; + + private long timeoutMs = DEFAULT_TIMEOUT_MS; + private long pollIntervalMs = DEFAULT_POLL_INTERVAL_MS; + private long pollDelayMs = 0; + private boolean ignoreAllExceptions = false; + private boolean forever = false; + private final List> ignoredExceptions = new ArrayList<>(); + + ConditionAwaiter() {} + + public ConditionAwaiter atMost(long time, TimeUnit unit) { + this.timeoutMs = unit.toMillis(time); + return this; + } + + public ConditionAwaiter pollInterval(long time, TimeUnit unit) { + this.pollIntervalMs = unit.toMillis(time); + return this; + } + + public ConditionAwaiter pollDelay(long time, TimeUnit unit) { + this.pollDelayMs = unit.toMillis(time); + return this; + } + + public ConditionAwaiter ignoreExceptions() { + this.ignoreAllExceptions = true; + return this; + } + + public ConditionAwaiter ignoreException(Class exceptionType) { + this.ignoredExceptions.add(exceptionType); + return this; + } + + public ConditionAwaiter forever() { + this.forever = true; + return this; + } + + public void until(Callable conditionEvaluator) { + long startTime = System.currentTimeMillis(); + + if (pollDelayMs > 0) { + sleep(pollDelayMs); + } + + Exception lastException = null; + while (true) { + try { + Boolean result = conditionEvaluator.call(); + if (Boolean.TRUE.equals(result)) { + return; + } + lastException = null; + } catch (Exception e) { + if (shouldIgnore(e)) { + lastException = e; + } else if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + throw new AwaitTimeoutException("Interrupted while awaiting condition", e); + } else { + throw new AwaitTimeoutException("Exception while evaluating condition", e); + } + } + + if (!forever && System.currentTimeMillis() - startTime >= timeoutMs) { + String message = String.format("Condition was not met within %d ms", timeoutMs); + if (lastException != null) { + throw new AwaitTimeoutException(message, lastException); + } + throw new AwaitTimeoutException(message); + } + + sleep(pollIntervalMs); + } + } + + public void untilAsserted(Runnable assertion) { + final AssertionErrorHolder holder = new AssertionErrorHolder(); + try { + until( + () -> { + try { + assertion.run(); + return true; + } catch (AssertionError e) { + holder.error = e; + return false; + } + }); + } catch (AwaitTimeoutException e) { + if (holder.error != null) { + throw new AwaitTimeoutException(e.getMessage(), holder.error); + } + throw e; + } + } + + private static final class AssertionErrorHolder { + AssertionError error; + } + + private boolean shouldIgnore(Exception e) { + if (ignoreAllExceptions) { + return true; + } + for (Class ignoredType : ignoredExceptions) { + if (ignoredType.isInstance(e)) { + return true; + } + } + return false; + } + + private static void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AwaitTimeoutException("Interrupted while awaiting condition", e); + } + } +} diff --git a/iotdb-core/node-commons/src/test/java/org/apache/iotdb/commons/AwaitTest.java b/iotdb-core/node-commons/src/test/java/org/apache/iotdb/commons/AwaitTest.java new file mode 100644 index 0000000000000..ae094c1a806c3 --- /dev/null +++ b/iotdb-core/node-commons/src/test/java/org/apache/iotdb/commons/AwaitTest.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons; + +import org.apache.iotdb.commons.concurrent.Await; +import org.apache.iotdb.commons.concurrent.AwaitTimeoutException; + +import org.junit.Test; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class AwaitTest { + + @Test + public void testConditionAlreadyTrue() { + Await.await().atMost(1, TimeUnit.SECONDS).until(() -> true); + } + + @Test + public void testConditionBecomesTrue() { + AtomicBoolean flag = new AtomicBoolean(false); + new Thread( + () -> { + try { + Thread.sleep(200); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + flag.set(true); + }) + .start(); + + Await.await() + .atMost(5, TimeUnit.SECONDS) + .pollInterval(50, TimeUnit.MILLISECONDS) + .until(flag::get); + + assertTrue(flag.get()); + } + + @Test(expected = AwaitTimeoutException.class) + public void testTimeout() { + Await.await() + .atMost(300, TimeUnit.MILLISECONDS) + .pollInterval(50, TimeUnit.MILLISECONDS) + .until(() -> false); + } + + @Test + public void testPollDelay() { + long start = System.currentTimeMillis(); + + Await.await() + .atMost(5, TimeUnit.SECONDS) + .pollDelay(200, TimeUnit.MILLISECONDS) + .until(() -> true); + + long elapsed = System.currentTimeMillis() - start; + assertTrue("Expected at least 200ms delay, got " + elapsed, elapsed >= 180); + } + + @Test + public void testIgnoreAllExceptions() { + AtomicInteger counter = new AtomicInteger(0); + + Await.await() + .atMost(5, TimeUnit.SECONDS) + .pollInterval(50, TimeUnit.MILLISECONDS) + .ignoreExceptions() + .until( + () -> { + int val = counter.incrementAndGet(); + if (val < 3) { + throw new RuntimeException("not ready yet"); + } + return true; + }); + + assertTrue(counter.get() >= 3); + } + + @Test + public void testIgnoreSpecificException() { + AtomicInteger counter = new AtomicInteger(0); + + Await.await() + .atMost(5, TimeUnit.SECONDS) + .pollInterval(50, TimeUnit.MILLISECONDS) + .ignoreException(IllegalStateException.class) + .until( + () -> { + int val = counter.incrementAndGet(); + if (val < 3) { + throw new IllegalStateException("not ready"); + } + return true; + }); + + assertTrue(counter.get() >= 3); + } + + @Test + public void testNonIgnoredExceptionPropagates() { + try { + Await.await() + .atMost(5, TimeUnit.SECONDS) + .ignoreException(IllegalStateException.class) + .until( + () -> { + throw new IllegalArgumentException("unexpected"); + }); + fail("Should have thrown"); + } catch (AwaitTimeoutException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + } + } + + @Test + public void testUntilAsserted() { + AtomicInteger value = new AtomicInteger(0); + new Thread( + () -> { + try { + Thread.sleep(200); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + value.set(42); + }) + .start(); + + Await.await() + .atMost(5, TimeUnit.SECONDS) + .pollInterval(50, TimeUnit.MILLISECONDS) + .untilAsserted(() -> assertEquals(42, value.get())); + } + + @Test + public void testForever() { + AtomicInteger counter = new AtomicInteger(0); + + Await.await() + .forever() + .pollInterval(10, TimeUnit.MILLISECONDS) + .until(() -> counter.incrementAndGet() >= 5); + + assertTrue(counter.get() >= 5); + } + + @Test + public void testTimeoutMessageIncludesLastException() { + try { + Await.await() + .atMost(200, TimeUnit.MILLISECONDS) + .pollInterval(50, TimeUnit.MILLISECONDS) + .ignoreExceptions() + .until( + () -> { + throw new RuntimeException("still failing"); + }); + fail("Should have thrown"); + } catch (AwaitTimeoutException e) { + assertTrue(e.getCause() instanceof RuntimeException); + assertEquals("still failing", e.getCause().getMessage()); + } + } +}