Coverage for strongcoca/calculators/base_calculator.py: 100%
46 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-10-26 18:44 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-10-26 18:44 +0000
1from abc import abstractmethod
2from typing import Optional
3import numpy as np
4from .. import CoupledSystem
5from ..response.base import BaseResponse
6from ..response.utilities import Broadening
7from ..units import au_to_eV
10class BaseCalculator(BaseResponse):
12 """Children of this class enable the calculation of correlation energy and
13 spectrum of a coupled system.
15 Parameters
16 ----------
17 coupled_system
18 Coupled system for which to carry out calculations.
19 broadening
20 Broadening used.
21 name
22 Name of calculator.
23 """
25 def __init__(self,
26 coupled_system: CoupledSystem,
27 broadening: Broadening,
28 name: str = 'BaseCalculator') -> None:
29 super().__init__(broadening=broadening, pbc=False, name=name)
30 if not isinstance(coupled_system, CoupledSystem):
31 raise TypeError(
32 f'coupled_system must be of type CoupledSystem is {type(coupled_system)}.')
33 self._coupled_system = coupled_system
34 self._update_state()
35 self._correlation_energy: Optional[float] = None
37 def _update_state(self) -> None:
38 """Update the internal variables that are used for determining whether
39 the object 'is dirty'. This method is used internally to determine
40 whether previous calculation results have become invalid.
41 """
42 self._positions = self._coupled_system.positions.copy()
43 self._orientations = [pu.orientation for pu in self._coupled_system]
44 self._responses = [pu.response for pu in self._coupled_system]
46 def _verify_state(self) -> None:
47 """Update state and discard previous calculation data if the object
48 'is dirty'. Otherwise do nothing.
49 """
50 if not self.is_dirty():
51 return
52 self._update_state()
53 self._discard_data()
55 def _discard_data(self) -> None:
56 """Discard all the data that becomes invalid when the system gets dirty."""
57 self._correlation_energy = None
59 @property
60 def coupled_system(self) -> CoupledSystem:
61 """Coupled system for which calculations are carried out."""
62 return self._coupled_system
64 def is_dirty(self) -> bool:
65 """Returns True if the state of the underlying coupled system has
66 changed since the last calculation.
67 """
68 if not np.allclose(self._positions, self._coupled_system.positions): # type: ignore
69 return True
70 if any(not np.allclose(pu.orientation.as_quat(), o.as_quat())
71 for pu, o in zip(self._coupled_system, self._orientations)):
72 return True
73 if not np.all([pu.response == rf
74 for pu, rf in zip(self._coupled_system, self._responses)]):
75 return True
76 return False
78 def get_correlation_energy(self):
79 """Return the correlation energy of the coupled system in units of eV."""
80 return self._get_correlation_energy() * au_to_eV
82 def _get_correlation_energy(self) -> float:
83 """Return the correlation energy of the coupled system in atomic units."""
84 self._verify_state()
85 if self._correlation_energy is None:
86 self._correlation_energy = float(self._calculate_correlation_energy())
87 return self._correlation_energy
89 @abstractmethod
90 def _calculate_correlation_energy(self) -> float:
91 raise NotImplementedError()