Skip to content

Commit 304ddbf

Browse files
btf: Add support for bpf_core_type_matches()
This commit adds support for the latest edition to CO-RE. The `bpf_core_type_matches()` function allows you to check if a given type matches. This is a stricter check than the normal compatibility check. An example use case for this feature is to implement fallback code for cases where kernel types changed over time, such as the `block_rq_insert` tracepoint. The tracepoint lost an argument in the v5.11 kernel, so this feature can be used to handle the change in provided context in a CO-RE manner. Signed-off-by: Dylan Reimerink <[email protected]>
1 parent 6f21619 commit 304ddbf

File tree

6 files changed

+476
-18
lines changed

6 files changed

+476
-18
lines changed

btf/core.go

Lines changed: 270 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,11 @@ const (
127127
reloTypeSize /* type size in bytes */
128128
reloEnumvalExists /* enum value existence in target kernel */
129129
reloEnumvalValue /* enum value integer value */
130+
reloTypeMatches /* type matches kernel type */
130131
)
131132

132133
func (k coreKind) checksForExistence() bool {
133-
return k == reloEnumvalExists || k == reloTypeExists || k == reloFieldExists
134+
return k == reloEnumvalExists || k == reloTypeExists || k == reloFieldExists || k == reloTypeMatches
134135
}
135136

136137
func (k coreKind) String() string {
@@ -159,8 +160,10 @@ func (k coreKind) String() string {
159160
return "enumval_exists"
160161
case reloEnumvalValue:
161162
return "enumval_value"
163+
case reloTypeMatches:
164+
return "type_matches"
162165
default:
163-
return "unknown"
166+
return fmt.Sprintf("unknown (%d)", k)
164167
}
165168
}
166169

