#!/usr/bin/python
#
# (C) 2006-2007 XenSource Ltd.
#!/usr/bin/python
#
# (C) 2006-2007 XenSource Ltd.
# Copyright (C) 2008-2010 Citrix Ltd.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation; version 2.1 only.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
#

import os
import sys
import XenAPIPlugin
sys.path.append("/opt/xensource/sm/")
import util
import time
from threading import Thread
import xmlrpclib
import datetime
import XenAPI
import random
import syslog
import flock
from contextlib import contextmanager

INITIAL_RUN_TIME = u'19700101T00:00:00Z'
VMSS_SYSLOG_FACILITY = syslog.LOG_LOCAL1
LOG_INFO = syslog.LOG_INFO
POOL_CONF_FILE = "/etc/xensource/pool.conf"
VM_THREAD_MAX = 1
POLICY_THREAD_MAX = 1
VERBOSE = True

errorcode_to_error_map = {
    'VMSS_SNAPSHOT_LOCK_FAILED': 'The snapshot phase is already executing for this snapshot policy. Please try again later',
    'VMSS_SNAPSHOT_SUCCEEDED': 'Successfully performed the snapshot phase of the snapshot policy',
    'VMSS_SNAPSHOT_FAILED': 'The snapshot phase of the snapshot policy failed',
    'VMSS_XAPI_LOGON_FAILURE':'Could not login to API session',
    'VMSS_SNAPSHOT_MISSED_EVENT': 'A scheduled snapshot event was missed due to another on-going scheduled snapshot run. This is unexpected behaviour, please re-configure your snapshot sub-policy',
}


def log_message(message, ident="VMSS", priority=LOG_INFO):
    for message_line in str(message).split('\n'):
        syslog.openlog(ident, 0, VMSS_SYSLOG_FACILITY)
        syslog.syslog(priority, "[%d] %s" % (os.getpid(), message_line))
        syslog.closelog()

