SAMZA-1638: Recreate SystemProducer on KafkaCheckpointManager.writeCheckpoint failures.
authorShanthoosh Venkataraman <spvenkat@usc.edu>
Thu, 29 Nov 2018 19:53:39 +0000 (11:53 -0800)
committerBoris S <bshkolnik@linkedin.com>
Thu, 29 Nov 2018 19:53:39 +0000 (11:53 -0800)
Retry loop in the existing `KafkaCheckpointManager` implementation retries using the same `SystemProducer` instance on exception and does not recreate it.

When some irrecoverable exceptions occur within the `SystemProducer`, all the subsequent produce message invocations on the `SystemProducer` instance will fail. This had made the entire retry loop on `KafkaCheckpointManager` pointless.

This patch consists of the following changes:
1. This patch addresses the above problem by recreating the `SystemProducer` instance on failure and adds a unit test to verify the functionality.
2. Minor code cleanup in classes: `TestKafkaCheckpointManager` and `KafkaCheckpointManager`.

Author: Shanthoosh Venkataraman <spvenkat@usc.edu>
Author: Shanthoosh Venkataraman <svenkata@linkedin.com>

Reviewers: Dong Lin <lindong28@gmail.com>

Closes #792 from shanthoosh/kafka_checkpoint_manager_fix

samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManager.scala
samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala

index b090136..4479c2d 100644 (file)
@@ -21,7 +21,9 @@ package org.apache.samza.checkpoint.kafka
 
 import java.util.Collections
 import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicReference
 
+import com.google.common.annotations.VisibleForTesting
 import com.google.common.base.Preconditions
 import org.apache.samza.checkpoint.{Checkpoint, CheckpointManager}
 import org.apache.samza.config.{Config, JobConfig}
