# Copyright (c) 2023 VMware Inc.  All rights reserved.
# -- VMware Confidential

"""
VisorFS MountRev utility library.

WARNING: This module requires Python >= 3.7.
Prior Python versions will fail to import the module.

Importers which can run in Python < 3.7 (e.g., `esximage` patch-the-patcher on
ESXi < 7.0) should avoid importing this module when using those older runtimes.
(Such ESXi versions would not have MountRev functionality, anyway.)
"""

from contextlib import contextmanager
from dataclasses import dataclass
import logging
import os.path
from typing import (
   Iterable,
   Mapping,
   Optional,
)

from vmware import vsi

log = logging.getLogger('MountRev')

class _FrozenMapping(dict):
   """Read-only mapping.
   """

   def __setitem__(self, *args, **kwargs):
      raise NotImplementedError('This mapping is read-only!')

   def __delitem__(self, *args, **kwargs):
      raise NotImplementedError('This mapping is read-only!')

@dataclass(init=False, order=True, frozen=True)
class MountRev:
   """Represents a MountRev's metadata.
   - revNum: Revision number.
   - tardisks: Tardisks mounted in the revision.

   WARNING: Each instance's data is frozen.
   To read new data, construct a new instance.
   """

   # NOTE: dataclass attribute ordering affects instance sort order!
   revNum: int
   tardisks: Mapping[str, 'MountRev.Tardisk']

   @dataclass(init=False, order=True, frozen=True)
   class Tardisk:
      """Represents a tardisk mounted in a revision.
      - name: Mounted tardisk name.
      - revNum: Revision in which the tardisk is mounted.
      - hash: Tardisk's hash.

      WARNING: Each instance's data is frozen.
      To read new data, construct a new instance.
      """

      # NOTE: dataclass attribute ordering affects instance sort order!
      revNum: int
      name: str
      hash: str

      def __init__(self, revNum: int, name: str):
         """INTERNAL: Init using current system state for tardisk in revision.
         """
         hash = self.hexDigestFromVSI(vsi.get(self._vsi(
            revNum, name, 'sha256hash',
         )))
         for (attr, value) in (
            ('revNum', revNum),
            ('name', name),
            ('hash', hash),
         ):
            # NOTE: Required to set attributes on frozen dataclass instance.
            object.__setattr__(self, attr, value)

      @classmethod
      def hexDigestFromVSI(cls, vsiHash):
         """Convert a VSI hash value to a hexadecimal string digest.
         """
         return ''.join(map('{:02x}'.format, vsiHash))

      @classmethod
      def _vsi(cls, rev, tardiskName, *args):
         """Compute tardisk VSI node in the right mount revision.
         """
         return MountRev._vsi('rev', str(rev), 'tardisks', tardiskName, *args)

   def __init__(self, revNum: int):
      """INTERNAL: Init using current system state for given revision.
      """
      tardisks = _FrozenMapping({
         tardiskName: self.Tardisk(revNum, tardiskName)
         for tardiskName in vsi.list(self._vsi(
            'rev', str(revNum), 'tardisks'
         ))
      })
      for (attr, value) in (
         ('revNum', revNum),
         ('tardisks', tardisks),
      ):
         # NOTE: Required to set attributes on frozen dataclass instance.
         object.__setattr__(self, attr, value)

   @classmethod
   def _vsi(cls, *args):
      """Compute MountRev VSI node."""
      return '/'.join(('/system', 'visorfs', 'MountRev') + args)

   @classmethod
   def revLatest(cls):
      """Read latest revision."""
      return vsi.get(cls._vsi('revLatest'))

   @classmethod
   def revVisible(cls):
      """Read visible revision."""
      return vsi.get(cls._vsi('revVisible'))

   @classmethod
   def getVibInfo(cls, rev, tardisk):
      """Get the vib information of a tardisk in a given revision."""
      return vsi.get(cls.Tardisk._vsi(rev, tardisk, 'vibInfo'))

   @classmethod
   def setVibInfo(cls, vibObj, tardisk, rev=None):
      """Set the vib information of a tardisk."""
      build = vibObj.version.release.versionstring.split('.')[-1]
      vibInfo = {
         'vendor': vibObj.vendor,
         'name': vibObj.name,
         'version': vibObj.version.versionstring,
         'buildNumber': build,
      }
      if rev is None:
         rev = cls.revLatest()
      vsi.set(cls.Tardisk._vsi(rev, os.path.basename(tardisk), 'vibInfo'),
              vibInfo)

   # Cache of existing, published revisions.
   # Revisions strictly below the visible rev cannot be changed, so it is safe
   # to cache their instances here (indexed by revision number).
   _REVS = []

   @classmethod
   def revs(cls, onlyVisible: bool=False) -> Iterable['MountRev']:
      """Iterate over all revisions (at time of call) in ascending order.
      If `onlyVisible` is `True`, stops at last visible revision.
      Otherwise, continues to the latest revision.
      """
      revVisible = cls.revVisible()
      revNums = vsi.list(cls._vsi('rev'))
      for revNum in revNums:
         revNum = int(revNum)
         if onlyVisible and revNum > revVisible:
            break   # Only iterating visible revisions.

         if revNum < len(cls._REVS):
            yield cls._REVS[revNum]    # Cache hit.
            continue

         rev = cls(revNum)   # Cache miss.
         if revNum < revVisible:
            # Cacheable (before current visible revision).
            assert len(cls._REVS) == revNum
            log.debug('cached rev #{:d}'.format(revNum))
            cls._REVS.append(rev)
         yield rev

   @classmethod
   def cartelRevNum(cls, cartelID: int) -> int:
      """Get minimum revision number for given cartel.
      """
      return vsi.get('/'.join((
         '/userworld', 'cartel', str(cartelID),
         'visorFSMountRevMin'
      )))

   # MountRev statistics related methods
   @classmethod
   def stats(cls):
      """Get MountRev statistics."""
      return vsi.get(cls._vsi('stats'))

   @classmethod
   def cancelNumAttempts(cls) -> int:
      """Get number of attempted cancellations."""
      return cls.stats()['cancelledAttempts']

   @classmethod
   def cancelNumSuccess(cls) -> int:
      """Get number of successful cancellations."""
      return cls.stats()['cancelledSuccess']

   @classmethod
   def _vfatNameForPayload(cls, payloadName: str) -> str:
      """Extract VFAT filename (sans extension) from payload name.
      See `vmware.esximage.ImageProfile.ImageProfile.GenerateVFATNames()`.
      """
      payloadName, _ = os.path.splitext(payloadName)
      return payloadName[:8].replace('-', '_').lower()

   @classmethod
   def findMinRevWithPayload(
      cls,
      payloadName: str,
      payloadHash: str
   ) -> Optional['MountRev']:
      """Finds earliest revision containing payload with given name and hash.
      Returns `None`
      NOTE: Only the VFAT filename, sans extension, is compared between the
      `payloadName` and mounted tardisks' names.
      """
      payloadName = cls._vfatNameForPayload(payloadName)
      for rev in cls.revs():
         for tardiskName, tardisk in rev.tardisks.items():
            tardiskName = cls._vfatNameForPayload(tardiskName)
            if tardiskName != payloadName:
               continue    # Tardisk's VFAT name does not match payload.
            if tardisk.hash != payloadHash:
               continue    # Tardisk hash does not match.
            return rev  # Matching tardisk found!

   @classmethod
   def _create(cls):
      """INTERNAL: Create a mount revision."""
      return vsi.set(cls._vsi('create'), [])

   @classmethod
   def _publish(cls):
      """INTERNAL: Publish a mount revision."""
      return vsi.set(cls._vsi('publish'), [])

   @classmethod
   def _cancel(cls):
      """INTERNAL: Cancel a pending mount revision."""
      return vsi.set(cls._vsi('cancel'), [])

   @classmethod
   @contextmanager
   def transact(cls, shouldPublish=True):
      """Attempts to create a MountRev.
      Upon successful MountRev creation, enters the context.  If an
      exception causes a context-exit, cancels the MountRev.  Otherwise
      (context-exit w/o exception), depending on the 'shouldPublish' value
      either publishes the MountRev (default) or cancels it (mainly for
      testing
      purpose).
      """
      revCreated = cls._create()
      log.info('Created rev, new latest: %r', revCreated)
      try:
         yield revCreated
      except Exception:
         shouldPublish = False
         raise
      finally:
         if shouldPublish:
            res = cls._publish()
            log.info('Published rev, new visible: %r', res)
         else:
            res = cls._cancel()
            log.info('Cancelled rev, new latest: %r', res)