class Lock:

    # Simple file-based lock on a local FS. With shared reader/writer
    # attributes. Replicating SM lock class since importing the same
    # violates design principles as suggested by Germano

    BASE_DIR = "/var/lock/vmss"

    def _open(self):
        """Create and open the lockable attribute base, if it doesn't exist.
        (But don't lock it yet.)"""

        # one directory per namespace
        self.nspath = os.path.join(Lock.BASE_DIR, self.ns)

        # the lockfile inside that namespace directory per namespace
        self.lockpath = os.path.join(self.nspath, self.name)

        number_of_enoent_retries = 10

        while True:
            self._mkdirs(self.nspath)

            try:
                self._open_lockfile()
            except IOError, e:
                # If another lock within the namespace has already
                # cleaned up the namespace by removing the directory,
                # _open_lockfile raises an ENOENT, in this case we retry.
                if e.errno == errno.ENOENT:
                    if number_of_enoent_retries > 0:
                        number_of_enoent_retries -= 1
                        continue
                raise
            break

        fd = self.lockfile.fileno()
        self.lock = flock.WriteLock(fd)

    def _open_lockfile(self):
        """Provide a seam, so extreme situations could be tested"""
        log_message("lock: opening lock file {0:s}" .format(self.lockpath))
        self.lockfile = file(self.lockpath, "w+")

    def _close(self):
        """Close the lock, which implies releasing the lock."""
        if self.lockfile is not None:
            if self.held():
                self.release()
            self.lockfile.close()
            log_message("lock: closed {0:s}" .format(self.lockpath))
            self.lockfile = None

    def _mknamespace(ns):

        if ns is None:
            return ".nil"

        assert not ns.startswith(".")
        assert ns.find(os.path.sep) < 0
        return ns
    _mknamespace = staticmethod(_mknamespace)

    def __init__(self, name, ns=None):
        self.lockfile = None

        self.ns = Lock._mknamespace(ns)

        assert not name.startswith(".")
        assert name.find(os.path.sep) < 0
        self.name = name

        self._open()

    __del__ = _close

    def cleanup(name, ns = None):
        ns = Lock._mknamespace(ns)
        path = os.path.join(Lock.BASE_DIR, ns, name)
        if os.path.exists(path):
            Lock._unlink(path)

    cleanup = staticmethod(cleanup)

    def cleanupAll(ns = None):
        ns = Lock._mknamespace(ns)
        nspath = os.path.join(Lock.BASE_DIR, ns)

        if not os.path.exists(nspath):
            return

        for file in os.listdir(nspath):
            path = os.path.join(nspath, file)
            Lock._unlink(path)

        Lock._rmdir(nspath)

    cleanupAll = staticmethod(cleanupAll)

    #
    # Lock and attribute file management
    #

    def _mkdirs(path):
        """Concurrent makedirs() catching EEXIST."""
        if os.path.exists(path):
            return
        try:
            os.makedirs(path)
        except OSError, e:
            if e.errno != errno.EEXIST:
                raise LockException("Failed to makedirs({0:s})" .format(path))
    _mkdirs = staticmethod(_mkdirs)

    def _unlink(path):
        """Non-raising unlink()."""
        log_message("lock: unlinking lock file {0:s}" .format(path))
        try:
            os.unlink(path)
        except Exception, e:
            log_message("Failed to unlink({0:s}): {1:s}" .format(path, e))
    _unlink = staticmethod(_unlink)

    def _rmdir(path):
        """Non-raising rmdir()."""
        log_message("lock: removing lock dir {0:s}" .format(path))
        try:
            os.rmdir(path)
        except Exception, e:
            log_message("Failed to rmdir({0:s}): {1:s}" .format(path, e))
    _rmdir = staticmethod(_rmdir)

    #
    # Actual Locking
    #

    def acquire(self):
        """Blocking lock aquisition, with warnings. We don't expect to lock a
        lot. If so, not to collide. Coarse log statements should be ok
        and aid debugging."""
        if not self.lock.trylock():
            log_message("Failed to lock {0:s} on first attempt, blocked by "
                        "PID {1:s}" .format(self.lockpath,
                                                 self.lock.test()))
            self.lock.lock()
        if VERBOSE:
            log_message("lock: acquired {0:s}" .format(self.lockpath))

    def acquireNoblock(self):
        """Acquire lock if possible, or return false if lock already held"""
        exists = os.path.exists(self.lockpath)
        ret = self.lock.trylock()
        if VERBOSE:
            log_message("lock: tried lock {0:s}, acquired: {1:b} (exists: {2:b})"
                    .format(self.lockpath, ret, exists))
        return ret

    def held(self):
        """True if @self acquired the lock, False otherwise."""
        return self.lock.held()

    def release(self):
        """Release a previously acquired lock."""
        self.lock.unlock()
        if VERBOSE:
            log_message("lock: released {0:s}" .format(self.lockpath))

def get_session():
    session = XenAPI.xapi_local()
    try:
        session.xenapi.login_with_password('__dom0__vmss','')
    except Exception, e:
        raise Exception("%s. Error: %s" %
                        (errorcode_to_error_map['VMSS_XAPI_LOGON_FAILURE'],
                        str(e)))
    return session

@contextmanager
def xapi_session():
    # By tightly coupling session creation and session destroy, we
    # prevent session leaks and can add extra intelligence in the
    # future to retry etc.

    session = None
    try:
        session = get_session()
        yield session
        # Do not capture any exception from yield on purpose. Let
        # it be handled outside this session manager.
    finally:
        if session is not None:
            session.xenapi.session.logout()

def destroy_snapshot(session, snap_ref):
    try:
        # Generate a list of VM VDIs so we can verify snapshot VDIs are related
        vdimap = {}
        VBDs = session.xenapi.VM.get_VBDs(snap_ref)
        for vbd in VBDs:
            if session.xenapi.VBD.get_type(vbd) == 'Disk':
                # store the vdi
                vdimap[session.xenapi.VBD.get_VDI(vbd)] = '1'

        # Now destroy the VM
        # First hard_shutdown the VM, this is required for a checkpoint
        try:
            session.xenapi.VM.hard_shutdown(snap_ref)
        except Exception, e:
            # This must be a snapshot, rather than a checkpoint
            pass

        # Now try destroying the VM, this should work for both snapshot
        # and checkpoint
        try:
            session.xenapi.VM.destroy(snap_ref)
        except Exception, e:
            log_message("Could not destroy the VM: {0:s}, error: {1:s}"
                        .format(snap_ref, e))

        for vdi in vdimap.keys():
            try:
                session.xenapi.VDI.destroy(vdi)
            except Exception, e:
                log_message("Could not destroy the vdi: {0:s}. Error: {1:s}"
                            .format(vdi, e))

        # Now attempt to destroy the VBDs, if not dont worry about it,
        # they will be GCd later
        for vbd in VBDs:
            try:
                session.xenapi.VBD.destroy(vbd)
            except:
                pass

    except Exception, e:
        log_message("Could not destroy snapshot successfully, please destroy "
                    "the snapshot {0:s} manually. Error: {1:s}" .format(snap_ref,e))

