0 | module NN.Utils
 1 |
 2 | import Data.Nat
 3 | import Misc
 4 |
 5 | public export
 6 | runActionUntilMaxSteps : Show p => Show l =>
 7 |   {default 100 printEvery : Nat} ->
 8 |   (action : p -> IO p) ->
 9 |   (maxSteps : Nat) ->
10 |   (currentStep : Nat) -> (currentValue : p) ->
11 |   (loss : p -> IO l) ->
12 |   IO p
13 | runActionUntilMaxSteps action maxSteps currStep currVal lossIO
14 |   = case currStep < maxSteps of
15 |     True => do
16 |       runIf (currStep `mod` printEvery == 0 || currStep < 10) $ do
17 |         loss <- lossIO currVal
18 |         putStrLn "Current step: \{show currStep}, loss: \{show (loss)}"
19 |         -- putStrLn "Current step: \{show currStep}, value: \{show currVal}, loss: \{show (loss)}"
20 |       result <- action currVal
21 |       runActionUntilMaxSteps {printEvery=printEvery} action maxSteps (currStep + 1) result lossIO
22 |     False => do
23 |       loss <- lossIO currVal
24 |       -- putStrLn "Max steps (\{show maxSteps}) reached. Final loss: \{show (loss)}"
25 |       putStrLn "--------------------------------------------------"
26 |       putStrLn "Max steps (\{show maxSteps}) reached. Final loss: \{show (loss)}."
27 |       putStrLn "Final parameter values: \{show currVal}."
28 |       putStrLn "--------------------------------------------------"
29 |       pure currVal
30 |