@@ -54,27 +56,29 @@ class KafkaCheckpointManager(checkpointSpec: KafkaStreamSpec,
                              checkpointMsgSerde: Serde[Checkpoint] = new CheckpointSerde,
                              checkpointKeySerde: Serde[KafkaCheckpointLogKey] = new KafkaCheckpointLogKeySerde) extends CheckpointManager with Logging {
 
-  var MaxRetryDurationMs = TimeUnit.MINUTES.toMillis(15);
+  var MaxRetryDurationInMillis: Long = TimeUnit.MINUTES.toMillis(15)
 
   info(s"Creating KafkaCheckpointManager for checkpointTopic:$checkpointTopic, systemName:$checkpointSystem " +
     s"validateCheckpoints:$validateCheckpoint")
 
   val checkpointSystem: String = checkpointSpec.getSystemName
   val checkpointTopic: String = checkpointSpec.getPhysicalName
-  val checkpointSsp = new SystemStreamPartition(checkpointSystem, checkpointTopic, new Partition(0))
-  val expectedGrouperFactory = new JobConfig(config).getSystemStreamPartitionGrouperFactory
+  val checkpointSsp: SystemStreamPartition = new SystemStreamPartition(checkpointSystem, checkpointTopic, new Partition(0))
+  val expectedGrouperFactory: String = new JobConfig(config).getSystemStreamPartitionGrouperFactory
 
-  val systemProducer = systemFactory.getProducer(checkpointSystem, config, metricsRegistry)
   val systemConsumer = systemFactory.getConsumer(checkpointSystem, config, metricsRegistry)
   val systemAdmin = systemFactory.getAdmin(checkpointSystem, config)
 
-  var taskNames = Set[TaskName]()
-  var taskNamesToCheckpoints: Map[TaskName, Checkpoint] = null
+  var taskNames: Set[TaskName] = Set[TaskName]()
+  var taskNamesToCheckpoints: Map[TaskName, Checkpoint] = _
+
+  val producerRef: AtomicReference[SystemProducer] = new AtomicReference[SystemProducer](getSystemProducer())
+  val producerCreationLock: Object = new Object
 
   /**
     * Create checkpoint stream prior to start.
     */
-  override def createResources = {
+  override def createResources(): Unit = {
     Preconditions.checkNotNull(systemAdmin)
 
     systemAdmin.start()
@@ -92,18 +96,16 @@ class KafkaCheckpointManager(checkpointSpec: KafkaStreamSpec,
   /**
     * @inheritdoc
     */
-  override def start {
-    Preconditions.checkNotNull(systemProducer)
-    Preconditions.checkNotNull(systemConsumer)
-
+  override def start(): Unit = {
     // register and start a producer for the checkpoint topic
-    systemProducer.start
+    info("Starting the checkpoint SystemProducer")
+    producerRef.get().start()
 
     // register and start a consumer for the checkpoint topic
     val oldestOffset = getOldestOffset(checkpointSsp)
-    info(s"Starting checkpoint SystemConsumer from oldest offset $oldestOffset")
+    info(s"Starting the checkpoint SystemConsumer from oldest offset $oldestOffset")
     systemConsumer.register(checkpointSsp, oldestOffset)
-    systemConsumer.start
+    systemConsumer.start()
   }
 
   /**
@@ -111,7 +113,7 @@ class KafkaCheckpointManager(checkpointSpec: KafkaStreamSpec,
     */
   override def register(taskName: TaskName) {
     debug(s"Registering taskName: $taskName")
-    systemProducer.register(taskName.getTaskName)
+    producerRef.get().register(taskName.getTaskName)
     taskNames += taskName
   }
 
@@ -156,53 +158,77 @@ class KafkaCheckpointManager(checkpointSpec: KafkaStreamSpec,
     }
 
     val envelope = new OutgoingMessageEnvelope(checkpointSsp, keyBytes, msgBytes)
-    val retryBackoff: ExponentialSleepStrategy = new ExponentialSleepStrategy
 
-    val startTime = System.currentTimeMillis()
-    retryBackoff.run(
-      loop => {
-        systemProducer.send(taskName.getTaskName, envelope)
-        systemProducer.flush(taskName.getTaskName) // make sure it is written
+    // Used for exponential backoff retries on failure in sending messages through producer.
+    val startTimeInMillis: Long = System.currentTimeMillis()
+    var sleepTimeInMillis: Long = 1000
+    val maxSleepTimeInMillis: Long = 10000
+    var producerException: Exception = null
+    while ((System.currentTimeMillis() - startTimeInMillis) <= MaxRetryDurationInMillis) {
+      val currentProducer = producerRef.get()
+      try {
+        currentProducer.send(taskName.getTaskName, envelope)
+        currentProducer.flush(taskName.getTaskName) // make sure it is written
         debug(s"Wrote checkpoint: $checkpoint for task: $taskName")
-        loop.done
-      },
-
-      (exception, loop) => {
-        if ((System.currentTimeMillis() - startTime) >= MaxRetryDurationMs) {
-          error(s"Exhausted $MaxRetryDurationMs milliseconds when writing checkpoint: $checkpoint for task: $taskName.")
-          throw new SamzaException(s"Exception when writing checkpoint: $checkpoint for task: $taskName.", exception)
-        } else {
+        return
+      } catch {
+        case exception: Exception => {
+          producerException = exception
           warn(s"Retrying failed checkpoint write to key: $key, checkpoint: $checkpoint for task: $taskName", exception)
+          // TODO: Remove this producer recreation logic after SAMZA-1393.
+          val newProducer: SystemProducer = getSystemProducer()
+          producerCreationLock.synchronized {
+            if (producerRef.compareAndSet(currentProducer, newProducer)) {
+              info(s"Stopping the checkpoint SystemProducer")
+              currentProducer.stop()
+              info(s"Recreating the checkpoint SystemProducer")
+              // SystemProducer contract is that clients call register(taskName) followed by start
+              // before invoking writeCheckpoint, readCheckpoint API. Hence list of taskName are not
+              // expected to change during the producer recreation.
+              for (taskName <- taskNames) {
+                debug(s"Registering the taskName: $taskName with SystemProducer")
+                newProducer.register(taskName.getTaskName)
+              }
+              newProducer.start()
+            } else {
+              info("Producer instance was recreated by other thread. Retrying with it.")
+              newProducer.stop()
+            }
+          }
         }
       }
-    )
+      sleepTimeInMillis = Math.min(sleepTimeInMillis * 2, maxSleepTimeInMillis)
+      Thread.sleep(sleepTimeInMillis)
+    }
+    throw new SamzaException(s"Exception when writing checkpoint: $checkpoint for task: $taskName.", producerException)
   }
 
   /**
     * @inheritdoc
     */
-  override def clearCheckpoints: Unit = {
+  override def clearCheckpoints(): Unit = {
     info("Clear checkpoint stream %s in system %s" format(checkpointTopic, checkpointSystem))
     systemAdmin.clearStream(checkpointSpec)
   }
 
-  override def stop = {
+  override def stop(): Unit = {
+    info ("Stopping system admin.")
     systemAdmin.stop()
 
-    if (systemProducer != null) {
-      systemProducer.stop
-    } else {
-      error("Checkpoint SystemProducer should not be null")
-    }
+    info ("Stopping system producer.")
+    producerRef.get().stop()
+
+    info("Stopping system consumer.")
+    systemConsumer.stop()
 
-    if (systemConsumer != null) {
-      systemConsumer.stop
-    } else {
-      error("Checkpoint SystemConsumer should not be null")
-    }
     info("CheckpointManager stopped.")
   }
 
+  @VisibleForTesting
+  def getSystemProducer(): SystemProducer = {
+    systemFactory.getProducer(checkpointSystem, config, metricsRegistry)
+  }
+
   /**
     * Returns the checkpoints from the log.
     *
@@ -284,11 +310,11 @@ class KafkaCheckpointManager(checkpointSpec: KafkaStreamSpec,
       throw new SamzaException(s"Got null metadata for system:$checkpointSystem, topic:$topic")
     }
 
-    val partitionMetaData = checkpointMetadata.getSystemStreamPartitionMetadata().get(partition)
+    val partitionMetaData = checkpointMetadata.getSystemStreamPartitionMetadata.get(partition)
     if (partitionMetaData == null) {
       throw new SamzaException(s"Got a null partition metadata for system:$checkpointSystem, topic:$topic")
     }
 
-    return partitionMetaData.getOldestOffset
+    partitionMetaData.getOldestOffset
   }
 }
index 392670b..5abbea9 100644 (file)
@@ -64,13 +64,44 @@ class TestKafkaCheckpointManager extends KafkaServerTestHarness {
   }
 
   override def generateConfigs() = {
-    val props = TestUtils.createBrokerConfigs(numBrokers, zkConnect, true)
+    val props = TestUtils.createBrokerConfigs(numBrokers, zkConnect, enableControlledShutdown = true)
     // do not use relative imports
     props.map(_root_.kafka.server.KafkaConfig.fromProps)
   }
 
+  def testWriteCheckpointShouldRecreateSystemProducerOnFailure(): Unit = {
+    val checkpointTopic = "checkpoint-topic-2"
+    val mockKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])
+
+    class MockSystemFactory extends KafkaSystemFactory {
+      override def getProducer(systemName: String, config: Config, registry: MetricsRegistry): SystemProducer = {
+        mockKafkaProducer
+      }
+    }
+
+    Mockito.doThrow(new RuntimeException()).when(mockKafkaProducer).flush(taskName.getTaskName)
+
+    val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
+    val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
+    val checkPointManager = Mockito.spy(new KafkaCheckpointManager(spec, new MockSystemFactory, false, config, new NoOpMetricsRegistry))
+    val newKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])
+    checkPointManager.MaxRetryDurationInMillis = 1
+
+    Mockito.doReturn(newKafkaProducer).when(checkPointManager).getSystemProducer()
+
+    checkPointManager.register(taskName)
+    checkPointManager.start
+    checkPointManager.writeCheckpoint(taskName, new Checkpoint(ImmutableMap.of()))
+
+    // Verifications after the test
+
+    Mockito.verify(mockKafkaProducer).stop()
+    Mockito.verify(newKafkaProducer).register(taskName.getTaskName)
+    Mockito.verify(newKafkaProducer).start()
+  }
+
   @Test
-  def testCheckpointShouldBeNullIfCheckpointTopicDoesNotExistShouldBeCreatedOnWriteAndShouldBeReadableAfterWrite {
+  def testCheckpointShouldBeNullIfCheckpointTopicDoesNotExistShouldBeCreatedOnWriteAndShouldBeReadableAfterWrite(): Unit = {
     val checkpointTopic = "checkpoint-topic-1"
     val kcm1 = createKafkaCheckpointManager(checkpointTopic)
     kcm1.register(taskName)
@@ -101,7 +132,7 @@ class TestKafkaCheckpointManager extends KafkaServerTestHarness {
   }
 
   @Test(expected = classOf[SamzaException])
-  def testWriteCheckpointShouldRetryFiniteTimesOnFailure: Unit = {
+  def testWriteCheckpointShouldRetryFiniteTimesOnFailure(): Unit = {
     val checkpointTopic = "checkpoint-topic-2"
     val mockKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])
 
@@ -116,7 +147,7 @@ class TestKafkaCheckpointManager extends KafkaServerTestHarness {
     val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
     val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
     val checkPointManager = new KafkaCheckpointManager(spec, new MockSystemFactory, false, config, new NoOpMetricsRegistry)
-    checkPointManager.MaxRetryDurationMs = 1
+    checkPointManager.MaxRetryDurationInMillis = 1
 
     checkPointManager.register(taskName)
     checkPointManager.start
@@ -124,7 +155,7 @@ class TestKafkaCheckpointManager extends KafkaServerTestHarness {
   }
 
   @Test
-  def testFailOnTopicValidation {
+  def testFailOnTopicValidation(): Unit = {
     // By default, should fail if there is a topic validation error
     val checkpointTopic = "eight-partition-topic";
     val kcm1 = createKafkaCheckpointManager(checkpointTopic)
@@ -153,7 +184,7 @@ class TestKafkaCheckpointManager extends KafkaServerTestHarness {
   }
 
   @After
-  override def tearDown() {
+  override def tearDown(): Unit = {
     if (servers != null) {
       servers.foreach(_.shutdown())
       servers.foreach(server => CoreUtils.delete(server.config.logDirs))
@@ -161,17 +192,7 @@ class TestKafkaCheckpointManager extends KafkaServerTestHarness {
     super.tearDown
   }
 
-  private def getCheckpointProducerProperties() : Properties = {
-    val defaultSerializer = classOf[ByteArraySerializer].getCanonicalName
-    val props = new Properties()
-    props.putAll(ImmutableMap.of(
-      ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList,
-      ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, defaultSerializer,
-      ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, defaultSerializer))
-    props
-  }
-
-  private def getConfig() : Config = {
+  private def getConfig(): Config = {
     new MapConfig(new ImmutableMap.Builder[String, String]()
       .put(JobConfig.JOB_NAME, "some-job-name")
       .put(JobConfig.JOB_ID, "i001")
@@ -207,7 +228,7 @@ class TestKafkaCheckpointManager extends KafkaServerTestHarness {
     checkpoint
   }
 
-  private def writeCheckpoint(checkpointTopic: String, taskName: TaskName, checkpoint: Checkpoint) = {
+  private def writeCheckpoint(checkpointTopic: String, taskName: TaskName, checkpoint: Checkpoint): Unit = {
     val kcm = createKafkaCheckpointManager(checkpointTopic)
     kcm.register(taskName)
     kcm.start