SAMZA-1261: Fix TestProcessJob flaky test
[samza.git] / samza-core / src / main / scala / org / apache / samza / job / local / ProcessJob.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19
20 package org.apache.samza.job.local
21
22 import java.util.concurrent.CountDownLatch
23
24 import org.apache.samza.coordinator.JobModelManager
25 import org.apache.samza.job.ApplicationStatus.{New, Running, SuccessfulFinish, UnsuccessfulFinish}
26 import org.apache.samza.job.{ApplicationStatus, CommandBuilder, StreamJob}
27 import org.apache.samza.util.Logging
28
29 import scala.collection.JavaConverters._
30
31 object ProcessJob {
32   private def createProcessBuilder(commandBuilder: CommandBuilder): ProcessBuilder = {
33     val processBuilder = new ProcessBuilder(commandBuilder.buildCommand.split(" ").toList.asJava)
34     processBuilder.environment.putAll(commandBuilder.buildEnvironment)
35
36     // Pipe all output to this process's streams.
37     processBuilder.redirectOutput(ProcessBuilder.Redirect.INHERIT)
38     processBuilder.redirectError(ProcessBuilder.Redirect.INHERIT)
39
40     processBuilder
41   }
42 }
43
44 class ProcessJob(commandBuilder: CommandBuilder, val jobModelManager: JobModelManager) extends StreamJob with Logging {
45
46   import ProcessJob._
47
48   val lock = new Object
49   val processBuilder: ProcessBuilder = createProcessBuilder(commandBuilder)
50   var jobStatus: ApplicationStatus = New
51   var processThread: Option[Thread] = None
52
53
54   def submit: StreamJob = {
55     val threadStartCountDownLatch = new CountDownLatch(1)
56
57     // Create a non-daemon thread to make job runner block until the job finishes.
58     // Without this, the proc dies when job runner ends.
59     processThread = Some(new Thread {
60       override def run {
61         var processExitCode = -1
62         var process: Option[Process] = None
63
64         setStatus(Running)
65
66         try {
67           threadStartCountDownLatch.countDown
68           process = Some(processBuilder.start)
69           processExitCode = process.get.waitFor
70         } catch {
71           case _: InterruptedException => process foreach { p => p.destroyForcibly }
72           case e: Exception => error("Encountered an error during job start: %s".format(e.getMessage))
73         } finally {
74           jobModelManager.stop
75           setStatus(if (processExitCode == 0) SuccessfulFinish else UnsuccessfulFinish)
76         }
77       }
78     })
79
80     info("Starting process job")
81
82     processThread.get.start
83     threadStartCountDownLatch.await
84     ProcessJob.this
85   }
86
87   def kill: StreamJob = {
88     getStatus match {
89       case Running => {
90         info("Attempting to kill running process job")
91
92         processThread foreach { thread =>
93           thread.interrupt
94           thread.join
95
96           info("Process job killed successfully")
97         }
98       }
99       case status => warn("Ignoring attempt to kill a process job that is not running. Job status is %s".format(status))
100     }
101
102     ProcessJob.this
103   }
104
105   def waitForFinish(timeoutMs: Long): ApplicationStatus = {
106     require(timeoutMs >= 0, "Timeout values must be non-negative")
107
108     processThread foreach { thread => thread.join(timeoutMs) }
109     getStatus
110   }
111
112   def waitForStatus(status: ApplicationStatus, timeoutMs: Long): ApplicationStatus = lock.synchronized {
113     require(timeoutMs >= 0, "Timeout values must be non-negative")
114
115     timeoutMs match {
116       case 0 => {
117         info("Waiting for application status %s indefinitely".format(status))
118
119         while (getStatus != status) lock.wait(0)
120       }
121       case _ => {
122         info("Waiting for application status %s for %d ms".format(status, timeoutMs))
123
124         val startTimeMs = System.currentTimeMillis
125         var remainingTimeoutMs = timeoutMs
126
127         while (getStatus != status && remainingTimeoutMs > 0) {
128           lock.wait(remainingTimeoutMs)
129
130           val elapsedWaitTimeMs = System.currentTimeMillis - startTimeMs
131           remainingTimeoutMs = timeoutMs - elapsedWaitTimeMs
132         }
133       }
134     }
135     getStatus
136   }
137
138   def getStatus: ApplicationStatus = lock.synchronized {
139     jobStatus
140   }
141
142   private def setStatus(status: ApplicationStatus): Unit = lock.synchronized {
143     info("Changing process job status from %s to %s".format(jobStatus, status))
144
145     jobStatus = status
146     lock.notify
147   }
148 }