#!/usr/bin/python

# Copyright (c) 2018-2024 Broadcom. All Rights Reserved.
# Broadcom Confidential. The term "Broadcom" refers to Broadcom Inc.
# and/or its subsidiaries.

"""Unit tests for the ImageManager.DepotMgr module.
"""

import json
import logging
import os
import platform
import ssl

from . import DepotInfo

from .Constants import BASEIMAGE_UI_NAME
from .StagingArea import STAGING_ROOT

from ..BaseImage import versionSpecListToDictOrStr
from ..Bulletin import ComponentCollection
from ..DepotCollection import DepotCollection
from ..Errors import DepotConnectError, ReleaseUnitSchemaVersionError
from ..Manifest import Manifest

IS_ESX = (platform.system() == 'VMkernel')

HAVE_SLOCK = False
if IS_ESX:
   from ..HostImage import HostImage
   try:
      # There is no borautil on older ESXi, depot spec operations will be
      # unavailable in an upgrade scenario.
      from borautils.slock import openWithLockAndRetry
      HAVE_SLOCK = True
   except ImportError:
      pass
   DEPOT_SPECS_FILE = os.path.join(STAGING_ROOT, 'depots.json')

if hasattr(ssl, '_create_unverified_context') and\
      hasattr(ssl, '_create_default_https_context'):
   ssl._create_default_https_context = ssl._create_unverified_context

log = logging.getLogger(__name__)

joinList = lambda x: ', '.join([str(i) for i in x])

def _logAndRaise(baseMsg, urls, errors):
   msg = baseMsg + (" %s: %s" % (joinList(urls), joinList(errors)))
   log.error(msg)
   raise DepotConnectError(errors, msg)

class DepotParsingError(Exception):
   pass

class DuplicateDepotError(Exception):
   pass

class DepotNotFoundError(Exception):
   pass

