Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 143 additions & 44 deletions clippy_lints/src/missing_asserts_for_indexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::mem;
use std::ops::ControlFlow;

use clippy_utils::comparisons::{Rel, normalize_comparison};
use clippy_utils::consts::{ConstEvalCtxt, Constant};
use clippy_utils::diagnostics::span_lint_and_then;
use clippy_utils::higher::{If, Range};
use clippy_utils::macros::{find_assert_eq_args, first_node_macro_backtrace, root_macro_call};
Expand Down Expand Up @@ -91,35 +92,133 @@ enum LengthComparison {
/// `v.len() == 5`
LengthEqualInt,
}
#[derive(Copy, Clone, Debug)]
struct EvaluatedIntExpr<'hir> {
expr: &'hir Expr<'hir>,
value: usize,
}
#[derive(Copy, Clone, Debug)]
enum AssertionSide<'hir> {
/// `v.len()` in `v.len() > 5`
SliceLen {
/// `v` in `v.len()`
slice: &'hir Expr<'hir>,
},
/// `5` in `v.len() > 5`
AssertedLen(EvaluatedIntExpr<'hir>),
}
impl<'hir> AssertionSide<'hir> {
pub fn from_expr(cx: &LateContext<'_>, expr: &'hir Expr<'hir>) -> Option<Self> {
Self::asserted_len_from_int_lit_expr(expr)
.or_else(|| Self::slice_len_from_expr(cx, expr))
.or_else(|| Self::asserted_len_from_possibly_const_expr(cx, expr))
}
pub fn slice_len_from_expr(cx: &LateContext<'_>, expr: &'hir Expr<'hir>) -> Option<Self> {
if let ExprKind::MethodCall(method, recv, [], _) = expr.kind
// checking method name first rather than receiver's type could improve performance
&& method.ident.name == sym::len
&& cx.typeck_results().expr_ty_adjusted(recv).peel_refs().is_slice()
{
Some(Self::SliceLen { slice: recv })
} else {
None
}
}
pub fn asserted_len_from_expr(cx: &LateContext<'_>, expr: &'hir Expr<'hir>) -> Option<Self> {
Self::asserted_len_from_int_lit_expr(expr).or_else(|| Self::asserted_len_from_possibly_const_expr(cx, expr))
}
pub fn asserted_len_from_int_lit_expr(expr: &'hir Expr<'hir>) -> Option<Self> {
if let ExprKind::Lit(Spanned {
node: LitKind::Int(Pu128(x), _),
..
}) = expr.kind
{
Some(Self::AssertedLen(EvaluatedIntExpr {
expr,
value: x as usize,
}))
} else {
None
}
}
pub fn asserted_len_from_possibly_const_expr(cx: &LateContext<'_>, expr: &'hir Expr<'hir>) -> Option<Self> {
if let Some(Constant::Int(x)) = ConstEvalCtxt::new(cx).eval(expr) {
Some(Self::AssertedLen(EvaluatedIntExpr {
expr,
value: x as usize,
}))
} else {
None
}
}
}

/// Extracts parts out of a length comparison expression.
///
/// E.g. for `v.len() > 5` this returns `Some((LengthComparison::IntLessThanLength, 5, v.len()))`
fn len_comparison<'hir>(
cx: &LateContext<'_>,
bin_op: BinOpKind,
left: &'hir Expr<'hir>,
right: &'hir Expr<'hir>,
) -> Option<(LengthComparison, usize, &'hir Expr<'hir>)> {
macro_rules! int_lit_pat {
($id:ident) => {
ExprKind::Lit(Spanned {
node: LitKind::Int(Pu128($id), _),
..
})
};
left_expr: &'hir Expr<'hir>,
right_expr: &'hir Expr<'hir>,
) -> Option<(LengthComparison, EvaluatedIntExpr<'hir>, &'hir Expr<'hir>)> {
fn sniff_operands<'hir>(
cx: &LateContext<'_>,
left: &'hir Expr<'hir>,
right: &'hir Expr<'hir>,
) -> Option<(AssertionSide<'hir>, AssertionSide<'hir>)> {
// sniff as cheap as possible
if let Some(left) = AssertionSide::asserted_len_from_int_lit_expr(left) {
Some((left, AssertionSide::from_expr(cx, right)?))
} else if let Some(left) = AssertionSide::slice_len_from_expr(cx, left) {
Some((left, AssertionSide::asserted_len_from_expr(cx, right)?))
} else {
Some((
AssertionSide::asserted_len_from_possibly_const_expr(cx, left)?,
AssertionSide::slice_len_from_expr(cx, right)?,
))
}
}

type Side<'hir> = AssertionSide<'hir>;

// normalize comparison, `v.len() > 4` becomes `4 < v.len()`
// this simplifies the logic a bit
let (op, left, right) = normalize_comparison(bin_op, left, right)?;
match (op, left.kind, right.kind) {
(Rel::Lt, int_lit_pat!(left), _) => Some((LengthComparison::IntLessThanLength, left as usize, right)),
(Rel::Lt, _, int_lit_pat!(right)) => Some((LengthComparison::LengthLessThanInt, right as usize, left)),
(Rel::Le, int_lit_pat!(left), _) => Some((LengthComparison::IntLessThanOrEqualLength, left as usize, right)),
(Rel::Le, _, int_lit_pat!(right)) => Some((LengthComparison::LengthLessThanOrEqualInt, right as usize, left)),
(Rel::Eq, int_lit_pat!(left), _) => Some((LengthComparison::LengthEqualInt, left as usize, right)),
(Rel::Eq, _, int_lit_pat!(right)) => Some((LengthComparison::LengthEqualInt, right as usize, left)),
_ => None,
let (op, left_expr, right_expr) = normalize_comparison(bin_op, left_expr, right_expr)?;

let (left, right) = sniff_operands(cx, left_expr, right_expr)?;
let (swapped, asserted_len, slice) = match (left, right) {
// `A > B` (e.g. `5 > 4`)
| (Side::AssertedLen(_), Side::AssertedLen(_))
// `v.len() > w.len()`
| (Side::SliceLen { .. }, Side::SliceLen { .. }) => return None,
(Side::AssertedLen(asserted_len), Side::SliceLen { slice }) => {
(false, asserted_len, slice)
},
(Side::SliceLen { slice }, Side::AssertedLen(asserted_len)) => {
(true, asserted_len, slice)
},
};

match op {
Rel::Lt => {
let cmp = if swapped {
LengthComparison::LengthLessThanInt
} else {
LengthComparison::IntLessThanLength
};
Some((cmp, asserted_len, slice))
},
Rel::Le => {
let cmp = if swapped {
LengthComparison::LengthLessThanOrEqualInt
} else {
LengthComparison::IntLessThanOrEqualLength
};
Some((cmp, asserted_len, slice))
},
Rel::Eq => Some((LengthComparison::LengthEqualInt, asserted_len, slice)),
Rel::Ne => None,
}
}

Expand All @@ -132,15 +231,15 @@ fn len_comparison<'hir>(
fn assert_len_expr<'hir>(
cx: &LateContext<'_>,
expr: &'hir Expr<'hir>,
) -> Option<(LengthComparison, usize, &'hir Expr<'hir>, Symbol)> {
let ((cmp, asserted_len, slice_len), macro_call) = if let Some(If { cond, then, .. }) = If::hir(expr)
) -> Option<(LengthComparison, EvaluatedIntExpr<'hir>, &'hir Expr<'hir>, Symbol)> {
let ((cmp, asserted_len, slice), macro_call) = if let Some(If { cond, then, .. }) = If::hir(expr)
&& let ExprKind::Unary(UnOp::Not, condition) = &cond.kind
&& let ExprKind::Binary(bin_op, left, right) = &condition.kind
// check if `then` block has a never type expression
&& let ExprKind::Block(Block { expr: Some(then_expr), .. }, _) = then.kind
&& cx.typeck_results().expr_ty(then_expr).is_never()
{
(len_comparison(bin_op.node, left, right)?, sym::assert_macro)
(len_comparison(cx, bin_op.node, left, right)?, sym::assert_macro)
} else if let Some((macro_call, bin_op)) = first_node_macro_backtrace(cx, expr).find_map(|macro_call| {
match cx.tcx.get_diagnostic_name(macro_call.def_id) {
Some(sym::assert_eq_macro) => Some((macro_call, BinOpKind::Eq)),
Expand All @@ -150,7 +249,7 @@ fn assert_len_expr<'hir>(
}) && let Some((left, right, _)) = find_assert_eq_args(cx, expr, macro_call.expn)
{
(
len_comparison(bin_op, left, right)?,
len_comparison(cx, bin_op, left, right)?,
root_macro_call(expr.span)
.and_then(|macro_call| cx.tcx.get_diagnostic_name(macro_call.def_id))
.unwrap_or(sym::assert_macro),
Expand All @@ -159,21 +258,14 @@ fn assert_len_expr<'hir>(
return None;
};

if let ExprKind::MethodCall(method, recv, [], _) = &slice_len.kind
&& cx.typeck_results().expr_ty_adjusted(recv).peel_refs().is_slice()
&& method.ident.name == sym::len
{
Some((cmp, asserted_len, recv, macro_call))
} else {
None
}
Some((cmp, asserted_len, slice, macro_call))
}

#[derive(Debug)]
enum IndexEntry<'hir> {
/// `assert!` without any indexing (so far)
StrayAssert {
asserted_len: usize,
asserted_len: EvaluatedIntExpr<'hir>,
comparison: LengthComparison,
assert_span: Span,
slice: &'hir Expr<'hir>,
Expand All @@ -186,7 +278,7 @@ enum IndexEntry<'hir> {
AssertWithIndex {
highest_index: usize,
is_first_highest: bool,
asserted_len: usize,
asserted_len: EvaluatedIntExpr<'hir>,
assert_span: Span,
slice: &'hir Expr<'hir>,
indexes: Vec<Span>,
Expand Down Expand Up @@ -367,22 +459,29 @@ fn report_indexes(cx: &LateContext<'_>, map: UnindexMap<u64, Vec<IndexEntry<'_>>
Some(format!("assert!({slice_str}.len() > {highest_index})",))
},
// `5 < v.len()` == `v.len() > 5`
LengthComparison::IntLessThanLength if asserted_len < highest_index => {
LengthComparison::IntLessThanLength if asserted_len.value < highest_index => {
Some(format!("assert!({slice_str}.len() > {highest_index})",))
},
// `5 <= v.len() == `v.len() >= 5`
LengthComparison::IntLessThanOrEqualLength if asserted_len <= highest_index => {
Some(format!("assert!({slice_str}.len() > {highest_index})",))
LengthComparison::IntLessThanOrEqualLength if asserted_len.value == highest_index => {
let asserted_len_str =
snippet_with_applicability(cx, asserted_len.expr.span, "_", &mut app);
Some(format!("assert!({slice_str}.len() > {asserted_len_str})",))
},
LengthComparison::IntLessThanOrEqualLength if asserted_len.value < highest_index => {
Some(format!("assert!({slice_str}.len() >= {})", highest_index + 1))
},
// `highest_index` here is rather a length, so we need to add 1 to it
LengthComparison::LengthEqualInt if asserted_len < highest_index + 1 => match macro_call {
sym::assert_eq_macro => {
Some(format!("assert_eq!({slice_str}.len(), {})", highest_index + 1))
},
sym::debug_assert_eq_macro => {
Some(format!("debug_assert_eq!({slice_str}.len(), {})", highest_index + 1))
},
_ => Some(format!("assert!({slice_str}.len() == {})", highest_index + 1)),
LengthComparison::LengthEqualInt if asserted_len.value < highest_index + 1 => {
match macro_call {
sym::assert_eq_macro => {
Some(format!("assert_eq!({slice_str}.len(), {})", highest_index + 1))
},
sym::debug_assert_eq_macro => {
Some(format!("debug_assert_eq!({slice_str}.len(), {})", highest_index + 1))
},
_ => Some(format!("assert!({slice_str}.len() == {})", highest_index + 1)),
}
},
_ => None,
};
Expand Down
20 changes: 20 additions & 0 deletions tests/ui/missing_asserts_for_indexing.fixed
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,24 @@ mod issue15988 {
}
}

fn assert_cmp_to_const(v1: &[u8], v2: &[u8], v3: &[u8], v4: &[u8]) {
const N: usize = 2;

assert!(v1.len() >= N);
assert!(v2.len() > N);
assert!(v3.len() >= 4);
assert!(v4.len() > 3);

let _ = v1[0] + v1[1];

let _ = v2[0] + v2[1] + v2[2];
//~^ missing_asserts_for_indexing

let _ = v3[0] + v3[1] + v3[2] + v3[3];
//~^ missing_asserts_for_indexing

let _ = v4[0] + v4[1] + v4[2] + v4[3];
//~^ missing_asserts_for_indexing
}

fn main() {}
20 changes: 20 additions & 0 deletions tests/ui/missing_asserts_for_indexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,24 @@ mod issue15988 {
}
}

fn assert_cmp_to_const(v1: &[u8], v2: &[u8], v3: &[u8], v4: &[u8]) {
const N: usize = 2;

assert!(v1.len() >= N);
assert!(v2.len() >= N);
assert!(v3.len() >= N);
assert!(v4.len() > N);

let _ = v1[0] + v1[1];

let _ = v2[0] + v2[1] + v2[2];
//~^ missing_asserts_for_indexing

let _ = v3[0] + v3[1] + v3[2] + v3[3];
//~^ missing_asserts_for_indexing

let _ = v4[0] + v4[1] + v4[2] + v4[3];
//~^ missing_asserts_for_indexing
}

fn main() {}
38 changes: 37 additions & 1 deletion tests/ui/missing_asserts_for_indexing.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -187,5 +187,41 @@ LL - debug_assert_eq!(v.len(), 2);
LL + debug_assert_eq!(v.len(), 3);
|

error: aborting due to 15 previous errors
error: indexing into a slice multiple times with an `assert` that does not cover the highest index
--> tests/ui/missing_asserts_for_indexing.rs:193:13
|
LL | let _ = v2[0] + v2[1] + v2[2];
| ^^^^^ ^^^^^ ^^^^^
|
help: provide the highest index that is indexed with
|
LL - assert!(v2.len() >= N);
LL + assert!(v2.len() > N);
|

error: indexing into a slice multiple times with an `assert` that does not cover the highest index
--> tests/ui/missing_asserts_for_indexing.rs:196:13
|
LL | let _ = v3[0] + v3[1] + v3[2] + v3[3];
| ^^^^^ ^^^^^ ^^^^^ ^^^^^
|
help: provide the highest index that is indexed with
|
LL - assert!(v3.len() >= N);
LL + assert!(v3.len() >= 4);
|

error: indexing into a slice multiple times with an `assert` that does not cover the highest index
--> tests/ui/missing_asserts_for_indexing.rs:199:13
|
LL | let _ = v4[0] + v4[1] + v4[2] + v4[3];
| ^^^^^ ^^^^^ ^^^^^ ^^^^^
|
help: provide the highest index that is indexed with
|
LL - assert!(v4.len() > N);
LL + assert!(v4.len() > 3);
|

error: aborting due to 18 previous errors