SAMZA-1361; OperatorImplGraph is using wrong keys to store/retrieve OperatorImpl...
[samza.git] / samza-core / src / test / java / org / apache / samza / operators / TestJoinOperator.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 package org.apache.samza.operators;
20
21 import com.google.common.collect.ImmutableSet;
22 import org.apache.samza.Partition;
23 import org.apache.samza.application.StreamApplication;
24 import org.apache.samza.config.Config;
25 import org.apache.samza.metrics.MetricsRegistryMap;
26 import org.apache.samza.operators.functions.JoinFunction;
27 import org.apache.samza.runtime.ApplicationRunner;
28 import org.apache.samza.system.IncomingMessageEnvelope;
29 import org.apache.samza.system.OutgoingMessageEnvelope;
30 import org.apache.samza.system.StreamSpec;
31 import org.apache.samza.system.SystemStream;
32 import org.apache.samza.system.SystemStreamPartition;
33 import org.apache.samza.task.MessageCollector;
34 import org.apache.samza.task.StreamOperatorTask;
35 import org.apache.samza.task.TaskContext;
36 import org.apache.samza.task.TaskCoordinator;
37 import org.apache.samza.testUtils.TestClock;
38 import org.apache.samza.util.Clock;
39 import org.apache.samza.util.SystemClock;
40 import org.junit.Test;
41
42 import java.time.Duration;
43 import java.util.ArrayList;
44 import java.util.List;
45 import java.util.Set;
46
47 import static org.junit.Assert.assertEquals;
48 import static org.junit.Assert.assertTrue;
49 import static org.mockito.Mockito.mock;
50 import static org.mockito.Mockito.when;
51
52 public class TestJoinOperator {
53 private static final Duration JOIN_TTL = Duration.ofMinutes(10);
54
55 private final TaskCoordinator taskCoordinator = mock(TaskCoordinator.class);
56 private final Set<Integer> numbers = ImmutableSet.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
57
58 @Test
59 public void join() throws Exception {
60 StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
61 List<Integer> output = new ArrayList<>();
62 MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
63
64 // push messages to first stream
65 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
66 // push messages to second stream with same keys
67 numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator));
68
69 int outputSum = output.stream().reduce(0, (s, m) -> s + m);
70 assertEquals(110, outputSum);
71 }
72
73 @Test
74 public void joinFnInitAndClose() throws Exception {
75 TestJoinFunction joinFn = new TestJoinFunction();
76 StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(joinFn));
77 assertEquals(1, joinFn.getNumInitCalls());
78 MessageCollector messageCollector = mock(MessageCollector.class);
79
80 // push messages to first stream
81 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
82
83 // close should not be called till now
84 assertEquals(0, joinFn.getNumCloseCalls());
85 sot.close();
86
87 // close should be called from sot.close()
88 assertEquals(1, joinFn.getNumCloseCalls());
89 }
90
91 @Test
92 public void joinReverse() throws Exception {
93 StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
94 List<Integer> output = new ArrayList<>();
95 MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
96
97 // push messages to second stream
98 numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator));
99 // push messages to first stream with same keys
100 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
101
102 int outputSum = output.stream().reduce(0, (s, m) -> s + m);
103 assertEquals(110, outputSum);
104 }
105
106 @Test
107 public void joinNoMatch() throws Exception {
108 StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
109 List<Integer> output = new ArrayList<>();
110 MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
111
112 // push messages to first stream
113 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
114 // push messages to second stream with different keys
115 numbers.forEach(n -> sot.process(new SecondStreamIME(n + 100, n), messageCollector, taskCoordinator));
116
117 assertTrue(output.isEmpty());
118 }
119
120 @Test
121 public void joinNoMatchReverse() throws Exception {
122 StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
123 List<Integer> output = new ArrayList<>();
124 MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
125
126 // push messages to second stream
127 numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator));
128 // push messages to first stream with different keys
129 numbers.forEach(n -> sot.process(new FirstStreamIME(n + 100, n), messageCollector, taskCoordinator));
130
131 assertTrue(output.isEmpty());
132 }
133
134 @Test
135 public void joinRetainsLatestMessageForKey() throws Exception {
136 StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
137 List<Integer> output = new ArrayList<>();
138 MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
139
140 // push messages to first stream
141 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
142 // push messages to first stream again with same keys but different values
143 numbers.forEach(n -> sot.process(new FirstStreamIME(n, 2 * n), messageCollector, taskCoordinator));
144 // push messages to second stream with same key
145 numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator));
146
147 int outputSum = output.stream().reduce(0, (s, m) -> s + m);
148 assertEquals(165, outputSum); // should use latest messages in the first stream
149 }
150
151 @Test
152 public void joinRetainsLatestMessageForKeyReverse() throws Exception {
153 StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
154 List<Integer> output = new ArrayList<>();
155 MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
156
157 // push messages to second stream
158 numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator));
159 // push messages to second stream again with same keys but different values
160 numbers.forEach(n -> sot.process(new SecondStreamIME(n, 2 * n), messageCollector, taskCoordinator));
161 // push messages to first stream with same key
162 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
163
164 int outputSum = output.stream().reduce(0, (s, m) -> s + m);
165 assertEquals(165, outputSum); // should use latest messages in the second stream
166 }
167
168 @Test
169 public void joinRetainsMatchedMessages() throws Exception {
170 StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
171 List<Integer> output = new ArrayList<>();
172 MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
173
174 // push messages to first stream
175 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
176 // push messages to second stream with same key
177 numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator));
178
179 int outputSum = output.stream().reduce(0, (s, m) -> s + m);
180 assertEquals(110, outputSum);
181
182 output.clear();
183
184 // push messages to first stream with same keys once again.
185 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
186 int newOutputSum = output.stream().reduce(0, (s, m) -> s + m);
187 assertEquals(110, newOutputSum); // should produce the same output as before
188 }
189
190 @Test
191 public void joinRetainsMatchedMessagesReverse() throws Exception {
192 StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
193 List<Integer> output = new ArrayList<>();
194 MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
195
196 // push messages to first stream
197 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
198 // push messages to second stream with same key
199 numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator));
200
201 int outputSum = output.stream().reduce(0, (s, m) -> s + m);
202 assertEquals(110, outputSum);
203
204 output.clear();
205
206 // push messages to second stream with same keys once again.
207 numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator));
208 int newOutputSum = output.stream().reduce(0, (s, m) -> s + m);
209 assertEquals(110, newOutputSum); // should produce the same output as before
210 }
211
212 @Test
213 public void joinRemovesExpiredMessages() throws Exception {
214 TestClock testClock = new TestClock();
215 StreamOperatorTask sot = createStreamOperatorTask(testClock, new TestJoinStreamApplication(new TestJoinFunction()));
216 List<Integer> output = new ArrayList<>();
217 MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
218
219 // push messages to first stream
220 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
221
222 testClock.advanceTime(JOIN_TTL.plus(Duration.ofMinutes(1))); // 1 minute after ttl
223 sot.window(messageCollector, taskCoordinator); // should expire first stream messages
224
225 // push messages to second stream with same key
226 numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator));
227
228 assertTrue(output.isEmpty());
229 }
230
231
232 @Test
233 public void joinRemovesExpiredMessagesReverse() throws Exception {
234 TestClock testClock = new TestClock();
235 StreamOperatorTask sot = createStreamOperatorTask(testClock, new TestJoinStreamApplication(new TestJoinFunction()));
236 List<Integer> output = new ArrayList<>();
237 MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
238
239 // push messages to second stream
240 numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator));
241
242 testClock.advanceTime(JOIN_TTL.plus(Duration.ofMinutes(1))); // 1 minute after ttl
243 sot.window(messageCollector, taskCoordinator); // should expire second stream messages
244
245 // push messages to first stream with same key
246 numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
247
248 assertTrue(output.isEmpty());
249 }
250
251 private StreamOperatorTask createStreamOperatorTask(Clock clock, StreamApplication app) throws Exception {
252 ApplicationRunner runner = mock(ApplicationRunner.class);
253 when(runner.getStreamSpec("instream")).thenReturn(new StreamSpec("instream", "instream", "insystem"));
254 when(runner.getStreamSpec("instream2")).thenReturn(new StreamSpec("instream2", "instream2", "insystem2"));
255
256 TaskContext taskContext = mock(TaskContext.class);
257 when(taskContext.getSystemStreamPartitions()).thenReturn(ImmutableSet
258 .of(new SystemStreamPartition("insystem", "instream", new Partition(0)),
259 new SystemStreamPartition("insystem2", "instream2", new Partition(0))));
260 when(taskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
261
262 Config config = mock(Config.class);
263
264 StreamOperatorTask sot = new StreamOperatorTask(app, runner, clock);
265 sot.init(config, taskContext);
266 return sot;
267 }
268
269 private static class TestJoinStreamApplication implements StreamApplication {
270
271 private final TestJoinFunction joinFn;
272
273 TestJoinStreamApplication(TestJoinFunction joinFn) {
274 this.joinFn = joinFn;
275 }
276
277 @Override
278 public void init(StreamGraph graph, Config config) {
279 MessageStream<FirstStreamIME> inStream =
280 graph.getInputStream("instream", FirstStreamIME::new);
281 MessageStream<SecondStreamIME> inStream2 =
282 graph.getInputStream("instream2", SecondStreamIME::new);
283
284 SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
285 inStream
286 .join(inStream2, joinFn, JOIN_TTL)
287 .sink((message, messageCollector, taskCoordinator) -> {
288 messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message));
289 });
290 }
291 }
292
293 private static class TestJoinFunction implements JoinFunction<Integer, FirstStreamIME, SecondStreamIME, Integer> {
294
295 private int numInitCalls = 0;
296 private int numCloseCalls = 0;
297
298 @Override
299 public void init(Config config, TaskContext context) {
300 numInitCalls++;
301 }
302
303 @Override
304 public Integer apply(FirstStreamIME message, SecondStreamIME otherMessage) {
305 return (Integer) message.getMessage() + (Integer) otherMessage.getMessage();
306 }
307
308 @Override
309 public Integer getFirstKey(FirstStreamIME message) {
310 return (Integer) message.getKey();
311 }
312
313 @Override
314 public Integer getSecondKey(SecondStreamIME message) {
315 return (Integer) message.getKey();
316 }
317
318 @Override
319 public void close() {
320 numCloseCalls++;
321 }
322
323 public int getNumInitCalls() {
324 return numInitCalls;
325 }
326
327 public int getNumCloseCalls() {
328 return numCloseCalls;
329 }
330 }
331
332 private static class FirstStreamIME extends IncomingMessageEnvelope {
333 FirstStreamIME(Integer key, Integer message) {
334 super(new SystemStreamPartition("insystem", "instream", new Partition(0)), "1", key, message);
335 }
336 }
337
338 private static class SecondStreamIME extends IncomingMessageEnvelope {
339 SecondStreamIME(Integer key, Integer message) {
340 super(new SystemStreamPartition("insystem2", "instream2", new Partition(0)), "1", key, message);
341 }
342 }
343 }