Skip to content

Commit 7d71a80

Browse files
Merge pull request rabbitmq#292 from rabbitmq/rabbitmq-java-client-290
Add option to ensure RPC reply is for the current request
2 parents dace543 + 0963ee3 commit 7d71a80

File tree

7 files changed

+210
-17
lines changed

7 files changed

+210
-17
lines changed

src/main/java/com/rabbitmq/client/ConnectionFactory.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ public class ConnectionFactory implements Cloneable {
127127
*/
128128
private int channelRpcTimeout = DEFAULT_CHANNEL_RPC_TIMEOUT;
129129

130+
/**
131+
* Whether or not channels check the reply type of an RPC call.
132+
* Default is false.
133+
* @since 4.2.0
134+
*/
135+
private boolean channelShouldCheckRpcResponseType = false;
136+
130137
/** @return the default host to use for connections */
131138
public String getHost() {
132139
return host;
@@ -958,6 +965,7 @@ public ConnectionParams params(ExecutorService consumerWorkServiceExecutor) {
958965
result.setShutdownExecutor(shutdownExecutor);
959966
result.setHeartbeatExecutor(heartbeatExecutor);
960967
result.setChannelRpcTimeout(channelRpcTimeout);
968+
result.setChannelShouldCheckRpcResponseType(channelShouldCheckRpcResponseType);
961969
return result;
962970
}
963971

@@ -1126,4 +1134,19 @@ public void setChannelRpcTimeout(int channelRpcTimeout) {
11261134
public int getChannelRpcTimeout() {
11271135
return channelRpcTimeout;
11281136
}
1137+
1138+
/**
1139+
* When set to true, channels will check the response type (e.g. queue.declare
1140+
* expects a queue.declare-ok response) of RPC calls
1141+
* and ignore those that do not match.
1142+
* Default is false.
1143+
* @param channelShouldCheckRpcResponseType
1144+
*/
1145+
public void setChannelShouldCheckRpcResponseType(boolean channelShouldCheckRpcResponseType) {
1146+
this.channelShouldCheckRpcResponseType = channelShouldCheckRpcResponseType;
1147+
}
1148+
1149+
public boolean isChannelShouldCheckRpcResponseType() {
1150+
return channelShouldCheckRpcResponseType;
1151+
}
11291152
}

src/main/java/com/rabbitmq/client/impl/AMQChannel.java

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
import java.util.concurrent.TimeoutException;
2121

2222
import com.rabbitmq.client.*;
23+
import com.rabbitmq.client.AMQP.Basic;
24+
import com.rabbitmq.client.AMQP.Confirm;
25+
import com.rabbitmq.client.AMQP.Exchange;
26+
import com.rabbitmq.client.AMQP.Queue;
27+
import com.rabbitmq.client.AMQP.Tx;
2328
import com.rabbitmq.client.Method;
2429
import com.rabbitmq.utility.BlockingValueOrException;
2530
import org.slf4j.Logger;
@@ -66,6 +71,8 @@ public abstract class AMQChannel extends ShutdownNotifierComponent {
6671
/** Timeout for RPC calls */
6772
protected final int _rpcTimeout;
6873

74+
private final boolean _checkRpcResponseType;
75+
6976
/**
7077
* Construct a channel on the given connection, with the given channel number.
7178
* @param connection the underlying connection for this channel
@@ -78,6 +85,7 @@ public AMQChannel(AMQConnection connection, int channelNumber) {
7885
throw new IllegalArgumentException("Continuation timeout on RPC calls cannot be less than 0");
7986
}
8087
this._rpcTimeout = connection.getChannelRpcTimeout();
88+
this._checkRpcResponseType = connection.willCheckRpcResponseType();
8189
}
8290

8391
/**
@@ -153,8 +161,19 @@ public void handleCompleteInboundCommand(AMQCommand command) throws IOException
153161
// waiting RPC continuation.
154162
if (!processAsync(command)) {
155163
// The filter decided not to handle/consume the command,
156-
// so it must be some reply to an earlier RPC.
157-
RpcContinuation nextOutstandingRpc = nextOutstandingRpc();
164+
// so it must be a response to an earlier RPC.
165+
if (_checkRpcResponseType) {
166+
synchronized (_channelMutex) {
167+
// check if this reply command is intended for the current waiting request before calling nextOutstandingRpc()
168+
if (!_activeRpc.canHandleReply(command)) {
169+
// this reply command is not intended for the current waiting request
170+
// most likely a previous request timed out and this command is the reply for that.
171+
// Throw this reply command away so we don't stop the current request from waiting for its reply
172+
return;
173+
}
174+
}
175+
}
176+
final RpcContinuation nextOutstandingRpc = nextOutstandingRpc();
158177
// the outstanding RPC can be null when calling Channel#asyncRpc
159178
if(nextOutstandingRpc != null) {
160179
nextOutstandingRpc.handleCommand(command);
@@ -229,7 +248,7 @@ public AMQCommand rpc(Method m, int timeout)
229248
private AMQCommand privateRpc(Method m)
230249
throws IOException, ShutdownSignalException
231250
{
232-
SimpleBlockingRpcContinuation k = new SimpleBlockingRpcContinuation();
251+
SimpleBlockingRpcContinuation k = new SimpleBlockingRpcContinuation(m);
233252
rpc(m, k);
234253
// At this point, the request method has been sent, and we
235254
// should wait for the reply to arrive.
@@ -266,7 +285,7 @@ protected ChannelContinuationTimeoutException wrapTimeoutException(final Method
266285

267286
private AMQCommand privateRpc(Method m, int timeout)
268287
throws IOException, ShutdownSignalException, TimeoutException {
269-
SimpleBlockingRpcContinuation k = new SimpleBlockingRpcContinuation();
288+
SimpleBlockingRpcContinuation k = new SimpleBlockingRpcContinuation(m);
270289
rpc(m, k);
271290

272291
try {
@@ -384,13 +403,25 @@ public AMQConnection getConnection() {
384403

385404
public interface RpcContinuation {
386405
void handleCommand(AMQCommand command);
406+
/** @return true if the reply command can be handled for this request */
407+
boolean canHandleReply(AMQCommand command);
387408
void handleShutdownSignal(ShutdownSignalException signal);
388409
}
389410

390411
public static abstract class BlockingRpcContinuation<T> implements RpcContinuation {
391412
public final BlockingValueOrException<T, ShutdownSignalException> _blocker =
392413
new BlockingValueOrException<T, ShutdownSignalException>();
393414

415+
protected final Method request;
416+
417+
public BlockingRpcContinuation() {
418+
request = null;
419+
}
420+
421+
public BlockingRpcContinuation(final Method request) {
422+
this.request = request;
423+
}
424+
394425
@Override
395426
public void handleCommand(AMQCommand command) {
396427
_blocker.setValue(transformReply(command));
@@ -412,12 +443,79 @@ public T getReply(int timeout)
412443
return _blocker.uninterruptibleGetValue(timeout);
413444
}
414445

446+
@Override
447+
public boolean canHandleReply(AMQCommand command) {
448+
// make a best effort attempt to ensure the reply was intended for this rpc request
449+
// Ideally each rpc request would tag an id on it that could be returned and referenced on its reply.
450+
// But because that would be a very large undertaking to add passively this logic at least protects against ClassCastExceptions
451+
if (request != null) {
452+
final Method reply = command.getMethod();
453+
if (request instanceof Basic.Qos) {
454+
return reply instanceof Basic.QosOk;
455+
} else if (request instanceof Basic.Get) {
456+
return reply instanceof Basic.GetOk || reply instanceof Basic.GetEmpty;
457+
} else if (request instanceof Basic.Consume) {
458+
if (!(reply instanceof Basic.ConsumeOk))
459+
return false;
460+
// can also check the consumer tags match here. handle case where request consumer tag is empty and server-generated.
461+
final String consumerTag = ((Basic.Consume)request).getConsumerTag();
462+
return consumerTag == null || consumerTag.equals("") || consumerTag.equals(((Basic.ConsumeOk)reply).getConsumerTag());
463+
} else if (request instanceof Basic.Cancel) {
464+
if (!(reply instanceof Basic.CancelOk))
465+
return false;
466+
// can also check the consumer tags match here
467+
return ((Basic.Cancel)request).getConsumerTag().equals(((Basic.CancelOk)reply).getConsumerTag());
468+
} else if (request instanceof Basic.Recover) {
469+
return reply instanceof Basic.RecoverOk;
470+
} else if (request instanceof Exchange.Declare) {
471+
return reply instanceof Exchange.DeclareOk;
472+
} else if (request instanceof Exchange.Delete) {
473+
return reply instanceof Exchange.DeleteOk;
474+
} else if (request instanceof Exchange.Bind) {
475+
return reply instanceof Exchange.BindOk;
476+
} else if (request instanceof Exchange.Unbind) {
477+
return reply instanceof Exchange.UnbindOk;
478+
} else if (request instanceof Queue.Declare) {
479+
// we cannot check the queue name, as the server can strip some characters
480+
// see QueueLifecycle test and https://github.com/rabbitmq/rabbitmq-server/issues/710
481+
return reply instanceof Queue.DeclareOk;
482+
} else if (request instanceof Queue.Delete) {
483+
return reply instanceof Queue.DeleteOk;
484+
} else if (request instanceof Queue.Bind) {
485+
return reply instanceof Queue.BindOk;
486+
} else if (request instanceof Queue.Unbind) {
487+
return reply instanceof Queue.UnbindOk;
488+
} else if (request instanceof Queue.Purge) {
489+
return reply instanceof Queue.PurgeOk;
490+
} else if (request instanceof Tx.Select) {
491+
return reply instanceof Tx.SelectOk;
492+
} else if (request instanceof Tx.Commit) {
493+
return reply instanceof Tx.CommitOk;
494+
} else if (request instanceof Tx.Rollback) {
495+
return reply instanceof Tx.RollbackOk;
496+
} else if (request instanceof Confirm.Select) {
497+
return reply instanceof Confirm.SelectOk;
498+
}
499+
}
500+
// for passivity default to true
501+
return true;
502+
}
503+
415504
public abstract T transformReply(AMQCommand command);
416505
}
417506

418507
public static class SimpleBlockingRpcContinuation
419508
extends BlockingRpcContinuation<AMQCommand>
420509
{
510+
511+
public SimpleBlockingRpcContinuation() {
512+
super();
513+
}
514+
515+
public SimpleBlockingRpcContinuation(final Method method) {
516+
super(method);
517+
}
518+
421519
@Override
422520
public AMQCommand transformReply(AMQCommand command) {
423521
return command;

src/main/java/com/rabbitmq/client/impl/AMQConnection.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ public static Map<String, Object> defaultClientProperties() {
134134
private final Collection<BlockedListener> blockedListeners = new CopyOnWriteArrayList<BlockedListener>();
135135
protected final MetricsCollector metricsCollector;
136136
private final int channelRpcTimeout;
137+
private final boolean channelShouldCheckRpcResponseType;
137138

138139
/* State modified after start - all volatile */
139140

@@ -229,6 +230,7 @@ public AMQConnection(ConnectionParams params, FrameHandler frameHandler, Metrics
229230
throw new IllegalArgumentException("Continuation timeout on RPC calls cannot be less than 0");
230231
}
231232
this.channelRpcTimeout = params.getChannelRpcTimeout();
233+
this.channelShouldCheckRpcResponseType = params.channelShouldCheckRpcResponseType();
232234

233235
this._channel0 = new AMQChannel(this, 0) {
234236
@Override public boolean processAsync(Command c) throws IOException {
@@ -1056,4 +1058,8 @@ public void setId(String id) {
10561058
public int getChannelRpcTimeout() {
10571059
return channelRpcTimeout;
10581060
}
1061+
1062+
public boolean willCheckRpcResponseType() {
1063+
return channelShouldCheckRpcResponseType;
1064+
}
10591065
}

src/main/java/com/rabbitmq/client/impl/ChannelN.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,15 @@ public String basicConsume(String queue, final boolean autoAck, String consumerT
12291229
final Consumer callback)
12301230
throws IOException
12311231
{
1232-
BlockingRpcContinuation<String> k = new BlockingRpcContinuation<String>() {
1232+
final Method m = new Basic.Consume.Builder()
1233+
.queue(queue)
1234+
.consumerTag(consumerTag)
1235+
.noLocal(noLocal)
1236+
.noAck(autoAck)
1237+
.exclusive(exclusive)
1238+
.arguments(arguments)
1239+
.build();
1240+
BlockingRpcContinuation<String> k = new BlockingRpcContinuation<String>(m) {
12331241
@Override
12341242
public String transformReply(AMQCommand replyCommand) {
12351243
String actualConsumerTag = ((Basic.ConsumeOk) replyCommand.getMethod()).getConsumerTag();
@@ -1243,14 +1251,7 @@ public String transformReply(AMQCommand replyCommand) {
12431251
}
12441252
};
12451253

1246-
final Method m = new Basic.Consume.Builder()
1247-
.queue(queue)
1248-
.consumerTag(consumerTag)
1249-
.noLocal(noLocal)
1250-
.noAck(autoAck)
1251-
.exclusive(exclusive)
1252-
.arguments(arguments)
1253-
.build();
1254+
12541255
rpc(m, k);
12551256

12561257
try {
@@ -1276,7 +1277,9 @@ public void basicCancel(final String consumerTag)
12761277
final Consumer originalConsumer = _consumers.get(consumerTag);
12771278
if (originalConsumer == null)
12781279
throw new IOException("Unknown consumerTag");
1279-
BlockingRpcContinuation<Consumer> k = new BlockingRpcContinuation<Consumer>() {
1280+
1281+
final Method m = new Basic.Cancel(consumerTag, false);
1282+
BlockingRpcContinuation<Consumer> k = new BlockingRpcContinuation<Consumer>(m) {
12801283
@Override
12811284
public Consumer transformReply(AMQCommand replyCommand) {
12821285
if (!(replyCommand.getMethod() instanceof Basic.CancelOk))
@@ -1287,7 +1290,7 @@ public Consumer transformReply(AMQCommand replyCommand) {
12871290
}
12881291
};
12891292

1290-
final Method m = new Basic.Cancel(consumerTag, false);
1293+
12911294
rpc(m, k);
12921295

12931296
try {

src/main/java/com/rabbitmq/client/impl/ConnectionParams.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public class ConnectionParams {
4040
private long networkRecoveryInterval;
4141
private boolean topologyRecovery;
4242
private int channelRpcTimeout;
43+
private boolean channelShouldCheckRpcResponseType;
4344

4445
private ExceptionHandler exceptionHandler;
4546
private ThreadFactory threadFactory;
@@ -114,6 +115,10 @@ public int getChannelRpcTimeout() {
114115
return channelRpcTimeout;
115116
}
116117

118+
public boolean channelShouldCheckRpcResponseType() {
119+
return channelShouldCheckRpcResponseType;
120+
}
121+
117122
public void setUsername(String username) {
118123
this.username = username;
119124
}
@@ -189,4 +194,8 @@ public void setHeartbeatExecutor(ScheduledExecutorService heartbeatExecutor) {
189194
public void setChannelRpcTimeout(int channelRpcTimeout) {
190195
this.channelRpcTimeout = channelRpcTimeout;
191196
}
197+
198+
public void setChannelShouldCheckRpcResponseType(boolean channelShouldCheckRpcResponseType) {
199+
this.channelShouldCheckRpcResponseType = channelShouldCheckRpcResponseType;
200+
}
192201
}

src/main/java/com/rabbitmq/client/impl/nio/HeaderWriteRequest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
import java.io.DataOutputStream;
2121
import java.io.IOException;
22-
import java.nio.ByteBuffer;
23-
import java.nio.channels.WritableByteChannel;
2422

2523
/**
2624
*

src/test/java/com/rabbitmq/client/test/AMQChannelTest.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,62 @@ public Void call() throws Exception {
106106
assertThat(rpcResponse.getMethod(), is(response));
107107
}
108108

109+
@Test
110+
public void testRpcTimeoutReplyComesDuringNexRpc() throws Exception {
111+
int rpcTimeout = 100;
112+
AMQConnection connection = mock(AMQConnection.class);
113+
when(connection.getChannelRpcTimeout()).thenReturn(rpcTimeout);
114+
when(connection.willCheckRpcResponseType()).thenReturn(Boolean.TRUE);
115+
116+
final DummyAmqChannel channel = new DummyAmqChannel(connection, 1);
117+
Method method = new AMQImpl.Queue.Declare.Builder()
118+
.queue("123")
119+
.durable(false)
120+
.exclusive(true)
121+
.autoDelete(true)
122+
.arguments(null)
123+
.build();
124+
125+
try {
126+
channel.rpc(method);
127+
fail("Should time out and throw an exception");
128+
} catch(final ChannelContinuationTimeoutException e) {
129+
// OK
130+
assertThat((DummyAmqChannel) e.getChannel(), is(channel));
131+
assertThat(e.getChannelNumber(), is(channel.getChannelNumber()));
132+
assertThat(e.getMethod(), is(method));
133+
assertNull("outstanding RPC should have been cleaned", channel.nextOutstandingRpc());
134+
}
135+
136+
// now do a basic.consume request and have the queue.declareok returned instead
137+
method = new AMQImpl.Basic.Consume.Builder()
138+
.queue("123")
139+
.consumerTag("")
140+
.arguments(null)
141+
.build();
142+
143+
final Method response1 = new AMQImpl.Queue.DeclareOk.Builder()
144+
.queue("123")
145+
.consumerCount(0)
146+
.messageCount(0).build();
147+
148+
final Method response2 = new AMQImpl.Basic.ConsumeOk.Builder()
149+
.consumerTag("456").build();
150+
151+
scheduler.schedule(new Callable<Void>() {
152+
@Override
153+
public Void call() throws Exception {
154+
channel.handleCompleteInboundCommand(new AMQCommand(response1));
155+
Thread.sleep(10);
156+
channel.handleCompleteInboundCommand(new AMQCommand(response2));
157+
return null;
158+
}
159+
}, (long) (rpcTimeout / 2.0), TimeUnit.MILLISECONDS);
160+
161+
AMQCommand rpcResponse = channel.rpc(method);
162+
assertThat(rpcResponse.getMethod(), is(response2));
163+
}
164+
109165
static class DummyAmqChannel extends AMQChannel {
110166

111167
public DummyAmqChannel(AMQConnection connection, int channelNumber) {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy