summaryrefslogtreecommitdiffstats
path: root/Process/Supervisor.hs
blob: bc2a5a7d34fea12b7f2f740f0195e3604918b361 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
{-|
Module      : Process.Supervisor
Description : Hardened supervisor with single-coordinator state machine to avoid races.
-}
module Process.Supervisor
  ( Process
  , ProcessOptions(..)
  , ProcessError(..)
  , SupervisionResult(..)
  , Termination(..)
  , GracePeriod(..)
  , defaultOptions
  , withProcess
  , waitProcess
  , pollResult
  , getAdvisoryPid
  , sendSignal
  , requestStdinClose
  ) where

import Control.Concurrent (MVar, newEmptyMVar, tryPutMVar, tryReadMVar, readMVar)
import Control.Concurrent.Async (async, cancel, waitCatch)
import Control.Exception
  ( bracket, catch, try, SomeException, throwIO, Exception
  , fromException, mask, asyncExceptionFromException, SomeAsyncException
  )
import Control.Monad (void, when)
import Data.ByteString qualified as BS
import Data.Int (Int64)
import Data.Maybe (isJust)
import Foreign.C.Error (Errno(Errno), ePIPE)
import GHC.Generics (Generic)
import GHC.IO.Exception (IOErrorType(..), IOException(..))
import System.Exit (ExitCode(..))
import System.IO (Handle, hClose, hFlush, hIsEOF, hPutStrLn, hWaitForInput, stderr)
import System.IO.Error (ioeGetErrorType, isEOFError)
import System.Posix.Signals (Signal, signalProcess, sigKILL)
import System.Posix.Types (CPid)
import System.Process qualified as P
import System.Timeout (timeout)

-------------------------------------------------------------------------------
-- Types
-------------------------------------------------------------------------------

data Termination = Clean | Escalated | Indeterminate
  deriving (Show, Eq, Generic)

data SupervisionResult = SupervisionResult
  { srExitCode    :: !(Maybe ExitCode)
  , srTermination :: !Termination
  -- optional exception observed by the coordinator; may be used for diagnostics
  , srException   :: !(Maybe SomeException)
  } deriving (Show, Generic)

-- Provide Eq by comparing exit code, termination, and stringified exception
instance Eq SupervisionResult where
  a == b =
    srExitCode a == srExitCode b &&
    srTermination a == srTermination b &&
    fmap show (srException a) == fmap show (srException b)

newtype GracePeriod = GracePeriod { getGracePeriodMicros :: Int64 }
  deriving (Show, Eq, Ord)

data ProcessError
  = PipeCreationFailed String
  | WorkerDied String String SomeException
  | ProcessStartFailed SomeException
  | InvalidOptions String
  deriving (Show, Generic)

instance Exception ProcessError

data Process = Process
  { phExitVar :: !(MVar SupervisionResult)  -- authoritative final result (written only by coordinator)
  , phStopIn  :: !(MVar ())                 -- request stdin close
  , phHandle  :: !P.ProcessHandle
  }

data ProcessOptions = ProcessOptions
  { optFlushStdin   :: !Bool
  , optGracePeriod  :: !GracePeriod
  , optBufferSize   :: !Int
  , optCmd          :: !FilePath
  , optArgs         :: ![String]
  , optEnv          :: !(Maybe [(String, String)])
  , optCwd          :: !(Maybe FilePath)
  }

defaultOptions :: FilePath -> [String] -> ProcessOptions
defaultOptions cmd args = ProcessOptions
  { optFlushStdin   = True
  , optGracePeriod  = GracePeriod 5_000_000
  , optBufferSize   = 4096
  , optCmd          = cmd
  , optArgs         = args
  , optEnv          = Nothing
  , optCwd          = Nothing
  }

-------------------------------------------------------------------------------
-- withProcess: start process, workers, coordinator
-------------------------------------------------------------------------------

withProcess
  :: ProcessOptions
  -> IO (Maybe BS.ByteString)         -- get input to feed to child's stdin
  -> (BS.ByteString -> IO ())         -- stdout consumer
  -> (BS.ByteString -> IO ())         -- stderr consumer
  -> (Process -> IO a)                -- user action
  -> IO a
