Coverage for strongcoca/calculators/base_calculator.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-10-26 18:44 +0000

1from abc import abstractmethod 

2from typing import Optional 

3import numpy as np 

4from .. import CoupledSystem 

5from ..response.base import BaseResponse 

6from ..response.utilities import Broadening 

7from ..units import au_to_eV 

8 

9 

10class BaseCalculator(BaseResponse): 

11 

12 """Children of this class enable the calculation of correlation energy and 

13 spectrum of a coupled system. 

14 

15 Parameters 

16 ---------- 

17 coupled_system 

18 Coupled system for which to carry out calculations. 

19 broadening 

20 Broadening used. 

21 name 

22 Name of calculator. 

23 """ 

24 

25 def __init__(self, 

26 coupled_system: CoupledSystem, 

27 broadening: Broadening, 

28 name: str = 'BaseCalculator') -> None: 

29 super().__init__(broadening=broadening, pbc=False, name=name) 

30 if not isinstance(coupled_system, CoupledSystem): 

31 raise TypeError( 

32 f'coupled_system must be of type CoupledSystem is {type(coupled_system)}.') 

33 self._coupled_system = coupled_system 

34 self._update_state() 

35 self._correlation_energy: Optional[float] = None 

36 

37 def _update_state(self) -> None: 

38 """Update the internal variables that are used for determining whether 

39 the object 'is dirty'. This method is used internally to determine 

40 whether previous calculation results have become invalid. 

41 """ 

42 self._positions = self._coupled_system.positions.copy() 

43 self._orientations = [pu.orientation for pu in self._coupled_system] 

44 self._responses = [pu.response for pu in self._coupled_system] 

45 

46 def _verify_state(self) -> None: 

47 """Update state and discard previous calculation data if the object 

48 'is dirty'. Otherwise do nothing. 

49 """ 

50 if not self.is_dirty(): 

51 return 

52 self._update_state() 

53 self._discard_data() 

54 

55 def _discard_data(self) -> None: 

56 """Discard all the data that becomes invalid when the system gets dirty.""" 

57 self._correlation_energy = None 

58 

59 @property 

60 def coupled_system(self) -> CoupledSystem: 

61 """Coupled system for which calculations are carried out.""" 

62 return self._coupled_system 

63 

64 def is_dirty(self) -> bool: 

65 """Returns True if the state of the underlying coupled system has 

66 changed since the last calculation. 

67 """ 

68 if not np.allclose(self._positions, self._coupled_system.positions): # type: ignore 

69 return True 

70 if any(not np.allclose(pu.orientation.as_quat(), o.as_quat()) 

71 for pu, o in zip(self._coupled_system, self._orientations)): 

72 return True 

73 if not np.all([pu.response == rf 

74 for pu, rf in zip(self._coupled_system, self._responses)]): 

75 return True 

76 return False 

77 

78 def get_correlation_energy(self): 

79 """Return the correlation energy of the coupled system in units of eV.""" 

80 return self._get_correlation_energy() * au_to_eV 

81 

82 def _get_correlation_energy(self) -> float: 

83 """Return the correlation energy of the coupled system in atomic units.""" 

84 self._verify_state() 

85 if self._correlation_energy is None: 

86 self._correlation_energy = float(self._calculate_correlation_energy()) 

87 return self._correlation_energy 

88 

89 @abstractmethod 

90 def _calculate_correlation_energy(self) -> float: 

91 raise NotImplementedError()