class TakeSnapshot(Thread):
    def __init__(self, session, vmss_ref, vm_ref):
        Thread.__init__(self) # init the thread
        self.session = session
        self.vmss_ref = vmss_ref
        self.vm_ref = vm_ref
        self.ret_val = str(True)

    def run(self):
        try:

            # Identify the snapshot type
            snapshot_type = self.session.xenapi.VMSS.get_type(
                                                     self.vmss_ref)

            # Now create the snapshot name
            vm = self.session.xenapi.VM.get_uuid(self.vm_ref)
            vm_name = self.session.xenapi.VM.get_name_label(self.vm_ref)
            snap_name = ("%s-%s-%s" %
                         (vm_name, vm[0:16],time.strftime("%Y%m%d-%H%M",
                          time.localtime())))
            snap_name = snap_name.replace(' ', '-')

            # Start a snapshot/checkpoint operation for the VM
            log_message("Processing VM: {0:s} " .format(vm))
            if snapshot_type == "snapshot":
                snap_ref = self.session.xenapi.VM.snapshot(
                                                self.vm_ref, snap_name)
            elif snapshot_type == "checkpoint":
                snap_ref = self.session.xenapi.VM.checkpoint(
                                                self.vm_ref, snap_name)
            elif snapshot_type == "snapshot_with_quiesce":
                snap_ref = self.session.xenapi.VM.snapshot_with_quiesce(
                    self.vm_ref, snap_name)

            # Set the snapshot name to DDMMYYYY-HHMM
            timeOfSnap = str(self.session.xenapi.VM.get_snapshot_time(
                                                snap_ref))

            snap_name = ("%s%s%s-%s%s" %
                         (timeOfSnap[0:4], timeOfSnap[4:6], timeOfSnap[6:8],
                          timeOfSnap[9:11], timeOfSnap[12:14]))
            self.session.xenapi.VM.set_name_label(snap_ref, snap_name)

            # Find the oldest snapshot and delete it if required
            snaps = self.session.xenapi.VM.get_snapshots(self.vm_ref)
            oldest = self.session.xenapi.host.get_servertime(
                        util.get_localhost_uuid(self.session))
            noOfSnaps = 0
            oldestSnap = ''
            current_retention_value = int(self.session.xenapi.VMSS.get_retained_snapshots(
                                                         self.vmss_ref))
            for snap_ref in snaps:
                if not self.session.xenapi.VM.get_is_vmss_snapshot(snap_ref):
                    continue
                else:
                    noOfSnaps += 1

                if oldest > self.session.xenapi.VM.get_snapshot_time(snap_ref):
                    oldest = self.session.xenapi.VM.get_snapshot_time(snap_ref)
                    oldestSnap = snap_ref

            # If the no of snapshots has past the retention value (it
            # should just be past by 1, if not throw an exception)
            if (noOfSnaps > current_retention_value):
                if (noOfSnaps -
                       int(self.session.xenapi.VMSS.get_retained_snapshots(
                                                         self.vmss_ref))) != 1:
                    log_message("WARNING: The difference between number of "
                                "snapshots ({0:d}) and the retention value ({1:d}) is more "
                                "than one, this is an inconsistent state. "
                                "Please contact your pool operator." .format(noOfSnaps,current_retention_value))

                if not oldestSnap:
                    raise Exception("no snapshots found older than current snapshot.")

                log_message("Snapshot retention value reached for VM: {0:s}. "
                            "Deleting the oldest snapshot: {1:s}" .format(
                    vm, self.session.xenapi.VM.get_uuid(oldestSnap)))

                destroy_snapshot(self.session, oldestSnap)

            log_message("Completed processing VM: {0:s} " .format(vm))

        except Exception, e:
            log_message("snapshot failed with exception: {0:s}" .format(e))
            self.ret_val = str(e)


    def result(self):
        return (self.ret_val, self.vm_ref)

def process_VMs(session, vmss_ref):

    vm_to_snapshot_result = {}
    ret_val = True
    task_ref = None
    task_status = "success"
    error = ''

    try:
        snapshot_schedule = session.xenapi.VMSS.get_schedule(vmss_ref)
        snapshot_schedule['frequency'] = \
            session.xenapi.VMSS.get_frequency(
                        vmss_ref)

        # Get the time before snapshots
        before = datetime.datetime.now()

        #create xapi task here
        vmss_uuid = session.xenapi.VMSS.get_uuid(vmss_ref)
        task_name = "Executing policy: " + vmss_uuid
        task_ref = session.xenapi.task.create(task_name,'')
        task_uuid = session.xenapi.task.get_uuid(task_ref)
        log_message("task: {0:s} created for policy: {1:s}" .format(task_uuid,
                                                                    vmss_uuid))

        vms = session.xenapi.VMSS.get_VMs(vmss_ref)
        no_of_vms = len(vms)

        if no_of_vms % VM_THREAD_MAX:
            no_of_batches = no_of_vms/VM_THREAD_MAX + 1
        else:
            no_of_batches = no_of_vms/VM_THREAD_MAX

        entire_list_threads = []
        log_message("No of VMs: {0:d}" .format(no_of_vms))
        log_message("Max number of threads for processing VM: {0:d}" .format(
            VM_THREAD_MAX))

        for iter in range(0,no_of_batches):
            listThreads = []
            log_message("VM Batch: {0:d}"  .format(iter))
            for vmindex in range(0,VM_THREAD_MAX):
                realIndex = iter * VM_THREAD_MAX + vmindex
                if realIndex < no_of_vms:
                    # In each of these threads
                    s = TakeSnapshot(session, vmss_ref, vms[realIndex])
                    listThreads.append(s)
                    entire_list_threads.append(s)
                else:
                    break

            # Start the batch of threads simultaneously.
            for thread in listThreads:
                thread.start()

            # Wait till all the threads in a batch have finished.
            for thread in listThreads:
                thread.join()


        # If the snapshot failed for one or more VMs generate an
        # appropriately formatted error message to return to the caller.

        for thread in entire_list_threads:
            vm_to_snapshot_result[thread.result()[1]] = thread.result()[0]
            if thread.result()[0] != str(True):
                log_message("The snapshot for VM {0:s} failed with exception "
                            "{1:s}" .format(thread.result()[1],
                                            thread.result()[0]))
                ret_val = False
                task_status = "failure"
        # Get the time after the snapshot
        after = datetime.datetime.now()

        # Get the last expected run time according to the schedule
        last_expected_run_time = \
                        get_last_expected_run_time(snapshot_schedule)

        log_message("snapshot start time: {0:s}, end time: {1:s}, "
                    "Last expected run time: {2:s}" .format(str(before),
                                                            str(after),
                                                            str(last_expected_run_time)))

        # When the number of VMs are more and scheduled interval is less there
        # are changes for the schedule policy to not get triggered since
        # the previous invocation of this schedule is still running. We will
        # have to notify this event.

        if (before < last_expected_run_time and after > last_expected_run_time):
            create_alert( session, session.xenapi.VMSS.get_uuid(vmss_ref),
                          "warn", create_structured_alert(
                    session, 'warn', {},
                    "VMSS_SNAPSHOT_MISSED_EVENT"), create_email_body(
                    session, {}, 'VMSS_SNAPSHOT_MISSED_EVENT'),
                          "VMSS_SNAPSHOT_MISSED_EVENT")

        # We reach this point only if there were no exceptions raised  hence
        # update the snapshot last executed time for future reference

        session.xenapi.VMSS.set_last_run_time(vmss_ref,xmlrpclib.DateTime(str(
        xmlrpclib.DateTime(time.mktime(datetime.datetime.utcnow().timetuple()))) + "Z"))

    except Exception, e:
        log_message("The snapshot for the schedule policy {0:s} failed with "
                    "exception {1:s}" .format(vmss_uuid, e))
        error = str(e)
        ret_val = False
        task_status = "failure"

    finally:
        if task_ref:
            session.xenapi.task.set_status(task_ref, task_status)
        return (ret_val, vm_to_snapshot_result, error)

#
#TODO: remove unwanted arguments passed by xapi
#
def schedule_snapshots( xapi1, xapi2):

    # schedule_snapshots function gets called from xapi hostcall plugin,
    # by default hostcall plugin passes two arguments to the function being called
    # so we accept (xapi1, xapi2) and ignore them

    ret_val = str(True)
    try:
        child_list = []
        with xapi_session() as session:

            # Get all VMSS objects from the system
            vmss_list = session.xenapi.VMSS.get_all()
            vmss_list = random.sample(vmss_list, len(vmss_list))

            for vmss in vmss_list:
                if not session.xenapi.VMSS.get_enabled(vmss):
                    continue
                # Handle each object in a separate thread.
                s = ProcessPolicy(vmss)
                child_list.append(s)

        # In case the list is non empty, spawn threads in batches from a child
        # process

        if child_list:
            if os.fork() == 0:

                # Place a lock to have only one instance of VMSS running at any
                # given time

                vmss_lock = create_global_lock()
                if not acquire_lock(vmss_lock):
                    raise Exception("%s" %
                            errorcode_to_error_map["VMSS_SNAPSHOT_LOCK_FAILED"])

                no_of_policies = len(child_list)

                if no_of_policies % POLICY_THREAD_MAX:
                    no_of_batches = no_of_policies/POLICY_THREAD_MAX + 1
                else:
                    no_of_batches = no_of_policies/POLICY_THREAD_MAX

                log_message("No of Policies: %s" % no_of_policies)
                log_message("Max number of threads allocated for policy : %d" %
                      POLICY_THREAD_MAX)
                for iter in range(0,no_of_batches):
                    list_threads = []
                    log_message("Policy Batch: %s" % iter)
                    for index in range(0,POLICY_THREAD_MAX):
                        real_index = iter * POLICY_THREAD_MAX + index
                        if real_index < no_of_policies:
                            # In each of these threads
                            list_threads.append(child_list[real_index])
                        else:
                            break

                    # Start all the threads simultaneously.
                    for thread in list_threads:
                        thread.start()

                    # Wait till all the threads have finished.
                    for thread in list_threads:
                        thread.join()

                if vmss_lock:
                    release_lock(vmss_lock)

    except Exception, e:
        log_message("Exception in schedule_snapshots: {0:s}" .format(e))
        ret_val = str(e)

    finally:
        return ret_val

class ProcessPolicy(Thread):
    def __init__(self, vmss_ref):
        Thread.__init__(self) # init the thread
        self.vmss_ref = vmss_ref


    def run(self):
        snapshot = False
        args = {}

        try:
            # Get the last snapshot run time for the policy from XAPI.
            with xapi_session() as session:
                vmss_uuid = session.xenapi.VMSS.get_uuid(self.vmss_ref)
                snapshot_last_run_time = \
                    session.xenapi.VMSS.get_last_run_time(
                                                        self.vmss_ref)

                # Get the snapshot schedule details for the policy from XAPI.

                snapshot_schedule = \
                    session.xenapi.VMSS.get_schedule(
                                                   self.vmss_ref)
                snapshot_schedule['frequency'] = \
                    session.xenapi.VMSS.get_frequency(self.vmss_ref)

            # Use the snapshot schedule, last snapshot run time and the current
            # time to figure out if a snapshot should be executed now.

            snapshot = \
                should_operation_be_run(
                            snapshot_schedule, snapshot_last_run_time)

            # Prepare args for execute_policy

            args['vmss_uuid'] = vmss_uuid
            if snapshot:
                execute_policy("None", args)
            else:
                log_message("Not processing policy: {0:s}" .format(
                    vmss_uuid))

        except Exception, e:
            log_message("ProcessPolicy failed with exception: {0:s}" .format(e))


def get_last_expected_run_time(schedule, inUTC = False):
    last_expected_run_time = None

    try:
        now = datetime.datetime.now()

        # check operation frequency
        if schedule['frequency'] == 'hourly':
            # calculate the last expected run time, based on the current time.
            if now.minute > int(schedule['min']):
                # current mins are more than schedule mins so no need to
                # change the hour
                last_expected_run_time = \
                    datetime.datetime(now.year, now.month, now.day, now.hour,
                                      int(schedule['min']),0,0)
            else:
                last_expected_run_time = \
                    (datetime.datetime(now.year, now.month, now.day, now.hour,
                                       int(schedule['min']),0,0) -
                     datetime.timedelta(hours = 1))
        elif schedule['frequency'] == 'daily':
            # calculate the last expected run time, based on the
            # current date and time.
            if (now.hour > int(schedule['hour']) or
               ((now.hour == int(schedule['hour'])) and
                (now.minute > int(schedule['min'])))):
                # current hours are more than schedule hours so no need
                # to change the day
                last_expected_run_time = \
                    datetime.datetime(now.year, now.month, now.day,
                                      int(schedule['hour']),
                                      int(schedule['min']),0,0)
            else:
                last_expected_run_time = \
                    (datetime.datetime(now.year, now.month, now.day,
                                       int(schedule['hour']),
                                       int(schedule['min']),0,0) -
                     datetime.timedelta(days = 1))
        elif schedule['frequency'] == 'weekly':
            # First create a map of the days in the schedule for
            # ease of computation later
            dayMap = {}
            for day in schedule['days'].split(','):
                dayMap[day] = '1'

            lastDayFound = False

            # calculate the last expected run time, based on the current
            # date and time.
            # if current time is less than scheduled time
            if (now.hour < int(schedule['hour']) or
                ((now.hour == int(schedule['hour']) and
                 now.minute < int(schedule['min'])))):
                # go to the last day on the scheduled list excluding today
                noOfDays = 1
            else:
                # go to the last day on the scheduled list including today
                noOfDays = 0

            newDate = now
            while not lastDayFound and noOfDays < 8:
                td = datetime.timedelta(days = noOfDays)
                newDate = now - td
                if dayMap.has_key(newDate.strftime("%A")):
                    lastDayFound = True
                else:
                    noOfDays += 1

            if not lastDayFound:
                raise Exception("Could not find the last expected execution "
                                "time for the schedule: %s"
                                % schedule)

            # generate a date with this day and the scheduled time
            last_expected_run_time = \
                datetime.datetime(newDate.year, newDate.month, newDate.day,
                                  int(schedule['hour']),
                                  int(schedule['min']),0,0)

                # Now check if this needs to be converted into UTC time
        if inUTC:
            secs = time.mktime(last_expected_run_time.timetuple())
            last_expected_run_time = time.gmtime(secs)

    except Exception, e:
        log_message("There was an exception in finding out the last expected "
                    "run time of a schedule. {0:s}" .format(e))

    return last_expected_run_time

def should_operation_be_run(schedule, last_run_time):
    try:
        # Get the current time
	now = datetime.datetime.utcnow()

        # check if the operation is due because of the schedule
        if is_due_for_run(schedule):
            return True
        else:
            # not due for run yet, if the last run time is the initial
            # time then check if we still have to run it!
            if last_run_time == INITIAL_RUN_TIME:
                last_expected_run_time = \
                xmlrpclib.DateTime(get_last_expected_run_time(schedule, True))
                if (xmlrpclib.DateTime(time.mktime(now.timetuple())) > last_expected_run_time):
                    log_message("scheduling policy for first time")
                    return True
                else:
                    return False


        # if not, check if it should be run anyways as it wasnt run in
        # the last timeslot for some reason

        # if the last run time is in the future then run the operation
        if (last_run_time >
                 xmlrpclib.DateTime(time.mktime(now.timetuple()))):
            log_message("The last run time is in the future then run the "
                        "operation, run the operation to be safe.")
            return True

        last_expected_run_time = \
                xmlrpclib.DateTime(get_last_expected_run_time(schedule, True))

        # Now check if the last run time was before the last expected run time
        if last_run_time < last_expected_run_time:
            log_message("The last expected run time was {0:s}, however the "
                        "operation was last run at {1:s}, hence run it "
                        "again." .format(last_expected_run_time, last_run_time))
            return True
        else:
            return False

    except Exception, e:
        log_message("Exception in should_operation_be_run: {0:s}" .format(e))
        return False

def is_due_for_run(schedule):
    try:
        # Find the current time and extract required information
        now = datetime.datetime.now()
        day = now.strftime("%A")
        hour = now.hour
        min = now.minute

        # Now compare with the schedule passed in
        if min != int(schedule['min']):
            return False

        if schedule['frequency'] == 'hourly':
            return True

        if hour != int(schedule['hour']):
            return False

        if schedule['frequency'] == 'daily':
            return True

        # If we have come to this point the frequency is definitely weekly
        # however still putting in a check just in case we have monthly
        # frequency in the later releases

        day_map = {}
        for dayofweek in schedule['days'].split(','):
            day_map[dayofweek] = '1'

        if schedule['frequency'] == 'weekly' and day_map.has_key(day):
            return True
        else:
            return False

    except Exception, e:
        log_message("Exception in is_due_for_run: {0:s}" .format(e))
        return False

    return False

#
#TODO: remove unwanted arguments passed by xapi
#
def execute_policy(xapi1, args):
    ret_val = str(True)
    try:
        log_message("Processing policy: {0:s}" .format(args['vmss_uuid']))
        with xapi_session() as session:
            policy_lock = None
            vmss_uuid = args['vmss_uuid']
            vmss_ref = session.xenapi.VMSS.get_by_uuid(vmss_uuid)
            if not session.xenapi.VMSS.get_enabled(vmss_ref):
                log_message("Policy {0:s} is not enabled" .format(args[
                                                                      'vmss_uuid']))
                return ret_val # true

            if len(session.xenapi.VMSS.get_VMs(vmss_ref)) == 0:
                log_message("No VMs assigned to policy: {0:s}" .format(args[
                                'vmss_uuid']))
                return ret_val # true

            # we reach this point only when we need to process a policy
            # therefore acquire a policy lock

            policy_lock = get_snapshot_lock(vmss_uuid)
            if not acquire_lock(policy_lock):
                create_alert(session, vmss_uuid, "warn",
                             create_structured_alert(session, 'warn',{},
                                        "VMSS_SNAPSHOT_LOCK_FAILED"),
                             create_email_body(session, {},"VMSS_SNAPSHOT_LOCK_FAILED"),
                             "VMSS_SNAPSHOT_LOCK_FAILED")
                raise Exception("%s" %
                        errorcode_to_error_map["VMSS_SNAPSHOT_LOCK_FAILED"])

            (ret_val_snapshot, vm_to_snapshot_result, error_snapshot) = \
                process_VMs(session, vmss_ref)

            if ret_val_snapshot:
                create_alert(session, vmss_uuid, "info",
                             create_structured_alert(session, 'info',{},
                                        "VMSS_SNAPSHOT_SUCCEEDED"),
                             create_email_body(session, vm_to_snapshot_result),
                             "VMSS_SNAPSHOT_SUCCEEDED")
            else:
                # Generate snapshot schedule failure alerts here.
                create_alert(session, vmss_uuid, "error",
                             create_structured_alert(session, 'error',
                                                     vm_to_snapshot_result),
                    create_email_body(session, vm_to_snapshot_result),
                    "VMSS_SNAPSHOT_FAILED")
                log_message("process_VMs failed with the following error "
                            "details: {0:s} and {1:s}." .format(
                    vm_to_snapshot_result, error_snapshot))

                raise Exception

            log_message("Completed processing policy: {0:s}" .format(vmss_uuid))

    except Exception, e:
        log_message("Exception in execute_policy: {0:s}" .format(e))
        ret_val = '%s.' % str(e)

    finally:
        if policy_lock:
            release_lock(policy_lock)  # release policy lock
        return ret_val

def acquire_lock(l):
    try:
        return l.acquireNoblock()

    except Exception, e:
        log_message("There was an exception in acquiring lock. Exception: {"
                    "0:s}" .format(e))
        return False

def create_global_lock():
    return Lock("schedule.all","vmss")

def get_snapshot_lock(vmss_uuid):
    return Lock("%s.running" % vmss_uuid, "vmss")

def release_lock(l):
    try:
        l.release()
        return True

    except Exception, e:
        log_message("There was an exception releasing lock. Exception: {0:s}" .format(e))
        return False

def trigger_schedule_snapshots():