class DepotMgr(object):
   """ Class for abstracting depot management.
       This class should try not do duplicate work done by
       esximage.DepotCollection

       Unlike the VAPI representation of the list of depotSpecs,
       [ { 'name' : 'foo', 'url' : 'http://foo.com/' } ],
       the structure of the file where depots are persisted is
       modeled after the structure of DepotMgr._depots, which is
       simple map of depotName:depotUrl and allows O(1) insert and delete.
   """

   def __init__(self, depotSpecs=None, connect=False,
                ignoreError=True, validate=False, schemaVersionCheck=False):
      """ Initialize the DepotMgr class by loading
          depotSpecs from storage and conditionally
          connecting to the URLs.

          depotSpecs - An optional list of name:url maps to initialize with.
          connect - Specifies whether we should connect to the URLs.
                    This option will be set to true in apply and
                    set{SoftwareSpec,Component} workflows.
          ignoreError - If true, a depot connect exception will be logged
                        and the problematic depot URL is skipped.
          validate - Flag to enforce schema validation
          schemaVersionCheck - Flag to enforce schema version check
      """
      self._depots = {}
      self.components = ComponentCollection()
      self._dc = DepotCollection()

      self._createStagingSpec()

      if depotSpecs is not None:
         for depot in depotSpecs:
            if 'name' not in depot or 'url' not in depot:
               raise DepotParsingError
            self._depots[depot['name']] = depot['url']
      else:
         self._loadStagingSpec()

      if connect:
         self._connect(ignoreError=ignoreError, validate=validate,
                       schemaVersionCheck=schemaVersionCheck)

   @property
   def componentsWithVibs(self):
      """Get components with full VIB information. This excludes reserved
         components on pre-U2 hosts that do not come with reserved VIBs.
         Returns a ComponentCollection.
      """
      return self.components.GetComponentsFromVibIds(set(self.vibs.keys()))

   def _createStagingSpec(self):
      """ Create a spec file for storage.
      """
      if not IS_ESX:
         # Depot spec staging is only enabled on ESXi.
         return
      # Create a dummy depots.json file.
      if not os.path.isdir(STAGING_ROOT):
         os.mkdir(STAGING_ROOT)
      try:
         with open(DEPOT_SPECS_FILE, 'x') as f:
            json.dump({}, f)
      except FileExistsError:
         pass

   def _loadStagingSpec(self):
      """ Load up depots from storage.
      """
      if not IS_ESX or not HAVE_SLOCK:
         # Depot spec staging is only enabled on ESXi, and not in legacy
         # upgrade scenarios.
         return
      try:
         if os.path.isfile(DEPOT_SPECS_FILE):
            with openWithLockAndRetry(DEPOT_SPECS_FILE, 'r') as f:
               data = json.load(f)
            self._depots.update(data)
      except (ValueError, IOError) as e:
         raise DepotParsingError("Unable to parse depots file %s: %s" %
                                 (DEPOT_SPECS_FILE, str(e)))

   def _updateCollections(self):
      """ Refresh the release unit collection references.
      """
      # find the components
      self.components = ComponentCollection(self._dc.bulletins, True)
      self.vibs = self._dc.vibs
      self.addons = self._dc.addons
      self.baseimages = self._dc.baseimages
      self.solutions = self._dc.solutions
      self.manifests = self._dc.manifests

   def _connect(self, url=None, ignoreError=True, validate=False,
                schemaVersionCheck=False):
      """ Connect to a specified URL or all known URLs.
      """
      def _raiseError(depotUrls, errors):
         # Raise specific exception ReleaseUnitSchemaVersionError for schema
         # version check issue, rather than generic exception.
         schemaErrors = [e for e in errors if
                         isinstance(e, ReleaseUnitSchemaVersionError)]
         if schemaErrors:
            raise schemaErrors[0]
         depotStr = ','.join(depotUrls)
         exMsgs = '\n' + '\n'.join([str(e) for e in errors])
         msg = "Unable to connect to depot(s) %s: %s" % (depotStr, exMsgs)
         log.exception(msg)
         raise DepotConnectError(errors, msg)

      depotUrls = [url] if url else list(self._depots.values())
      try:
         _, errors = self._dc.ConnectDepots(
               depotUrls, ignoreerror=ignoreError, validate=validate,
               schemaVersionCheck=schemaVersionCheck)
      except Exception as e:
         _raiseError(depotUrls, [e])

      if errors:
         _raiseError(depotUrls, errors)

      self._updateCollections()
      if IS_ESX:
         # Includes local metadata when running on ESXi.
         self._loadLocalMetadata()

   def _loadLocalMetadata(self):
      """Load local live/bootbank/staged metadata on ESXi.
         Scan requires current image metadata to check compliance, see
         PR 2164400. This method works on ESXi only.
      """
      hostImage = HostImage()
      # Latest host image.
      profiles = [hostImage.GetProfile()]
      if hostImage.imgstate == hostImage.IMGSTATE_BOOTBANK_UPDATED:
         # If pending reboot, load live image.
         profiles.append(hostImage.GetProfile(database=hostImage.DB_VISORFS))
      # Staged image.
      profiles.append(hostImage.stagedimageprofile)

      for p in profiles:
         self._loadProfileMetadata(p)

   def _loadProfileMetadata(self, profile):
      """Load components, vibs and other metadata from an image profile.
      """
      if profile is not None:
         for comp in profile.components.IterComponents():
            self.components.AddComponent(comp)
         for vib in profile.vibs.values():
            self.vibs.AddVib(vib)
            # Installed VIBs will not be re-downloaded, thus they need to
            # persist their signatures from the current database.
            # In addition, AddVib will not copy signature over when the VIB
            # is present in the form of metadata. Manual copy is required.
            self.vibs[vib.id].SetSignature(vib.GetSignature())
            self.vibs[vib.id].SetOrigDescriptor(vib.GetOrigDescriptor())

         for comp in profile.reservedComponents.IterComponents():
            if not self.components.HasComponent(comp.id):
               self.components.AddComponent(comp)

         if profile.baseimageID and profile.baseimageID not in self.baseimages:
            self.baseimages[profile.baseimageID] = profile.baseimage
         if profile.addonID and profile.addonID not in self.addons:
            self.addons[profile.addonID] = profile.addon
         self.manifests += profile.manifests
      else:
         log.warning("Couldn't extract the ImageProfile.")

   def deleteDepot(self, name):
      """ Delete a depot specified by name.
      """
      if name in self._depots:
         del self._depots[name]
         if IS_ESX:
            # Delete from the depot spec on ESXi.
            if not HAVE_SLOCK:
               raise RuntimeError('borautil.slock is not available')
            try:
               encoded = json.dumps(self._depots)
               with openWithLockAndRetry(DEPOT_SPECS_FILE, 'w') as f:
                  f.write(encoded)
            except ValueError as e:
               log.exception("Cannot encode depots.json file: %s", str(e))
            except IOError as e:
               log.error("Cannot write out depots.json file: %s", str(e))
      else:
         raise DepotNotFoundError

   def addDepot(self, depotSpec):
      """ Add a new depot to storage.
          Perform some validation on the URL.

          depotSpec is a dict with two keys: 'name' and 'url'
      """
      if depotSpec['name'] in self._depots:
         raise DuplicateDepotError("A depot with this Name already exists")
      elif depotSpec['url'] in self._depots.values():
         raise DuplicateDepotError("A depot with this URL already exists")

      depotUrl = depotSpec['url']
      try:
         dc = DepotCollection()
         _, errors = dc.ConnectDepots([depotUrl])
      except Exception as e:
         # Problems with processing the metadata result in exceptions.
         # In both cases we cannot add this depot.
         msg = "Unable to connect to depot %s: %s" % (depotUrl, str(e))
         log.exception(msg)
         raise DepotConnectError([e], msg)
      if errors:
         # Problems with connection and getting the metadata result in
         # errors being populated.
         msg = "Unable to connect to depot %s: %s" % (depotUrl, errors)
         log.exception(msg)
         raise DepotConnectError(errors, msg)

      self._depots[depotSpec['name']] = depotSpec['url']

      if IS_ESX:
         # Add the depot to depot spec on ESXi.
         if not HAVE_SLOCK:
            raise RuntimeError('borautil.slock is not available')
         with openWithLockAndRetry(DEPOT_SPECS_FILE, 'w') as f:
            json.dump(self._depots, f)

   def upsertDepots(self, depotSpecs, ignoreError=True, validate=False):
      """ For each depot in the depot spec list, add it if it is not managed by
          this depot manager yet; otherwise, update it: remove and load again.

          The depots in the current depot collection but not in depot spec
          list are kept.

          depotSpec is a list of dicts with two keys: 'name' and 'url'
      """
      toBeRemoved = set()
      for spec in depotSpecs:
         if spec['name'] in self._depots:
            toBeRemoved.add(self._depots[spec['name']])

         if (spec['url'] in self._depots.values() or
             spec['url'] in self._dc._urlToChannelMap):
            toBeRemoved.add(spec['url'])

      if toBeRemoved:
         try:
            self._dc.DisconnectDepots(toBeRemoved, isPman=True)
         except Exception as e:
            _logAndRaise("Unable to disconnect depot(s)", toBeRemoved, [e])

      depotUrls = [spec['url'] for spec in depotSpecs]
      try:
         _, errors = self._dc.ConnectDepots(depotUrls, ignoreerror=ignoreError,
                                            validate=validate)
      except Exception as e:
         # Problems with processing the metadata result in exceptions.
         # In both cases we cannot add this depot.
         _logAndRaise("Unable to connect to depot(s)", depotUrls, [e])
      if errors:
         # Problems with connection and getting the metadata result in
         # errors being populated.
         _logAndRaise("Unable to connect to depot(s)", depotUrls, errors)

      for spec in depotSpecs:
         self._depots[spec['name']] = spec['url']
      self._updateCollections()

   def processNotification(self):
      """ Process notificaitons, including removing components that are
          recalled from the collection and related MetadataNode object(s).
          Then update the collection.
      """
      self._dc.ProcessNotification()
      self._updateCollections()

   def deleteDepots(self, depots):
      """ Delete the depots from the depot collection.
      """
      missingDepots = set(depots) - set(self._depots.values())
      if missingDepots:
         _logAndRaise('Missing depots', missingDepots,
                      [DepotNotFoundError('Depots not found',
                                          ', '.join(missingDepots))])
      try:
         self._dc.DisconnectDepots(depots, isPman=True)
         urls = [k for k, v in self._depots.items() if v in depots]
         for u in urls:
            del self._depots[u]
         self._updateCollections()
      except Exception as e:
         _logAndRaise("Unable to disconnect depot(s)", depots, [e])

   def getAllDepotURLs(self):
      """ Return the list of all depot URLs
      """
      return self._depots.values()

   def getAllDepots(self):
      """ Return the list of depots in 'depotSpec' format.
      """
      return [{'name': n, 'url': u} for n, u in self._depots.items()]

   def _GetReleaseUnitComponentsInfo(self, relUnit):
      """Get the component info from a release unit.
      """
      componentsIds = relUnit.components
      compInfoList = []
      for name in componentsIds:
         version = componentsIds[name]
         comp = self.components.GetComponent(name, version)
         compInfo = {}
         compInfo['name'] = name
         compInfo['version'] = version
         compInfo['display_name'] = comp.componentnamespec['uistring']
         compInfo['display_version'] = comp.componentversionspec['uistring']
         compInfoList.append(compInfo)
      return compInfoList

   def _GetAddOnRemovedComponentsInfo(self, addon):
      """Get the removed component info for a given addon.
      """
      componentNames = addon.removedComponents
      compInfoList = []
      for name in componentNames:
         try:
            for comp in self.components.GetComponents(name=name):
               compInfo = {}
               compInfo['name'] = name
               compInfo['display_name'] = comp.componentnamespec['uistring']
               compInfoList.append(compInfo)
               # We just need one component to get the name and display_name
               break
         except (KeyError, ValueError) as e:
            log.warning("Removed component %s not found in depot. Error:%s",
               name, str(e))
            compInfo = {}
            compInfo['name'] = name
            compInfo['display_name'] = name
            compInfoList.append(compInfo)

      return compInfoList

   def GetBaseImageInfoList(self):
      """Get information of all base images.
      """
      baseImageInfoList = []
      for bi in self.baseimages.values():
         biInfo = {}
         biInfo['display_name'] = BASEIMAGE_UI_NAME
         biInfo['version'] = bi.versionSpec.version.versionstring
         biInfo['display_version'] = bi.versionSpec.uiString
         biInfo['summary'] = bi.summary
         biInfo['description'] = bi.description
         biInfo['category'] = bi.category.upper()
         biInfo['kb'] = bi.docURL
         biInfo['release_date'] = bi.releaseDate
         biInfo['components'] = self._GetReleaseUnitComponentsInfo(bi)
         # convert to JSON string in preparation for storing it in db
         biInfo['quick_patch_compatible_versions'] = \
            versionSpecListToDictOrStr(bi.quickPatchCompatibleVersions,
                                       toStr=True)
         baseImageInfoList.append(biInfo)
      return baseImageInfoList

   def _GetAddonInfoList(self, addons):
      addonInfoList = []
      for addon in addons.values():
         aInfo = {}
         aInfo['version'] = addon.versionSpec.version.versionstring
         aInfo['display_version'] = addon.versionSpec.uiString
         aInfo['vendor'] = addon.vendor
         aInfo['summary'] = addon.summary
         aInfo['description'] = addon.description
         aInfo['category'] = addon.category.upper()
         aInfo['kb'] = addon.docURL
         aInfo['release_date'] = addon.releaseDate
         aInfo['components'] = self._GetReleaseUnitComponentsInfo(addon)
         aInfo['removed_components'] = \
            self._GetAddOnRemovedComponentsInfo(addon)
         aInfo['base_image_versions'] = addon.supportedBaseImageVersions
         if isinstance(addon, Manifest):
            aInfo['manager_name'] = addon.hardwareSupportInfo.manager.name
            aInfo['package_name'] = addon.hardwareSupportInfo.package.name
            aInfo['package_version'] = addon.hardwareSupportInfo.package.version
         aInfo['name'] = addon.nameSpec.name
         aInfo['display_name'] = addon.nameSpec.uiString
         addonInfoList.append(aInfo)
      return addonInfoList

   def GetAddonInfoList(self):
      """Get information of all addons.
      """
      return self._GetAddonInfoList(self.addons)

   def GetManifestInfoList(self):
      """Get the information of all hardware support packages.
      """
      return self._GetAddonInfoList(self.manifests)

   def GetComponentInfoList(self):
      """Get the information of all components.
      """
      componentInfoList = []

      # Collect solution components.
      solutionComps = {}
      for sol in self.solutions.values():
         solCompDict = sol.MatchComponents(self.components)
         for name in solCompDict:
            for comp in solCompDict[name]:
               version = comp.compVersionStr
               solutionComps.setdefault(name, []).append(version)

      for name in self.components:
         sameName = self.components[name]
         for version in sameName:
            comp = sameName[version]
            compInfo = {}
            compInfo['name'] = name
            compInfo['version'] = version
            compInfo['display_name'] = comp.componentnamespec['uistring']
            compInfo['display_version'] = comp.componentversionspec['uistring']
            compInfo['vendor'] = comp.vendor
            if name in solutionComps and version in solutionComps[name]:
               compInfo['type'] = 'SOLUTION'
            else:
               compInfo['type'] = 'DRIVER'
            compInfo['summary'] = comp.summary
            compInfo['description'] = comp.description
            compInfo['category'] = comp.category.upper()
            compInfo['urgency'] = comp.urgency.upper()
            compInfo['kb'] = comp.kburl
            compInfo['contact'] = comp.contact
            compInfo['release_date'] = comp.releasedate
            componentInfoList.append(compInfo)
      return componentInfoList

   def GetVibInfo(self, vibids):
      """Get the information of all vibs. The vibs are classified into
         solution component vibs, non solution component vibs and standalone
         vibs.

         The result data structure is a dict:
         {
             non_solution_vibs: componentInfoList,
             solutions_vibs: componentInfoList,
             standalone_vibs: vibInfoList
         }
         with a componentInfoList is a list of objects of:
         {
             name: component_name,
             version: component_version,
             vibInfo: component vibInfoList
         }
         and a vibInfoList is a list of VIB info objects of:
         {
             vib: VIB ID,
             name: VIB name,
             version: VIB version
         }
      """
      allVibInfoMap = {}
      notExistedVibs = list(vibids)
      for vibid, vib in self.vibs.items():
         if vibid in vibids:
            vibInfo = {'vib': vibid,
                       'name': vib.name,
                       'version': vib.version.versionstring}
            allVibInfoMap[vibid] = vibInfo
            notExistedVibs.remove(vibid)

      # Handle vibs in components.
      compInfoDict = dict()
      compVibIds = set()
      relatedComps = ComponentCollection()
      for comp in self.components.IterComponents():
         vibInfoList = []
         compVibIds.update(comp.vibids)
         for vibId in comp.vibids:
            try:
               vibInfoList.append(allVibInfoMap[vibId])
            except KeyError:
               # VIB is not in the input VIB list.
               continue
         if vibInfoList:
            compName = comp.compNameStr
            compVersion = comp.compVersionStr
            compInfoDict[(compName, compVersion)] = vibInfoList
            relatedComps.AddComponent(comp)

      # Handle standalone vibs in depot
      standaloneVibs = [vibInfo for vibInfo in allVibInfoMap.values()
                        if vibInfo['vib'] not in compVibIds]

      # Collect solution vib info.
      solutionCompInfoList = []
      for sol in self.solutions.values():
         solCompDict = sol.MatchComponents(relatedComps)
         for name, version in compInfoDict:
            comp = relatedComps.GetComponent(name, version)
            if name in solCompDict and comp in solCompDict[name]:
               vibInfo = compInfoDict[(name, version)]
               compDict = dict(name=name, version=version, vibInfo=vibInfo)
               solutionCompInfoList.append(compDict)
               # Differentiate solution and non solution components.
               compInfoDict[(name, version)] = None

      # Collect non solution vib info.
      compInfoList = []
      for (name, version), vibInfoList in compInfoDict.items():
         if vibInfoList:
            compDict = dict(name=name, version=version, vibInfo=vibInfoList)
            compInfoList.append(compDict)

      # Handle vibs that don't exist in depot
      for vibid in notExistedVibs:
         standaloneVibs.append({'vib': vibid, 'name': '', 'version': ''})

      finalVibMap = {'non_solution_vibs' : compInfoList,
                     'solutions_vibs' : solutionCompInfoList,
                     'standalone_vibs': standaloneVibs}
      return finalVibMap

   def CalculateMicroDepots(self, imageProfile):
      """ Calculate the micro depots that contains all the image related
          objects in the provided image profile.
      """
      return self._dc.CalculateMicroDepots(imageProfile)

   def GetRelatedVibs(self, imageProfile):
      """ Generate a VibCollection that only contains the vibs from the
          micro depots that overlap with the provided image profile.
      """
      return self._dc.GetRelatedVibs(imageProfile)

   def deepcopy(self):
      """ The threading.RLock objects cannot be copied. So
          Hold the lock of this object's DepotCollection
          Acquire the lock
          Set the lock in _dc to None since threading.RLock is not clonable
          Deepcopy DepotMgr
          Assign a new lock to the new DepotMgr's DepotCollection
          Recover the lock
          Release the lock
      """
      from copy import deepcopy
      import threading
      self._dc._Lock()
      lock = self._dc._lock
      try:
         self._dc._lock = None
         depotMgrCopy = deepcopy(self)
         depotMgrCopy._dc._lock = threading.RLock()
         return depotMgrCopy
      finally:
         self._dc._lock = lock
         lock.release()

   def GetDepotInfo(self, depotUrls):
      """ The wrapper to get release object info for the provided depots.
      """
      return DepotInfo.GetDepotInfo(self._dc, depotUrls)

   def GetDepotUniqueInfo(self, depotUrls):
      """ The wrapper to get info of unique relased objects for the provided
          depots.
      """
      return DepotInfo.GetDepotUniqueInfo(self._dc, depotUrls)

   def GetVibConfigSchemas(self, vibs):
      """ The wrapper to retrieve config schemas for the given vibs from
          the contained DepotCollection "_dc".
      """
      return self._dc.GetVibConfigSchemas(vibs)

   def GetVibExports(self, vibs):
      """ The wrapper to retrieve vib exports for the given vibs from
          the contained DepotCollection "_dc".
      """
      return self._dc.GetVibExports(vibs)


def getDepotSpecFromUrls(depotUrls):
   """Given a list of depot URLs, form a depot spec for DepotMgr use.
      The depots are named in 'depot(depotUrl)'.
   """
   depotSpec = []
   depotUrls = set(depotUrls)
   for depotUrl in depotUrls:
      url = depotUrl.strip()
      depotSpec.append(dict(name='depot(%s)' % url, url=url))
   return depotSpec