withProcess opts getIn putOut putErr action = do
  validateOptions opts
  let cp = (P.proc (optCmd opts) (optArgs opts))
             { P.env = optEnv opts, P.cwd = optCwd opts
             , P.std_in = P.CreatePipe, P.std_out = P.CreatePipe, P.std_err = P.CreatePipe
             }

  -- coordination vars
  finalVar    <- newEmptyMVar        -- authoritative final result (coordinator writes exactly once)
  stopIn      <- newEmptyMVar
  stdinClosed <- newEmptyMVar        -- worker sets this when it closes stdin

  -- protect createProcess + immediate registration
  mask \restore -> do
    (mhin, mhout, mherr, ph) <- (P.createProcess cp) `catchSync` (throwIO . ProcessStartFailed)

    -- validate pipes
    case (mhin, mhout, mherr) of
      (Just hin, Just hout, Just herr) -> do
        -- start coordinator (single authoritative writer)
        coordA <- async (coordinator opts ph finalVar)

        -- start workers as Asyncs so we can cancel them deterministically
        stdinA <- async (stdinWorker (optFlushStdin opts) hin getIn stopIn stdinClosed)
        stdoutA <- async (streamWorker "stdout" hout (optBufferSize opts) putOut)
        stderrA <- async (streamWorker "stderr" herr (optBufferSize opts) putErr)

        -- release action: cancel workers in safe order, wait for coordinator, cleanup
        let release = mask \_ -> do
              -- request stdin worker to stop
              void $ tryPutMVar stopIn ()

              -- cancel stdin worker first to interrupt blocked writes
              cancel stdinA
              void (waitCatch stdinA)

              -- close stdin only if worker didn't already close it
              mClosed <- tryReadMVar stdinClosed
              case mClosed of
                Just _  -> pure ()
                Nothing -> do
                  r <- trySync (hClose hin)
                  case r of
                    Left e -> hPutStrLn stderr $ "Process.release: hClose hin failed: " ++ show e
                    Right _ -> pure ()

              -- cancel stdout/stderr workers and wait for them
              cancel stdoutA
              cancel stderrA
              void (waitCatch stdoutA)
              void (waitCatch stderrA)

              -- wait for coordinator to finish (it will publish finalVar)
              void (waitCatch coordA)

              -- close remaining handles (best-effort)
              r1 <- trySync (hClose hout)
              case r1 of
                Left e -> hPutStrLn stderr $ "Process.release: hClose hout failed: " ++ show e
                Right _ -> pure ()
              r2 <- trySync (hClose herr)
              case r2 of
                Left e -> hPutStrLn stderr $ "Process.release: hClose herr failed: " ++ show e
                Right _ -> pure ()

        -- run user action under bracket so release runs
        bracket (pure ()) (const release) \_ -> do
          restore (action (Process finalVar stopIn ph)) `catch` \(e :: SomeException) -> do
            -- ensure workers cancelled; release will also run
            cancel coordA
            cancel stdinA
            cancel stdoutA
            cancel stderrA
            throwIO e

      _ -> throwIO $ PipeCreationFailed "createProcess returned missing pipe(s)"

-------------------------------------------------------------------------------
-- Coordinator state machine
--
-- Responsibilities:
--  - Observe process exit directly via waitForProcess with timeouts
--  - Enforce grace periods deterministically
--  - Perform TERM -> wait -> KILL escalation using ProcessHandle APIs
--  - Atomically publish authoritative final result into finalVar
-------------------------------------------------------------------------------

coordinator
  :: ProcessOptions
  -> P.ProcessHandle
  -> MVar SupervisionResult     -- finalVar (coordinator writes authoritative final result here)
  -> IO ()
coordinator opts ph finalVar =
  mask \restore -> do
    let GracePeriod gpMicros = optGracePeriod opts
        publishFinal mExit term mExc =
          -- publication itself should be uninterruptible
          void $ tryPutMVar finalVar (SupervisionResult mExit term mExc)

        -- classify termination based on whether TERM/KILL were attempted and whether an exit code is known
        classifyTerm didSendTerm didKill mExit =
          case (didSendTerm || didKill, didKill, mExit) of
            (False, False, Just _) -> Clean
            (True,  _,     Just _) -> Escalated
            (_,     True, Nothing) -> Escalated
            _                      -> Indeterminate

    -- 1) First chance: wait for natural exit within grace period
    mExit1 <- restore (timeout (fromIntegral gpMicros) (P.waitForProcess ph))
    case mExit1 of
      Just ec -> publishFinal (Just ec) (classifyTerm False False (Just ec)) Nothing
      Nothing -> do
        -- 2) No exit observed: attempt graceful termination
        termResult <- restore (trySync (P.terminateProcess ph))
        let didSendTerm = either (const False) (const True) termResult

        -- 3) Wait again for exit within grace period
        mExit2 <- restore (timeout (fromIntegral gpMicros) (P.waitForProcess ph))
        case mExit2 of
          Just ec2 -> publishFinal (Just ec2) (classifyTerm didSendTerm False (Just ec2)) Nothing
          Nothing -> do
            -- 4) Still no exit: attempt SIGKILL if safe
            didKill <- restore (safeKillUsingHandle ph)
            -- 5) Wait short time for exit to appear
            mExit3 <- restore (timeout 1_000_000 (P.waitForProcess ph))
            case mExit3 of
              Just ec3 -> publishFinal (Just ec3) (classifyTerm didSendTerm didKill (Just ec3)) Nothing
              Nothing -> do
                -- 6) No exit ever observed: fall back to getProcessExitCode
                mExit <- restore (P.getProcessExitCode ph)
                publishFinal mExit (classifyTerm didSendTerm didKill mExit) Nothing