@@ -369,6 +372,21 @@ func coreCalculateFixup(relo *CORERelocation, target Type, targetID TypeID, bo b
369372
local := relo.typ
370373

371374
switch relo.kind {
375+
case reloTypeMatches:
376+
if len(relo.accessor) > 1 || relo.accessor[0] != 0 {
377+
return zero, fmt.Errorf("unexpected accessor %v", relo.accessor)
378+
}
379+
380+
err := coreTypesMatch(local, target, false, nil)
381+
if errors.Is(err, errIncompatibleTypes) {
382+
return poison()
383+
}
384+
if err != nil {
385+
return zero, err
386+
}
387+
388+
return fixup(1, 1)
389+
372390
case reloTypeIDTarget, reloTypeSize, reloTypeExists:
373391
if len(relo.accessor) > 1 || relo.accessor[0] != 0 {
374392
return zero, fmt.Errorf("unexpected accessor %v", relo.accessor)
@@ -1016,19 +1034,6 @@ func coreAreMembersCompatible(localType Type, targetType Type) error {
10161034
localType = UnderlyingType(localType)
10171035
targetType = UnderlyingType(targetType)
10181036

1019-
doNamesMatch := func(a, b string) error {
1020-
if a == "" || b == "" {
1021-
// allow anonymous and named type to match
1022-
return nil
1023-
}
1024-
1025-
if newEssentialName(a) == newEssentialName(b) {
1026-
return nil
1027-
}
1028-
1029-
return fmt.Errorf("names don't match: %w", errImpossibleRelocation)
1030-
}
1031-
10321037
_, lok := localType.(composite)
10331038
_, tok := targetType.(composite)
10341039
if lok && tok {
@@ -1045,13 +1050,261 @@ func coreAreMembersCompatible(localType Type, targetType Type) error {
10451050

10461051
case *Enum:
10471052
tv := targetType.(*Enum)
1048-
return doNamesMatch(lv.Name, tv.Name)
1053+
if !coreNamesMatch(lv.Name, tv.Name) {
1054+
return fmt.Errorf("names %q and %q don't match: %w", lv.Name, tv.Name, errImpossibleRelocation)
1055+
}
1056+
1057+
return nil
10491058

10501059
case *Fwd:
10511060
tv := targetType.(*Fwd)
1052-
return doNamesMatch(lv.Name, tv.Name)
1061+
if !coreNamesMatch(lv.Name, tv.Name) {
1062+
return fmt.Errorf("names %q and %q don't match: %w", lv.Name, tv.Name, errImpossibleRelocation)
1063+
}
1064+
1065+
return nil
10531066

10541067
default:
10551068
return fmt.Errorf("type %s: %w", localType, ErrNotSupported)
10561069
}
10571070
}
1071+
1072+
func coreNamesMatch(a, b string) bool {
1073+
if a == "" || b == "" {
1074+
// allow anonymous and named type to match
1075+
return true
1076+
}
1077+
1078+
return newEssentialName(a) == newEssentialName(b)
1079+
}
1080+
1081+
/* The comment below is from __bpf_core_types_match in relo_core.c:
1082+
*
1083+
* Check that two types "match". This function assumes that root types were
1084+
* already checked for name match.
1085+
*
1086+
* The matching relation is defined as follows:
1087+
* - modifiers and typedefs are stripped (and, hence, effectively ignored)
1088+
* - generally speaking types need to be of same kind (struct vs. struct, union
1089+
* vs. union, etc.)
1090+
* - exceptions are struct/union behind a pointer which could also match a
1091+
* forward declaration of a struct or union, respectively, and enum vs.
1092+
* enum64 (see below)
1093+
* Then, depending on type:
1094+
* - integers:
1095+
* - match if size and signedness match
1096+
* - arrays & pointers:
1097+
* - target types are recursively matched
1098+
* - structs & unions:
1099+
* - local members need to exist in target with the same name
1100+
* - for each member we recursively check match unless it is already behind a
1101+
* pointer, in which case we only check matching names and compatible kind
1102+
* - enums:
1103+
* - local variants have to have a match in target by symbolic name (but not
1104+
* numeric value)
1105+
* - size has to match (but enum may match enum64 and vice versa)
1106+
* - function pointers:
1107+
* - number and position of arguments in local type has to match target
1108+
* - for each argument and the return value we recursively check match
1109+
*/
1110+
func coreTypesMatch(localType Type, targetType Type, behindPtr bool, visited map[pair]struct{}) error {
1111+
localType = UnderlyingType(localType)
1112+
targetType = UnderlyingType(targetType)
1113+
1114+
if !coreNamesMatch(localType.TypeName(), targetType.TypeName()) {
1115+
return fmt.Errorf("type name %q don't match %q: %w", localType.TypeName(), targetType.TypeName(), errIncompatibleTypes)
1116+
}
1117+
1118+
if _, ok := visited[pair{localType, targetType}]; ok {
1119+
return nil
1120+
}
1121+
if visited == nil {
1122+
visited = make(map[pair]struct{})
1123+
}
1124+
visited[pair{localType, targetType}] = struct{}{}
1125+
1126+
switch lv := (localType).(type) {
1127+
case *Void:
1128+
if _, ok := targetType.(*Void); !ok {
1129+
return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1130+
}
1131+
1132+
case *Fwd:
1133+
if behindPtr {
1134+
if tv, ok := targetType.(*Fwd); ok {
1135+
if lv.Kind != tv.Kind {
1136+
return fmt.Errorf("fwd kind mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1137+
}
1138+
1139+
return nil
1140+
}
1141+
1142+
if _, ok := targetType.(*Struct); ok && lv.Kind == FwdStruct {
1143+
return nil
1144+
}
1145+
1146+
if _, ok := targetType.(*Union); ok && lv.Kind == FwdUnion {
1147+
return nil
1148+
}
1149+
1150+
return fmt.Errorf("fwd kind mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1151+
}
1152+
1153+
if tv, ok := targetType.(*Fwd); ok && tv.Kind == lv.Kind {
1154+
return nil
1155+
}
1156+
1157+
return fmt.Errorf("fwd kind mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1158+
1159+
case *Enum:
1160+
tv, ok := targetType.(*Enum)
1161+
if !ok {
1162+
return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1163+
}
1164+
1165+
if err := coreEnumsMatch(lv, tv); err != nil {
1166+
return err
1167+
}
1168+
1169+
case composite:
1170+
if behindPtr {
1171+
if reflect.TypeOf(localType) == reflect.TypeOf(targetType) {
1172+
return nil
1173+
}
1174+
1175+
tv, ok := targetType.(*Fwd)
1176+
if !ok {
1177+
return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1178+
}
1179+
1180+
if _, ok := lv.(*Struct); ok && tv.Kind == FwdStruct {
1181+
return nil
1182+
}
1183+
1184+
if _, ok := lv.(*Union); ok && tv.Kind == FwdUnion {
1185+
return nil
1186+
}
1187+
1188+
return fmt.Errorf("fwd kind mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1189+
}
1190+
1191+
if reflect.TypeOf(localType) != reflect.TypeOf(targetType) {
1192+
return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1193+
}
1194+
tComp, ok := targetType.(composite)
1195+
if !ok {
1196+
return fmt.Errorf("expected composite type, got %T", targetType)
1197+
}
1198+
1199+
if err := coreCompositesMatch(lv, tComp, behindPtr, visited); err != nil {
1200+
return err
1201+
}
1202+
1203+
case *Int:
1204+
tv, ok := targetType.(*Int)
1205+
if !ok {
1206+
return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1207+
}
1208+
1209+
if lv.Size != tv.Size || (lv.Encoding == Signed) != (tv.Encoding == Signed) {
1210+
return fmt.Errorf("int mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1211+
}
1212+
1213+
case *Pointer:
1214+
tv, ok := targetType.(*Pointer)
1215+
if !ok {
1216+
return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1217+
}
1218+
1219+
return coreTypesMatch(lv.Target, tv.Target, true, visited)
1220+
1221+
case *Array:
1222+
tv, ok := targetType.(*Array)
1223+
if !ok {
1224+
return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1225+
}
1226+
1227+
if lv.Nelems != tv.Nelems {
1228+
return fmt.Errorf("array mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1229+
}
1230+
1231+
return coreTypesMatch(lv.Type, tv.Type, behindPtr, visited)
1232+
1233+
case *FuncProto:
1234+
tv, ok := targetType.(*FuncProto)
1235+
if !ok {
1236+
return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1237+
}
1238+
1239+
if len(lv.Params) != len(tv.Params) {
1240+
return fmt.Errorf("function param mismatch: %w", errIncompatibleTypes)
1241+
}
1242+
1243+
for i, lparam := range lv.Params {
1244+
if err := coreTypesMatch(lparam.Type, tv.Params[i].Type, behindPtr, visited); err != nil {
1245+
return err
1246+
}
1247+
}
1248+
1249+
return coreTypesMatch(lv.Return, tv.Return, behindPtr, visited)
1250+
1251+
default:
1252+
return fmt.Errorf("unsupported type %T", localType)
1253+
}
1254+
1255+
return nil
1256+
}
1257+
1258+
// coreEnumsMatch checks two enums match, which is considered to be the case if the following is true:
1259+
// - size has to match (but enum may match enum64 and vice versa)
1260+
// - local variants have to have a match in target by symbolic name (but not numeric value)
1261+
func coreEnumsMatch(local *Enum, target *Enum) error {
1262+
if local.Size != target.Size {
1263+
return fmt.Errorf("size mismatch between %v and %v: %w", local, target, errIncompatibleTypes)
1264+
}
1265+
1266+
// If there are more values in the local than the target, there must be at least one value in the local
1267+
// that isn't in the target, and therefor the types are incompatible.
1268+
if len(local.Values) > len(target.Values) {
1269+
return fmt.Errorf("local has more values than target: %w", errIncompatibleTypes)
1270+
}
1271+
1272+
outer:
1273+
for _, lv := range local.Values {
1274+
for _, rv := range target.Values {
1275+
if coreNamesMatch(lv.Name, rv.Name) {
1276+
continue outer
1277+
}
1278+
}
1279+
1280+
return fmt.Errorf("no match for %v in %v: %w", lv, target, errIncompatibleTypes)
1281+
}
1282+
1283+
return nil
1284+
}
1285+
1286+
func coreCompositesMatch(localType, targetType composite, behindPtr bool, visited map[pair]struct{}) error {
1287+
if reflect.TypeOf(localType) != reflect.TypeOf(targetType) {
1288+
return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes)
1289+
}
1290+
1291+
localMembers := localType.members()
1292+
targetMembers := map[string]Member{}
1293+
for _, member := range targetType.members() {
1294+
targetMembers[member.Name] = member
1295+
}
1296+
1297+
for _, localMember := range localMembers {
1298+
targetMember, found := targetMembers[localMember.Name]
1299+
if !found {
1300+
return fmt.Errorf("no field %q in %v: %w", localMember.Name, targetType, errIncompatibleTypes)
1301+
}
1302+
1303+
err := coreTypesMatch(localMember.Type, targetMember.Type, behindPtr, visited)
1304+
if err != nil {
1305+
return err
1306+
}
1307+
}
1308+
1309+
return nil
1310+
}

0 commit comments

Comments
 (0)