Source code for NEDAS.job_submitters.slurm

from math import log
import os
import subprocess
import tempfile
from time import sleep
from NEDAS.utils.conversion import seconds_to_timestr
from NEDAS.utils.progress import find_keyword_in_file
from .hpc import HPCJobSubmitter

[docs] class SLURMJobSubmitter(HPCJobSubmitter): """JobSubmitter Class customized for SLURM schedulers""" MAX_NTASKS = 1000000 MAX_NNODES = 1000 MAX_PPN = 1000 def __init__(self, **kwargs): super().__init__(**kwargs) # additional slurm options self.mem_per_cpu = kwargs.get('mem_per_cpu') self.log_file = kwargs.get('log_file', None) self.stagnant_log_timeout = kwargs.get('stagnant_log_timeout', 600) @property def nproc_avail(self): if self.in_job_allocation: return int(os.environ['SLURM_NTASKS']) return self.MAX_NTASKS @property def nnode_avail(self): if self.in_job_allocation: return int(os.environ['SLURM_NNODES']) return self.MAX_NNODES @property def ppn_avail(self): if self.in_job_allocation: return int(os.environ['SLURM_TASKS_PER_NODE'].split('(')[0]) return self.MAX_PPN @property def execute_command(self): if self.nproc == 1 or self.parallel_mode == 'serial': return "" if self.in_job_allocation: if self.parallel_mode == 'mpi': return f"srun -n {self.nproc} -N {self.nnode} -r {self.offset_node} --exact --unbuffered" elif self.parallel_mode == 'openmp': return f"export OMP_NUM_THREADS={self.nproc}; srun -N 1 -r {self.offset_node} -n 1 --cpus-per-task={self.nproc} --unbuffered" else: raise ValueError(f"unknown parallel_mode '{self.parallel_mode}'") else: return f"srun -n {self.nproc} --unbuffered" @property def job_array_index_name(self): return '$SLURM_ARRAY_TASK_ID' @property def in_job_allocation(self) -> bool: if 'SLURM_JOB_ID' in os.environ: return True return False
[docs] def submit_job_and_monitor(self, commands): with tempfile.NamedTemporaryFile(mode='w+', delete=False, dir=self.run_dir, prefix=self.job_name+'.', suffix='.sh') as job_script: job_script.write("#!/bin/bash\n") # slurm job header job_script.write(f"#SBATCH --job-name={self.job_name}\n") job_script.write(f"#SBATCH --account={self.project}\n") job_script.write(f"#SBATCH --time={seconds_to_timestr(self.walltime)}\n") job_script.write(f"#SBATCH --nodes={self.nnode}\n") job_script.write(f"#SBATCH --ntasks-per-node={self.ppn}\n") if self.queue and self.queue != 'normal': job_script.write(f"#SBATCH --qos={self.queue}\n") if self.mem_per_cpu: job_script.write(f"#SBATCH --mem-per-cpu={self.mem_per_cpu}\n") if self.use_job_array: log_file = os.path.join(self.run_dir, f"{self.job_name}-%A_%a.out") else: log_file = os.path.join(self.run_dir, f"{self.job_name}-%j.out") job_script.write(f"#SBATCH --output={log_file}\n") if self.use_job_array: job_script.write(f"#SBATCH --array=1-{self.array_size}\n") # add the commands commands = super().parse_commands(commands) job_script.write(commands) job_script.write('\n') self.job_script = job_script.name # submit the job script p = subprocess.run(['sbatch', self.job_script], capture_output=True, text=True) if p.returncode != 0: raise RuntimeError(f"Failed to submit job: {p.stderr}") self.job_id = int(p.stdout.split()[-1]) self.log_file = log_file.replace('%j', str(self.job_id)) if self.debug: print(f"JobSubmitter: job '{self.job_name}' submitted with ID {self.job_id} to SLURM scheduler", flush=True) # monitor job status if self.use_job_array: while True: sleep(self.check_dt) job_finished = [] for i in range(self.array_size): p = subprocess.run(['squeue', '-h', '-j', f'{self.job_id}_{i}'], capture_output=True, text=True) if not p.stdout: job_finished.append(True) else: job_finished.append(False) if all(job_finished): break else: elapsed_time = 0 file_pointer = 0 while True: sleep(self.check_dt) p = subprocess.run(['squeue', '-h', '-j', f'{self.job_id}'], capture_output=True, text=True) if not p.stdout: # job no longer in queue break job_status = p.stdout.split()[4] if job_status not in ['R', 'PD', 'CG']: # job not running, pending, or cleaning up raise RuntimeError(f"job {self.job_name} failed with status {job_status}") if job_status == 'PD': # if job is pending in queue, keep waiting continue # if self.log_file is specified if self.log_file is None: continue elapsed_time += self.check_dt # open log file and seek to the last position with open(self.log_file, 'r') as f: f.seek(file_pointer) new_content = f.read() if new_content: print(new_content, end='', flush=True) # stream the new content to tty file_pointer = f.tell() # update file pointer to the new position elapsed_time = 0 # reset elapsed time since we have new log output # kill the job if log file remain stagnant for too long if elapsed_time > self.stagnant_log_timeout: subprocess.run(['scancel', str(self.job_id)]) raise RuntimeError(f"job {self.job_name} killed: {self.log_file} stagnent for {elapsed_time} seconds") if self.debug: print(f"JobSubmitter: job '{self.job_name}' finished", flush=True) # check log file and report errors if self.use_job_array: for i in range(self.array_size): log_file = os.path.join(self.run_dir, f"{self.job_name}-{self.job_id}_{i}.out") if not find_keyword_in_file(log_file, f"Job {self.job_id} completed"): raise RuntimeError(f"job {self.job_name} failed, check {log_file}") else: log_file = os.path.join(self.run_dir, f"{self.job_name}-{self.job_id}.out") if not find_keyword_in_file(log_file, f"Job {self.job_id} completed"): raise RuntimeError(f"job {self.job_name} failed, check {log_file}")