-------------------------------------------------------------------------------
-- Reaper: removed in favor of a single deterministic coordinator that owns
-- waitForProcess and classification. This avoids timing races between threads.
-------------------------------------------------------------------------------

-------------------------------------------------------------------------------
-- Workers
-------------------------------------------------------------------------------

-- stdinWorker:
--   - Repeatedly obtains input chunks from getIn and writes them to the child's stdin.
--   - Honors requestStdinClose cooperatively: if stopVar is set at any point
--     before a write, no further writes occur and stdin is closed exactly once.
--   - When getIn returns Nothing, stdin is closed immediately to signal EOF.
--   - getIn is never interrupted; if a stop is requested while getIn is running,
--     the returned chunk (if any) is discarded and stdin is closed immediately.
stdinWorker :: Bool -> Handle -> IO (Maybe BS.ByteString) -> MVar () -> MVar () -> IO ()
stdinWorker flushStdin h getIn stopVar stdinClosed = loop `catch` handleAsync where
  loop =
    tryReadMVar stopVar >>= \case
      Just _ ->
        -- Stop requested before reading input
        closeStdin
      Nothing -> do
        mbs <- getIn `catchSync` \e -> if isBenignReadError e then pure Nothing else throwIO (WorkerDied "stdin" "input" e)
        case mbs of
          Nothing ->
            -- EOF from producer
            closeStdin
          Just bs -> do
            -- Stop might have been requested during or after the read, but since we have the chunk, we're writing it.
            (BS.hPut h bs >> when flushStdin (hFlush h)) `catchSync` \e ->
              if isBenignWriteError e then pure () else throwIO (WorkerDied "stdin" "output" e)
            loop

  -- Close stdin and record that it was closed
  closeStdin = do
    void $ trySync (hClose h)
    void $ tryPutMVar stdinClosed ()

  handleAsync :: SomeException -> IO ()
  handleAsync e = case fromException e of
    Just (_ :: SomeAsyncException) -> pure ()
    Nothing -> throwIO e

-- streamWorker: read loop that treats empty read as non-EOF and waits for input to avoid busy loop.
streamWorker :: String -> Handle -> Int -> (BS.ByteString -> IO ()) -> IO ()
streamWorker name h bufSize putAct = loop `catch` handleAsync where
  loop = do
    chunk <- BS.hGetSome h bufSize `catchSync` \e -> if isBenignReadError e then pure BS.empty else throwIO (WorkerDied name "read" e)
    if BS.null chunk
      then do
        eof <- hIsEOF h
        if eof
          then pure ()
          else do
            -- wait briefly for input to avoid busy loop; hWaitForInput returns immediately if input available
            _ <- hWaitForInput h 10  -- wait up to 10ms
            loop
      else do
        putAct chunk `catchSync` \e -> throwIO (WorkerDied name "write" e)
        loop

  handleAsync :: SomeException -> IO ()
  handleAsync e = case fromException e of
    Just (_ :: SomeAsyncException) -> pure ()
    Nothing -> throwIO e

-------------------------------------------------------------------------------
-- API helpers
-------------------------------------------------------------------------------

sendSignal :: Process -> Signal -> IO ()
sendSignal p sig = P.getProcessExitCode (phHandle p) >>= \case
  Just _  -> pure ()
  Nothing -> getAdvisoryPid p >>= \case
    Nothing  -> pure ()
    Just pid -> void $ trySync (signalProcess sig pid)

waitProcess :: Process -> IO SupervisionResult
waitProcess p = readMVar (phExitVar p)

pollResult :: Process -> IO (Maybe SupervisionResult)
pollResult p = tryReadMVar (phExitVar p)

