neo_solidity/solidity/analyse/modifiers/
constructors.rs

1fn apply_base_constructors_and_modifiers(
2    contract: &ContractIR,
3    constructor: &FunctionIR,
4    modifier_defs: &std::collections::HashMap<(String, usize), FunctionIR>,
5    contract_map: &std::collections::HashMap<String, ContractIR>,
6) -> Result<Statement, SolidityError> {
7    let chain = inheritance_contract_chain(contract, contract_map)?;
8    let base_contracts: std::collections::HashSet<String> = chain
9        .iter()
10        .filter(|name| *name != &contract.name)
11        .cloned()
12        .collect();
13
14    fn find_contract_constructor(contract: &ContractIR) -> Option<&FunctionIR> {
15        contract
16            .functions
17            .iter()
18            .find(|func| matches!(func.ty, FunctionTy::Constructor))
19    }
20
21    fn args_from_base_spec(base: &Base) -> Vec<Expression> {
22        base.args.clone().unwrap_or_default()
23    }
24
25    fn find_base_args_in_invocations(invocations: &[Base], base_name: &str) -> Option<Vec<Expression>> {
26        invocations
27            .iter()
28            .find(|b| base_last_name(b).as_deref() == Some(base_name))
29            .map(args_from_base_spec)
30    }
31
32    fn find_base_args_in_bases(bases: &[Base], base_name: &str) -> Option<Vec<Expression>> {
33        bases
34            .iter()
35            .find(|b| base_last_name(b).as_deref() == Some(base_name))
36            .map(args_from_base_spec)
37    }
38
39    fn contract_directly_inherits(contract: &ContractIR, base_name: &str) -> bool {
40        contract
41            .bases
42            .iter()
43            .any(|b| base_last_name(b).as_deref() == Some(base_name))
44    }
45
46    fn resolve_base_constructor_args(
47        base_name: &str,
48        contract: &ContractIR,
49        constructor: &FunctionIR,
50        chain: &[String],
51        contract_map: &std::collections::HashMap<String, ContractIR>,
52    ) -> Result<Vec<Expression>, SolidityError> {
53        // 1) Explicit base invocation in the most-derived constructor.
54        if let Some(args) = find_base_args_in_invocations(&constructor.base_or_modifiers, base_name) {
55            return Ok(args);
56        }
57        // 2) Base arguments on the most-derived contract inheritance specifier.
58        if let Some(args) = find_base_args_in_bases(&contract.bases, base_name) {
59            return Ok(args);
60        }
61
62        // 3) Walk down the linearized chain to find the contract that directly inherits this base,
63        // then use its constructor invocation / inheritance specifier.
64        let base_pos = chain
65            .iter()
66            .position(|name| name == base_name)
67            .ok_or_else(|| {
68                SolidityError::Analysis(format!(
69                    "internal error: base '{base_name}' missing from linearization"
70                ))
71            })?;
72
73        let mut candidates: Vec<Vec<Expression>> = Vec::new();
74        for name in chain.iter().skip(base_pos + 1) {
75            let Some(child) = contract_map.get(name) else {
76                continue;
77            };
78            if !contract_directly_inherits(child, base_name) {
79                continue;
80            }
81
82            if let Some(child_ctor) = find_contract_constructor(child) {
83                if let Some(args) =
84                    find_base_args_in_invocations(&child_ctor.base_or_modifiers, base_name)
85                {
86                    candidates.push(args);
87                    continue;
88                }
89            }
90            if let Some(args) = find_base_args_in_bases(&child.bases, base_name) {
91                candidates.push(args);
92            }
93        }
94
95        if candidates.is_empty() {
96            return Ok(Vec::new());
97        }
98
99        // Detect conflicts (e.g., diamond inheritance specifying different constructor args).
100        let first = candidates[0].clone();
101        for other in candidates.iter().skip(1) {
102            if *other != first {
103                return Err(SolidityError::Analysis(format!(
104                    "conflicting constructor arguments specified for base contract '{base_name}'"
105                )));
106            }
107        }
108
109        Ok(first)
110    }
111
112    // Base constructor prologue (runs before constructor modifiers).
113    let mut prologue: Vec<Statement> = Vec::new();
114    for base_name in chain
115        .iter()
116        .filter(|name| *name != &contract.name)
117        .cloned()
118        .collect::<Vec<_>>()
119    {
120        let Some(base_contract) = contract_map.get(&base_name) else {
121            continue;
122        };
123        let Some(base_ctor) = find_contract_constructor(base_contract) else {
124            continue;
125        };
126        let Some(base_body) = base_ctor.body.as_ref() else {
127            continue;
128        };
129
130        let args = resolve_base_constructor_args(
131            base_name.as_str(),
132            contract,
133            constructor,
134            &chain,
135            contract_map,
136        )?;
137
138        // Apply any modifiers used on the base constructor itself before substituting parameters.
139        let base_ctor_modifiers: Vec<Base> = base_ctor
140            .base_or_modifiers
141            .iter()
142            .filter(|b| {
143                let Some(name) = base_last_name(b) else {
144                    return false;
145                };
146                let arg_count = b.args.as_ref().map(|args| args.len()).unwrap_or(0);
147                modifier_defs.contains_key(&(name, arg_count))
148            })
149            .cloned()
150            .collect();
151
152        let base_wrapped = if base_ctor_modifiers.is_empty() {
153            base_body.clone()
154        } else {
155            apply_modifier_calls_to_body(base_body, &base_ctor_modifiers, modifier_defs)?
156        };
157
158        let substitutions = build_parameter_substitutions(&base_ctor.parameters, &args)?;
159        let rewritten = rewrite_statement(&base_wrapped, &substitutions, None);
160        prologue.extend(statement_list_from_body(&rewritten));
161    }
162
163    // Constructor modifiers exclude base constructor invocations.
164    let constructor_modifiers: Vec<Base> = constructor
165        .base_or_modifiers
166        .iter()
167        .filter(|b| {
168            let name = base_last_name(b);
169            match name {
170                Some(n) => !base_contracts.contains(&n),
171                None => true,
172            }
173        })
174        .cloned()
175        .collect();
176
177    let Some(body) = constructor.body.as_ref() else {
178        return Ok(Statement::Block {
179            loc: Loc::Implicit,
180            unchecked: false,
181            statements: prologue,
182        });
183    };
184
185    let wrapped = apply_modifier_calls_to_body(body, &constructor_modifiers, modifier_defs)?;
186    let mut statements = prologue;
187    match wrapped {
188        Statement::Block { statements: inner, .. } => statements.extend(inner),
189        other => statements.push(other),
190    }
191
192    Ok(Statement::Block {
193        loc: Loc::Implicit,
194        unchecked: false,
195        statements,
196    })
197}