Madan Jampani
Committed by Gerrit Code Review

MessagingService API enchancements

Change-Id: Iabfe15d4f08d7c53bd6575c5d94d0ac9f4e1a38e
...@@ -17,8 +17,8 @@ package org.onosproject.store.cluster.messaging; ...@@ -17,8 +17,8 @@ package org.onosproject.store.cluster.messaging;
17 17
18 import java.util.concurrent.CompletableFuture; 18 import java.util.concurrent.CompletableFuture;
19 import java.util.concurrent.Executor; 19 import java.util.concurrent.Executor;
20 -import java.util.function.Consumer; 20 +import java.util.function.BiConsumer;
21 -import java.util.function.Function; 21 +import java.util.function.BiFunction;
22 22
23 /** 23 /**
24 * Interface for low level messaging primitives. 24 * Interface for low level messaging primitives.
...@@ -36,7 +36,7 @@ public interface MessagingService { ...@@ -36,7 +36,7 @@ public interface MessagingService {
36 CompletableFuture<Void> sendAsync(Endpoint ep, String type, byte[] payload); 36 CompletableFuture<Void> sendAsync(Endpoint ep, String type, byte[] payload);
37 37
38 /** 38 /**
39 - * Sends a message synchronously and waits for a response. 39 + * Sends a message asynchronously and expects a response.
40 * @param ep end point to send the message to. 40 * @param ep end point to send the message to.
41 * @param type type of message. 41 * @param type type of message.
42 * @param payload message payload. 42 * @param payload message payload.
...@@ -45,12 +45,22 @@ public interface MessagingService { ...@@ -45,12 +45,22 @@ public interface MessagingService {
45 CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload); 45 CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload);
46 46
47 /** 47 /**
48 + * Sends a message synchronously and expects a response.
49 + * @param ep end point to send the message to.
50 + * @param type type of message.
51 + * @param payload message payload.
52 + * @param executor executor over which any follow up actions after completion will be executed.
53 + * @return a response future
54 + */
55 + CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload, Executor executor);
56 +
57 + /**
48 * Registers a new message handler for message type. 58 * Registers a new message handler for message type.
49 * @param type message type. 59 * @param type message type.
50 * @param handler message handler 60 * @param handler message handler
51 * @param executor executor to use for running message handler logic. 61 * @param executor executor to use for running message handler logic.
52 */ 62 */
53 - void registerHandler(String type, Consumer<byte[]> handler, Executor executor); 63 + void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor);
54 64
55 /** 65 /**
56 * Registers a new message handler for message type. 66 * Registers a new message handler for message type.
...@@ -58,14 +68,14 @@ public interface MessagingService { ...@@ -58,14 +68,14 @@ public interface MessagingService {
58 * @param handler message handler 68 * @param handler message handler
59 * @param executor executor to use for running message handler logic. 69 * @param executor executor to use for running message handler logic.
60 */ 70 */
61 - void registerHandler(String type, Function<byte[], byte[]> handler, Executor executor); 71 + void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor);
62 72
63 /** 73 /**
64 * Registers a new message handler for message type. 74 * Registers a new message handler for message type.
65 * @param type message type. 75 * @param type message type.
66 * @param handler message handler 76 * @param handler message handler
67 */ 77 */
68 - void registerHandler(String type, Function<byte[], CompletableFuture<byte[]>> handler); 78 + void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler);
69 79
70 /** 80 /**
71 * Unregister current handler, if one exists for message type. 81 * Unregister current handler, if one exists for message type.
......
...@@ -49,7 +49,7 @@ import java.util.concurrent.ExecutorService; ...@@ -49,7 +49,7 @@ import java.util.concurrent.ExecutorService;
49 import java.util.concurrent.Executors; 49 import java.util.concurrent.Executors;
50 import java.util.concurrent.ScheduledExecutorService; 50 import java.util.concurrent.ScheduledExecutorService;
51 import java.util.concurrent.TimeUnit; 51 import java.util.concurrent.TimeUnit;
52 -import java.util.function.Consumer; 52 +import java.util.function.BiConsumer;
53 import java.util.stream.Collectors; 53 import java.util.stream.Collectors;
54 54
55 import static com.google.common.base.Preconditions.checkNotNull; 55 import static com.google.common.base.Preconditions.checkNotNull;
...@@ -241,9 +241,9 @@ public class DistributedClusterStore ...@@ -241,9 +241,9 @@ public class DistributedClusterStore
241 }); 241 });
242 } 242 }
243 243
244 - private class HeartbeatMessageHandler implements Consumer<byte[]> { 244 + private class HeartbeatMessageHandler implements BiConsumer<Endpoint, byte[]> {
245 @Override 245 @Override
246 - public void accept(byte[] message) { 246 + public void accept(Endpoint sender, byte[] message) {
247 HeartbeatMessage hb = SERIALIZER.decode(message); 247 HeartbeatMessage hb = SERIALIZER.decode(message);
248 failureDetector.report(hb.source().id()); 248 failureDetector.report(hb.source().id());
249 hb.knownPeers().forEach(node -> { 249 hb.knownPeers().forEach(node -> {
......
...@@ -35,10 +35,13 @@ import org.slf4j.Logger; ...@@ -35,10 +35,13 @@ import org.slf4j.Logger;
35 import org.slf4j.LoggerFactory; 35 import org.slf4j.LoggerFactory;
36 36
37 import com.google.common.base.Objects; 37 import com.google.common.base.Objects;
38 +
38 import java.util.Set; 39 import java.util.Set;
39 import java.util.concurrent.CompletableFuture; 40 import java.util.concurrent.CompletableFuture;
40 import java.util.concurrent.Executor; 41 import java.util.concurrent.Executor;
41 import java.util.concurrent.ExecutorService; 42 import java.util.concurrent.ExecutorService;
43 +import java.util.function.BiConsumer;
44 +import java.util.function.BiFunction;
42 import java.util.function.Consumer; 45 import java.util.function.Consumer;
43 import java.util.function.Function; 46 import java.util.function.Function;
44 import java.util.stream.Collectors; 47 import java.util.stream.Collectors;
...@@ -210,7 +213,7 @@ public class ClusterCommunicationManager ...@@ -210,7 +213,7 @@ public class ClusterCommunicationManager
210 executor); 213 executor);
211 } 214 }
212 215
213 - private class InternalClusterMessageHandler implements Function<byte[], byte[]> { 216 + private class InternalClusterMessageHandler implements BiFunction<Endpoint, byte[], byte[]> {
214 private ClusterMessageHandler handler; 217 private ClusterMessageHandler handler;
215 218
216 public InternalClusterMessageHandler(ClusterMessageHandler handler) { 219 public InternalClusterMessageHandler(ClusterMessageHandler handler) {
...@@ -218,14 +221,14 @@ public class ClusterCommunicationManager ...@@ -218,14 +221,14 @@ public class ClusterCommunicationManager
218 } 221 }
219 222
220 @Override 223 @Override
221 - public byte[] apply(byte[] bytes) { 224 + public byte[] apply(Endpoint sender, byte[] bytes) {
222 ClusterMessage message = ClusterMessage.fromBytes(bytes); 225 ClusterMessage message = ClusterMessage.fromBytes(bytes);
223 handler.handle(message); 226 handler.handle(message);
224 return message.response(); 227 return message.response();
225 } 228 }
226 } 229 }
227 230
228 - private class InternalMessageResponder<M, R> implements Function<byte[], CompletableFuture<byte[]>> { 231 + private class InternalMessageResponder<M, R> implements BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> {
229 private final Function<byte[], M> decoder; 232 private final Function<byte[], M> decoder;
230 private final Function<R, byte[]> encoder; 233 private final Function<R, byte[]> encoder;
231 private final Function<M, CompletableFuture<R>> handler; 234 private final Function<M, CompletableFuture<R>> handler;
...@@ -239,12 +242,12 @@ public class ClusterCommunicationManager ...@@ -239,12 +242,12 @@ public class ClusterCommunicationManager
239 } 242 }
240 243
241 @Override 244 @Override
242 - public CompletableFuture<byte[]> apply(byte[] bytes) { 245 + public CompletableFuture<byte[]> apply(Endpoint sender, byte[] bytes) {
243 return handler.apply(decoder.apply(ClusterMessage.fromBytes(bytes).payload())).thenApply(encoder); 246 return handler.apply(decoder.apply(ClusterMessage.fromBytes(bytes).payload())).thenApply(encoder);
244 } 247 }
245 } 248 }
246 249
247 - private class InternalMessageConsumer<M> implements Consumer<byte[]> { 250 + private class InternalMessageConsumer<M> implements BiConsumer<Endpoint, byte[]> {
248 private final Function<byte[], M> decoder; 251 private final Function<byte[], M> decoder;
249 private final Consumer<M> consumer; 252 private final Consumer<M> consumer;
250 253
...@@ -254,7 +257,7 @@ public class ClusterCommunicationManager ...@@ -254,7 +257,7 @@ public class ClusterCommunicationManager
254 } 257 }
255 258
256 @Override 259 @Override
257 - public void accept(byte[] bytes) { 260 + public void accept(Endpoint sender, byte[] bytes) {
258 consumer.accept(decoder.apply(ClusterMessage.fromBytes(bytes).payload())); 261 consumer.accept(decoder.apply(ClusterMessage.fromBytes(bytes).payload()));
259 } 262 }
260 } 263 }
......
...@@ -15,6 +15,12 @@ ...@@ -15,6 +15,12 @@
15 */ 15 */
16 package org.onlab.netty; 16 package org.onlab.netty;
17 17
18 +import com.google.common.cache.Cache;
19 +import com.google.common.cache.CacheBuilder;
20 +import com.google.common.cache.RemovalListener;
21 +import com.google.common.cache.RemovalNotification;
22 +import com.google.common.util.concurrent.MoreExecutors;
23 +
18 import io.netty.bootstrap.Bootstrap; 24 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap; 25 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.PooledByteBufAllocator; 26 import io.netty.buffer.PooledByteBufAllocator;
...@@ -35,6 +41,19 @@ import io.netty.channel.socket.SocketChannel; ...@@ -35,6 +41,19 @@ import io.netty.channel.socket.SocketChannel;
35 import io.netty.channel.socket.nio.NioServerSocketChannel; 41 import io.netty.channel.socket.nio.NioServerSocketChannel;
36 import io.netty.channel.socket.nio.NioSocketChannel; 42 import io.netty.channel.socket.nio.NioSocketChannel;
37 43
44 +import org.apache.commons.pool.KeyedPoolableObjectFactory;
45 +import org.apache.commons.pool.impl.GenericKeyedObjectPool;
46 +import org.onlab.util.Tools;
47 +import org.onosproject.store.cluster.messaging.Endpoint;
48 +import org.onosproject.store.cluster.messaging.MessagingService;
49 +import org.slf4j.Logger;
50 +import org.slf4j.LoggerFactory;
51 +
52 +import javax.net.ssl.KeyManagerFactory;
53 +import javax.net.ssl.SSLContext;
54 +import javax.net.ssl.SSLEngine;
55 +import javax.net.ssl.TrustManagerFactory;
56 +
38 import java.io.FileInputStream; 57 import java.io.FileInputStream;
39 import java.io.IOException; 58 import java.io.IOException;
40 import java.security.KeyStore; 59 import java.security.KeyStore;
...@@ -47,25 +66,9 @@ import java.util.concurrent.TimeUnit; ...@@ -47,25 +66,9 @@ import java.util.concurrent.TimeUnit;
47 import java.util.concurrent.TimeoutException; 66 import java.util.concurrent.TimeoutException;
48 import java.util.concurrent.atomic.AtomicBoolean; 67 import java.util.concurrent.atomic.AtomicBoolean;
49 import java.util.concurrent.atomic.AtomicLong; 68 import java.util.concurrent.atomic.AtomicLong;
69 +import java.util.function.BiConsumer;
70 +import java.util.function.BiFunction;
50 import java.util.function.Consumer; 71 import java.util.function.Consumer;
51 -import java.util.function.Function;
52 -
53 -import javax.net.ssl.KeyManagerFactory;
54 -import javax.net.ssl.SSLContext;
55 -import javax.net.ssl.SSLEngine;
56 -import javax.net.ssl.TrustManagerFactory;
57 -
58 -import org.apache.commons.pool.KeyedPoolableObjectFactory;
59 -import org.apache.commons.pool.impl.GenericKeyedObjectPool;
60 -import org.onosproject.store.cluster.messaging.Endpoint;
61 -import org.onosproject.store.cluster.messaging.MessagingService;
62 -import org.slf4j.Logger;
63 -import org.slf4j.LoggerFactory;
64 -
65 -import com.google.common.cache.Cache;
66 -import com.google.common.cache.CacheBuilder;
67 -import com.google.common.cache.RemovalListener;
68 -import com.google.common.cache.RemovalNotification;
69 72
70 /** 73 /**
71 * Implementation of MessagingService based on <a href="http://netty.io/">Netty</a> framework. 74 * Implementation of MessagingService based on <a href="http://netty.io/">Netty</a> framework.
...@@ -81,11 +84,11 @@ public class NettyMessaging implements MessagingService { ...@@ -81,11 +84,11 @@ public class NettyMessaging implements MessagingService {
81 private final AtomicBoolean started = new AtomicBoolean(false); 84 private final AtomicBoolean started = new AtomicBoolean(false);
82 private final Map<String, Consumer<InternalMessage>> handlers = new ConcurrentHashMap<>(); 85 private final Map<String, Consumer<InternalMessage>> handlers = new ConcurrentHashMap<>();
83 private final AtomicLong messageIdGenerator = new AtomicLong(0); 86 private final AtomicLong messageIdGenerator = new AtomicLong(0);
84 - private final Cache<Long, CompletableFuture<byte[]>> responseFutures = CacheBuilder.newBuilder() 87 + private final Cache<Long, Callback> callbacks = CacheBuilder.newBuilder()
85 .expireAfterWrite(10, TimeUnit.SECONDS) 88 .expireAfterWrite(10, TimeUnit.SECONDS)
86 - .removalListener(new RemovalListener<Long, CompletableFuture<byte[]>>() { 89 + .removalListener(new RemovalListener<Long, Callback>() {
87 @Override 90 @Override
88 - public void onRemoval(RemovalNotification<Long, CompletableFuture<byte[]>> entry) { 91 + public void onRemoval(RemovalNotification<Long, Callback> entry) {
89 if (entry.wasEvicted()) { 92 if (entry.wasEvicted()) {
90 entry.getValue().completeExceptionally(new TimeoutException("Timedout waiting for reply")); 93 entry.getValue().completeExceptionally(new TimeoutException("Timedout waiting for reply"));
91 } 94 }
...@@ -165,12 +168,17 @@ public class NettyMessaging implements MessagingService { ...@@ -165,12 +168,17 @@ public class NettyMessaging implements MessagingService {
165 } 168 }
166 169
167 protected CompletableFuture<Void> sendAsync(Endpoint ep, InternalMessage message) { 170 protected CompletableFuture<Void> sendAsync(Endpoint ep, InternalMessage message) {
168 - CompletableFuture<Void> future = new CompletableFuture<>();
169 - try {
170 if (ep.equals(localEp)) { 171 if (ep.equals(localEp)) {
172 + try {
171 dispatchLocally(message); 173 dispatchLocally(message);
172 - future.complete(null); 174 + } catch (IOException e) {
173 - } else { 175 + return Tools.exceptionalFuture(e);
176 + }
177 + return CompletableFuture.completedFuture(null);
178 + }
179 +
180 + CompletableFuture<Void> future = new CompletableFuture<>();
181 + try {
174 Channel channel = null; 182 Channel channel = null;
175 try { 183 try {
176 channel = channels.borrowObject(ep); 184 channel = channels.borrowObject(ep);
...@@ -184,7 +192,6 @@ public class NettyMessaging implements MessagingService { ...@@ -184,7 +192,6 @@ public class NettyMessaging implements MessagingService {
184 } finally { 192 } finally {
185 channels.returnObject(ep, channel); 193 channels.returnObject(ep, channel);
186 } 194 }
187 - }
188 } catch (Exception e) { 195 } catch (Exception e) {
189 future.completeExceptionally(e); 196 future.completeExceptionally(e);
190 } 197 }
...@@ -193,28 +200,32 @@ public class NettyMessaging implements MessagingService { ...@@ -193,28 +200,32 @@ public class NettyMessaging implements MessagingService {
193 200
194 @Override 201 @Override
195 public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload) { 202 public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload) {
203 + return sendAndReceive(ep, type, payload, MoreExecutors.directExecutor());
204 + }
205 +
206 + @Override
207 + public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload, Executor executor) {
196 CompletableFuture<byte[]> response = new CompletableFuture<>(); 208 CompletableFuture<byte[]> response = new CompletableFuture<>();
209 + Callback callback = new Callback(response, executor);
197 Long messageId = messageIdGenerator.incrementAndGet(); 210 Long messageId = messageIdGenerator.incrementAndGet();
198 - responseFutures.put(messageId, response); 211 + callbacks.put(messageId, callback);
199 InternalMessage message = new InternalMessage(messageId, localEp, type, payload); 212 InternalMessage message = new InternalMessage(messageId, localEp, type, payload);
200 - try { 213 + return sendAsync(ep, message).whenComplete((r, e) -> {
201 - sendAsync(ep, message); 214 + if (e != null) {
202 - } catch (Exception e) { 215 + callbacks.invalidate(messageId);
203 - responseFutures.invalidate(messageId);
204 - response.completeExceptionally(e);
205 } 216 }
206 - return response; 217 + }).thenCompose(v -> response);
207 } 218 }
208 219
209 @Override 220 @Override
210 - public void registerHandler(String type, Consumer<byte[]> handler, Executor executor) { 221 + public void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor) {
211 - handlers.put(type, message -> executor.execute(() -> handler.accept(message.payload()))); 222 + handlers.put(type, message -> executor.execute(() -> handler.accept(message.sender(), message.payload())));
212 } 223 }
213 224
214 @Override 225 @Override
215 - public void registerHandler(String type, Function<byte[], byte[]> handler, Executor executor) { 226 + public void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor) {
216 handlers.put(type, message -> executor.execute(() -> { 227 handlers.put(type, message -> executor.execute(() -> {
217 - byte[] responsePayload = handler.apply(message.payload()); 228 + byte[] responsePayload = handler.apply(message.sender(), message.payload());
218 if (responsePayload != null) { 229 if (responsePayload != null) {
219 InternalMessage response = new InternalMessage(message.id(), 230 InternalMessage response = new InternalMessage(message.id(),
220 localEp, 231 localEp,
...@@ -230,9 +241,9 @@ public class NettyMessaging implements MessagingService { ...@@ -230,9 +241,9 @@ public class NettyMessaging implements MessagingService {
230 } 241 }
231 242
232 @Override 243 @Override
233 - public void registerHandler(String type, Function<byte[], CompletableFuture<byte[]>> handler) { 244 + public void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler) {
234 handlers.put(type, message -> { 245 handlers.put(type, message -> {
235 - handler.apply(message.payload()).whenComplete((result, error) -> { 246 + handler.apply(message.sender(), message.payload()).whenComplete((result, error) -> {
236 if (error == null) { 247 if (error == null) {
237 InternalMessage response = new InternalMessage(message.id(), 248 InternalMessage response = new InternalMessage(message.id(),
238 localEp, 249 localEp,
...@@ -435,17 +446,17 @@ public class NettyMessaging implements MessagingService { ...@@ -435,17 +446,17 @@ public class NettyMessaging implements MessagingService {
435 String type = message.type(); 446 String type = message.type();
436 if (REPLY_MESSAGE_TYPE.equals(type)) { 447 if (REPLY_MESSAGE_TYPE.equals(type)) {
437 try { 448 try {
438 - CompletableFuture<byte[]> futureResponse = 449 + Callback callback =
439 - responseFutures.getIfPresent(message.id()); 450 + callbacks.getIfPresent(message.id());
440 - if (futureResponse != null) { 451 + if (callback != null) {
441 - futureResponse.complete(message.payload()); 452 + callback.complete(message.payload());
442 } else { 453 } else {
443 log.warn("Received a reply for message id:[{}]. " 454 log.warn("Received a reply for message id:[{}]. "
444 + " from {}. But was unable to locate the" 455 + " from {}. But was unable to locate the"
445 + " request handle", message.id(), message.sender()); 456 + " request handle", message.id(), message.sender());
446 } 457 }
447 } finally { 458 } finally {
448 - responseFutures.invalidate(message.id()); 459 + callbacks.invalidate(message.id());
449 } 460 }
450 return; 461 return;
451 } 462 }
...@@ -456,4 +467,22 @@ public class NettyMessaging implements MessagingService { ...@@ -456,4 +467,22 @@ public class NettyMessaging implements MessagingService {
456 log.debug("No handler registered for {}", type); 467 log.debug("No handler registered for {}", type);
457 } 468 }
458 } 469 }
470 +
471 + private final class Callback {
472 + private final CompletableFuture<byte[]> future;
473 + private final Executor executor;
474 +
475 + public Callback(CompletableFuture<byte[]> future, Executor executor) {
476 + this.future = future;
477 + this.executor = executor;
478 + }
479 +
480 + public void complete(byte[] value) {
481 + executor.execute(() -> future.complete(value));
482 + }
483 +
484 + public void completeExceptionally(Throwable error) {
485 + executor.execute(() -> future.completeExceptionally(error));
486 + }
487 + }
459 } 488 }
......
...@@ -33,8 +33,9 @@ import java.util.concurrent.Executors; ...@@ -33,8 +33,9 @@ import java.util.concurrent.Executors;
33 import java.util.concurrent.TimeoutException; 33 import java.util.concurrent.TimeoutException;
34 import java.util.concurrent.atomic.AtomicBoolean; 34 import java.util.concurrent.atomic.AtomicBoolean;
35 import java.util.concurrent.atomic.AtomicLong; 35 import java.util.concurrent.atomic.AtomicLong;
36 +import java.util.function.BiConsumer;
37 +import java.util.function.BiFunction;
36 import java.util.function.Consumer; 38 import java.util.function.Consumer;
37 -import java.util.function.Function;
38 39
39 import org.apache.commons.pool.KeyedPoolableObjectFactory; 40 import org.apache.commons.pool.KeyedPoolableObjectFactory;
40 import org.apache.commons.pool.impl.GenericKeyedObjectPool; 41 import org.apache.commons.pool.impl.GenericKeyedObjectPool;
...@@ -50,6 +51,7 @@ import com.google.common.cache.CacheBuilder; ...@@ -50,6 +51,7 @@ import com.google.common.cache.CacheBuilder;
50 import com.google.common.cache.RemovalListener; 51 import com.google.common.cache.RemovalListener;
51 import com.google.common.cache.RemovalNotification; 52 import com.google.common.cache.RemovalNotification;
52 import com.google.common.collect.Lists; 53 import com.google.common.collect.Lists;
54 +import com.google.common.util.concurrent.MoreExecutors;
53 55
54 /** 56 /**
55 * MessagingService implementation based on IOLoop. 57 * MessagingService implementation based on IOLoop.
...@@ -86,10 +88,10 @@ public class IOLoopMessaging implements MessagingService { ...@@ -86,10 +88,10 @@ public class IOLoopMessaging implements MessagingService {
86 88
87 private final ConcurrentMap<String, Consumer<DefaultMessage>> handlers = new ConcurrentHashMap<>(); 89 private final ConcurrentMap<String, Consumer<DefaultMessage>> handlers = new ConcurrentHashMap<>();
88 private final AtomicLong messageIdGenerator = new AtomicLong(0); 90 private final AtomicLong messageIdGenerator = new AtomicLong(0);
89 - private final Cache<Long, CompletableFuture<byte[]>> responseFutures = CacheBuilder.newBuilder() 91 + private final Cache<Long, Callback> responseFutures = CacheBuilder.newBuilder()
90 - .removalListener(new RemovalListener<Long, CompletableFuture<byte[]>>() { 92 + .removalListener(new RemovalListener<Long, Callback>() {
91 @Override 93 @Override
92 - public void onRemoval(RemovalNotification<Long, CompletableFuture<byte[]>> entry) { 94 + public void onRemoval(RemovalNotification<Long, Callback> entry) {
93 if (entry.wasEvicted()) { 95 if (entry.wasEvicted()) {
94 entry.getValue().completeExceptionally(new TimeoutException("Timedout waiting for reply")); 96 entry.getValue().completeExceptionally(new TimeoutException("Timedout waiting for reply"));
95 } 97 }
...@@ -176,29 +178,37 @@ public class IOLoopMessaging implements MessagingService { ...@@ -176,29 +178,37 @@ public class IOLoopMessaging implements MessagingService {
176 public CompletableFuture<byte[]> sendAndReceive( 178 public CompletableFuture<byte[]> sendAndReceive(
177 Endpoint ep, 179 Endpoint ep,
178 String type, 180 String type,
179 - byte[] payload) { 181 + byte[] payload,
182 + Executor executor) {
180 CompletableFuture<byte[]> response = new CompletableFuture<>(); 183 CompletableFuture<byte[]> response = new CompletableFuture<>();
184 + Callback callback = new Callback(response, executor);
181 Long messageId = messageIdGenerator.incrementAndGet(); 185 Long messageId = messageIdGenerator.incrementAndGet();
182 - responseFutures.put(messageId, response); 186 + responseFutures.put(messageId, callback);
183 DefaultMessage message = new DefaultMessage(messageId, localEp, type, payload); 187 DefaultMessage message = new DefaultMessage(messageId, localEp, type, payload);
184 - try { 188 + return sendAsync(ep, message).whenComplete((r, e) -> {
185 - sendAsync(ep, message); 189 + if (e != null) {
186 - } catch (Exception e) {
187 responseFutures.invalidate(messageId); 190 responseFutures.invalidate(messageId);
188 - response.completeExceptionally(e);
189 } 191 }
190 - return response; 192 + }).thenCompose(v -> response);
193 + }
194 +
195 + @Override
196 + public CompletableFuture<byte[]> sendAndReceive(
197 + Endpoint ep,
198 + String type,
199 + byte[] payload) {
200 + return sendAndReceive(ep, type, payload, MoreExecutors.directExecutor());
191 } 201 }
192 202
193 @Override 203 @Override
194 - public void registerHandler(String type, Consumer<byte[]> handler, Executor executor) { 204 + public void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor) {
195 - handlers.put(type, message -> executor.execute(() -> handler.accept(message.payload()))); 205 + handlers.put(type, message -> executor.execute(() -> handler.accept(message.sender(), message.payload())));
196 } 206 }
197 207
198 @Override 208 @Override
199 - public void registerHandler(String type, Function<byte[], byte[]> handler, Executor executor) { 209 + public void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor) {
200 handlers.put(type, message -> executor.execute(() -> { 210 handlers.put(type, message -> executor.execute(() -> {
201 - byte[] responsePayload = handler.apply(message.payload()); 211 + byte[] responsePayload = handler.apply(message.sender(), message.payload());
202 if (responsePayload != null) { 212 if (responsePayload != null) {
203 DefaultMessage response = new DefaultMessage(message.id(), 213 DefaultMessage response = new DefaultMessage(message.id(),
204 localEp, 214 localEp,
...@@ -212,9 +222,9 @@ public class IOLoopMessaging implements MessagingService { ...@@ -212,9 +222,9 @@ public class IOLoopMessaging implements MessagingService {
212 } 222 }
213 223
214 @Override 224 @Override
215 - public void registerHandler(String type, Function<byte[], CompletableFuture<byte[]>> handler) { 225 + public void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler) {
216 handlers.put(type, message -> { 226 handlers.put(type, message -> {
217 - handler.apply(message.payload()).whenComplete((result, error) -> { 227 + handler.apply(message.sender(), message.payload()).whenComplete((result, error) -> {
218 if (error == null) { 228 if (error == null) {
219 DefaultMessage response = new DefaultMessage(message.id(), 229 DefaultMessage response = new DefaultMessage(message.id(),
220 localEp, 230 localEp,
...@@ -239,10 +249,10 @@ public class IOLoopMessaging implements MessagingService { ...@@ -239,10 +249,10 @@ public class IOLoopMessaging implements MessagingService {
239 String type = message.type(); 249 String type = message.type();
240 if (REPLY_MESSAGE_TYPE.equals(type)) { 250 if (REPLY_MESSAGE_TYPE.equals(type)) {
241 try { 251 try {
242 - CompletableFuture<byte[]> futureResponse = 252 + Callback callback =
243 responseFutures.getIfPresent(message.id()); 253 responseFutures.getIfPresent(message.id());
244 - if (futureResponse != null) { 254 + if (callback != null) {
245 - futureResponse.complete(message.payload()); 255 + callback.complete(message.payload());
246 } else { 256 } else {
247 log.warn("Received a reply for message id:[{}]. " 257 log.warn("Received a reply for message id:[{}]. "
248 + " from {}. But was unable to locate the" 258 + " from {}. But was unable to locate the"
...@@ -331,4 +341,23 @@ public class IOLoopMessaging implements MessagingService { ...@@ -331,4 +341,23 @@ public class IOLoopMessaging implements MessagingService {
331 return stream.isClosed(); 341 return stream.isClosed();
332 } 342 }
333 } 343 }
344 +
345 +
346 + private final class Callback {
347 + private final CompletableFuture<byte[]> future;
348 + private final Executor executor;
349 +
350 + public Callback(CompletableFuture<byte[]> future, Executor executor) {
351 + this.future = future;
352 + this.executor = executor;
353 + }
354 +
355 + public void complete(byte[] value) {
356 + executor.execute(() -> future.complete(value));
357 + }
358 +
359 + public void completeExceptionally(Throwable error) {
360 + executor.execute(() -> future.completeExceptionally(error));
361 + }
362 + }
334 } 363 }
......