-- | Return the OS PID associated with the underlying process handle, if any.
--
-- This PID is strictly advisory:
--
--   * It may be 'Nothing' if the backend does not expose a PID.
--   * It may refer to a process that has already exited.
--   * Due to PID recycling, it may even refer to an unrelated process
--     if the original child has exited and the OS has reused the PID.
--
-- Callers MUST NOT treat this PID as a stable or authoritative identity.
-- It is suitable only for best-effort diagnostics or signalling in
-- contexts where PID recycling is acceptable and documented.
getAdvisoryPid :: Process -> IO (Maybe CPid)
getAdvisoryPid = P.getPid . phHandle

requestStdinClose :: Process -> IO ()
requestStdinClose p = void $ tryPutMVar (phStopIn p) ()

-------------------------------------------------------------------------------
-- safeKill: re-check handle liveness and use ProcessHandle's PID if available.
-- This reduces but cannot eliminate PID-recycling risk; documented limitation remains.
-------------------------------------------------------------------------------

-- `safeKillUsingHandle` attempts to send `SIGKILL` to the process associated with the `ProcessHandle`.
-- Due to OS‑level PID recycling, there is an unavoidable race: if the child exits and its PID is reused by another
-- process between checks, the signal may be delivered to the wrong process.  This is a fundamental limitation of
-- PID‑based signalling and cannot be fully eliminated without kernel support.
safeKillUsingHandle :: P.ProcessHandle -> IO Bool
safeKillUsingHandle ph =
  -- If process already exited, nothing to do
  P.getProcessExitCode ph >>= \case
    Just _ -> pure False
    Nothing -> do
      -- Ask ProcessHandle for pid (may be Nothing)
      mpid <- P.getPid ph
      case mpid of
        Nothing -> pure False
        Just pid -> do
          -- Re-check handle liveness before signalling
          P.getProcessExitCode ph >>= \case
            Just _ -> pure False
            Nothing -> trySync (signalProcess sigKILL pid) >>= \case
              Left _  -> pure False
              Right _ -> pure True

-- XX once P.getPidFd and pidfdSendSignal :: Fd -> Signal -> IO () exists:
--safeKillUsingHandle ph = do
--  mExit <- P.getProcessExitCode ph
--  case mExit of
--    Just _  -> pure False
--    Nothing ->
--      case P.getPidFd ph of
--        Nothing     -> pure False
--        Just pidfd  -> do
--          r <- trySync (pidfdSendSignal pidfd sigKILL)
--          pure (either (const False) (const True) r)

-------------------------------------------------------------------------------
-- Error helpers and benign classification
-------------------------------------------------------------------------------

isAsync :: SomeException -> Bool
isAsync e = isJust (asyncExceptionFromException e :: Maybe SomeAsyncException)

catchSync :: IO a -> (SomeException -> IO a) -> IO a
catchSync act h = mask \restore ->
  restore act `catch` \e ->
    if isAsync e
      then throwIO e
      else restore (h e)

trySync :: IO a -> IO (Either SomeException a)
trySync act = mask \restore -> do
  r <- try (restore act)
  case r of
    Left e | isAsync e -> throwIO e
           | otherwise -> pure (Left e)
    Right x -> pure (Right x)

-- Conservative benign classification:
-- - Reads: EOF or ResourceVanished may be benign (EOF).
-- - Writes: treat EOF/ResourceVanished conservatively; if you need errno-level EPIPE detection,
--   convert the handler to an IO-based check and inspect errno.
isBenignReadError :: SomeException -> Bool
isBenignReadError se = case fromException se of
  Just (ioe :: IOException) -> isEOFError ioe || ioeGetErrorType ioe == ResourceVanished
  Nothing -> False

isBenignWriteError :: SomeException -> Bool
isBenignWriteError se = case fromException se of
  Just ioe ->
    case ioe_errno ioe of
      Just errno -> Errno errno == ePIPE
      Nothing    -> False
  Nothing -> False

-------------------------------------------------------------------------------
-- Misc helpers
-------------------------------------------------------------------------------

validateOptions :: ProcessOptions -> IO ()
validateOptions opts = do
  let gp = getGracePeriodMicros (optGracePeriod opts)
      buf = optBufferSize opts
      maxBuf = 16 * 1024 * 1024  -- 16MB, arbitrary but sane upper bound

  when (null $ optCmd opts) $
    throwIO $ InvalidOptions "optCmd cannot be empty"

  when (buf <= 0) $
    throwIO $ InvalidOptions "optBufferSize must be > 0"

  when (buf > maxBuf) $
    throwIO $ InvalidOptions "optBufferSize is unreasonably large"

  when (gp < 0) $
    throwIO $ InvalidOptions "optGracePeriod must be non-negative"