# This function is the entry point for cron job
# Algo:
# 1. Check if the host is master, if not exit
# 2. Check if atleast one VMSS is enabled, if yes then call schedule_snapshots

    f = open(POOL_CONF_FILE,'r')
    if f.read() != 'master':
        return

    try:
        with xapi_session() as session:
            log_message("===Kicking cron job for VMSS===")
            call_plugin = False
            for vmss in session.xenapi.VMSS.get_all():
                if session.xenapi.VMSS.get_enabled(vmss):
                    call_plugin = True
                    break
            if not call_plugin:
                log_message("VMSS policy not enabled for this pool, Exiting cron "
                            "job.")
            else:
                # Find the local host uuid
                host_ref = util.get_localhost_uuid(session)
                text = session.xenapi.host.call_plugin( host_ref, "vmss",
                                                        "schedule_snapshots", {})

    except Exception, e:
        log_message("Exception in trigger_schedule_snapshots: %s" % str(e))

def create_email_body(session, vm_to_error_map = {}, error_code = '',
                      additional_error_info = ''):
    # This will only be called for errors and warnings so we
    # do not need an alert type First handle the case where a vm to error map
    # is passed in.

    failed_VMs = 0
    error_str = ''
    try:
        if vm_to_error_map != {}:
            for vm in vm_to_error_map.keys():
                if vm_to_error_map[vm] != str(True):
                    failed_VMs += 1
                    vm_uuid = session.xenapi.VM.get_uuid(vm)
                    vm_name = session.xenapi.VM.get_name_label(vm)
                    error_str += ("VM: %s UUID: %s Error:%s" %
                                 (vm_name, vm_uuid, vm_to_error_map[vm]))
                    error_str += ',\n'

            error_str = error_str.strip('\n')
            error_str = error_str.strip(',')

            return ("Snapshot failed on {0:d} out of {1:d} VMs with the "
                    "following errors: \n\nDetails:\n{2:s}" .format(failed_VMs,
                                          len(vm_to_error_map.keys()), error_str))

        # Now handle if an error code is passed in
        if error_code != '':
            if additional_error_info != '':
                return ("failed with error: {0:s}. Additional error details: "
                        "{1:s}." .format(errorcode_to_error_map[error_code],
                                errorcode_to_error_map[additional_error_info]))
            else:
                return ("failed with error: {0:s}." .format(errorcode_to_error_map[error_code]))

    except:
        log_message("Exception in create_email_body")
        return ''

def create_structured_alert(session, alert_type,
                          vm_to_error_map = {}, error_code = ''):
    try:
        data_str = ("<XCData><time>%s</time><messagetype>%s</messagetype>" %
                   (datetime.datetime.now(), alert_type))
        if alert_type == 'error':
            if vm_to_error_map != {}:
                # Normal error with a vm to error map
                for vm in vm_to_error_map.keys():
                    if vm_to_error_map[vm] != str(True):
                        vm_uuid = session.xenapi.VM.get_uuid(vm)
                        error = vm_to_error_map[vm].split(',')[0]
                        error = error.lstrip('[')
                        error = error.lstrip('\'')
                        error = error.rstrip(']')
                        error = error.rstrip('\'')
                        data_str += \
                          ("<error><vm>%s</vm><errorcode>%s</errorcode></error>"
                          % (vm_uuid, error))

        if alert_type == 'warn' or alert_type == 'info':
            data_str += '<message>%s</message>' % (error_code)

        data_str += "</XCData>"

        log_message ("RETURN in create_structured_alert: {0:s}" .format(str(
            data_str)))
        return data_str
    except:
        log_message ("Exception in create_structured_alert")
        return ''

def create_alert(session, vmss_uuid, alert_type, structured_alert,
                 email_body = '', error_code = ''):
    try:
        if alert_type == 'error':
            session.xenapi.message.create(error_code, "1", "VMSS", vmss_uuid,
                                email_body)
        elif alert_type == 'warn':
            session.xenapi.message.create(error_code, "1", "VMSS", vmss_uuid,
                                        email_body)
        elif alert_type == 'info':
            session.xenapi.message.create(error_code, "4", "VMSS", vmss_uuid,
                                        error_code)
    except Exception, e:
        log_message("Failed to create alerts for vmss {0:s} with alert level: "
                    "{1:s}. Error: {2:s}" .format(
            vmss_uuid, alert_type,str(e)))

if __name__ == "__main__":
    log_message("Entering VMSS")
    XenAPIPlugin.dispatch({"schedule_snapshots": schedule_snapshots,
                           "snapshot_now": execute_policy})
