feat(rosbag_trans): add support to convert multiple aimrt bags into a single ROS bag (#90)

* feat: add functionality to convert multiple aimrt bags into a single ROS bag

* refactor(rosbag_trans): add DatabaseManager class to unify database operations

* fix: format the code

* fix: format code

* feat(bagtrans_tool): Update Command Line Tool Documentation
This commit is contained in:
ATT_POWER 2024-11-08 15:53:21 +08:00 committed by GitHub
parent 29eb541fd6
commit e01333313f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 303 additions and 237 deletions

View File

@ -23,10 +23,10 @@ aimrt_cli trans -h, --help 会显示参数说明:
options:
-h, --help show this help message and exit
-s SRC_DIR, --src_dir SRC_DIR
aimrtbag source directory.
-s SRC_DIR [SRC_DIR ...], --src_dir SRC_DIR [SRC_DIR ...]
aimrtbag source directories (support multiple directories)
-o OUTPUT_DIR, --output_dir OUTPUT_DIR
directory you want to output your files.
directory you want to output your files
```
其中 `-s` 参数为必填参数,表示 aimrtbag 的源目录,`-o` 参数为必填参数表示转换后的bag的输出目录如果输出目录不存在则会自动创建如果输出目录存在则会覆盖。
其中 `-s` 参数为必填参数,表示 aimrtbag 的源目录,支持多个目录,`-o` 参数为必填参数表示转换后的bag的输出目录如果输出目录不存在则会自动创建如果输出目录存在则会覆盖。

View File

@ -3,7 +3,7 @@
from aimrt_cli.command import CommandBase
from aimrt_cli.generator.project_generator import ProjectGenerator
from aimrt_cli.trans.rosbag_trans import RosbagTrans
from aimrt_cli.trans.rosbag_trans import AimrtbagToRos2
class GenCommand(CommandBase):
@ -25,5 +25,5 @@ class GenCommand(CommandBase):
generator = ProjectGenerator(cfg_path=args.project_cfg, output_dir=args.output_dir)
generator.generate()
trans = RosbagTrans(args.src_dir, args.output_dir)
trans = AimrtbagToRos2(args.src_dir, args.output_dir)
trans.trans()

View File

@ -2,7 +2,7 @@
# All rights reserved.
from aimrt_cli.command import CommandBase
from aimrt_cli.trans.rosbag_trans import RosbagTrans
from aimrt_cli.trans.rosbag_trans import AimrtbagToRos2
class TransCommand(CommandBase):
@ -13,12 +13,16 @@ class TransCommand(CommandBase):
def add_arguments(self, parser, cmd_name):
if cmd_name == "trans":
self.parser_ = parser
parser.add_argument("-s", "--src_dir", help="aimrtbag source directory.")
parser.add_argument(
"-s",
"--src_dir",
nargs='+',
help="aimrtbag source directories (support multiple directories).")
parser.add_argument("-o", "--output_dir", help="directory you want to output your files.")
def main(self, *, args=None):
if args is None:
self.parser_.print_help()
return 0
trans = RosbagTrans(args.src_dir, args.output_dir)
trans = AimrtbagToRos2(args.src_dir, args.output_dir)
trans.trans()

View File

@ -21,7 +21,11 @@ def main(description=None):
# bag trans sub command
trans_parser = subparsers.add_parser('trans', help='Transform bag files')
trans_parser.add_argument("-s", "--src_dir", help="aimrtbag source directory")
trans_parser.add_argument(
"-s",
"--src_dir",
nargs='+',
help="aimrtbag source directories (support multiple directories)")
trans_parser.add_argument("-o", "--output_dir", help="directory you want to output your files")
args = parser.parse_args()

View File

@ -15,14 +15,74 @@ class IndentDumper(yaml.Dumper):
return super(IndentDumper, self).increase_indent(flow, False)
class SingleBagProcess:
class DatabaseManager:
def __init__(self, db_path: str):
self.db_path = db_path
self.conn = None
self.cursor = None
def connect(self):
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
return self.conn, self.cursor
def create_tables(self):
try:
# create messages table
self.cursor.execute("""
CREATE TABLE messages(
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
topic_id INTEGER NOT NULL,
timestamp INTEGER NOT NULL,
data BLOB NOT NULL)
""")
# create topics table
self.cursor.execute("""
CREATE TABLE topics(
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
type TEXT NOT NULL,
serialization_format TEXT NOT NULL,
offered_qos_profiles TEXT NOT NULL)
""")
# create schema table
self.cursor.execute("""
CREATE TABLE "schema" (
"schema_version" INTEGER,
"ros_distro" TEXT NOT NULL,
PRIMARY KEY("schema_version")
);
""")
self.cursor.execute("""
INSERT INTO schema (schema_version, ros_distro)
VALUES (?, ?)
""", (3, "humble"))
# create metadata table
self.cursor.execute("""
CREATE TABLE metadata(id INTEGER PRIMARY KEY,metadata_version INTEGER NOT NULL,metadata TEXT NOT NULL)
""")
self.conn.commit()
except sqlite3.Error as e:
self.conn.rollback()
raise e
def close(self):
if self.cursor:
self.cursor.close()
if self.conn:
self.conn.close()
class SingleDbProcess:
def __init__(self, topic_info_dict: dict, db_path: Path):
self.message_count = 0
self.duration_nanoseconds = 0
self.starting_time_nanoseconds = 100000000000000000000
self.topic_with_message_count = {}
self.topic_info_dict = topic_info_dict
self.starting_time_nanoseconds = int(1e20)
self.end_time_nanoseconds = 0
self.db_path = db_path
self.get_info()
@ -30,21 +90,25 @@ class SingleBagProcess:
try:
cursor.execute("SELECT topic_id, timestamp FROM messages")
rows = sorted(cursor.fetchall())
self.starting_time_nanoseconds = min(self.starting_time_nanoseconds, rows[0][1])
self.duration_nanoseconds = rows[-1][1] - self.starting_time_nanoseconds
self.message_count = len(rows)
for row in rows:
self.topic_with_message_count[self.topic_info_dict[row[0]].topic_name] = self.topic_with_message_count.get(
self.topic_info_dict[row[0]].topic_name, 0) + 1
if rows:
self.starting_time_nanoseconds = rows[0][1]
self.end_time_nanoseconds = rows[-1][1]
self.message_count = len(rows)
for row in rows:
topic_name = self.topic_info_dict[row[0]].topic_name
self.topic_with_message_count[topic_name] = \
self.topic_with_message_count.get(topic_name, 0) + 1
except Exception as e:
print(f"Error getting single bag info: {e}")
conn.rollback()
def get_info(self):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
self.get_bag_info(conn, cursor)
db_manager = DatabaseManager(str(self.db_path))
conn, cursor = db_manager.connect()
try:
self.get_bag_info(conn, cursor)
finally:
db_manager.close()
@dataclass
@ -53,152 +117,116 @@ class TopicInfo:
topic_name: str
msg_type: str
serialization_type: str
message_count: int
def encode_topic_name(topic_name: str, msg_type: str):
if msg_type.startswith("pb"):
return topic_name + '/' + msg_type.replace('/', '_2F').replace(':', '_3A').replace('.', '_2E')
else:
return topic_name
class RosbagTrans(TransBase):
def __init__(self, input_dir: str, output_dir: str):
class SingleBagTrans(TransBase):
def __init__(self, input_dir: str, output_dir: str, conn: sqlite3.Connection, cursor: sqlite3.Cursor, id: int):
super().__init__(output_dir)
self.input_dir_ = input_dir
self.input_dir = input_dir
self.output_dir = output_dir
self.topics_list = {}
self.topic_info_dict = {}
self.files_list = {}
self.bag_info_list = []
self.message_count = 0
self.all_duration = 0
self.topic_with_message_count = {}
self.starting_time_nanoseconds = 100000000000000000000
self.duration_nanoseconds = 0
self.rosbag_yaml_data = {
"version": 5,
"storage_identifier": "sqlite3",
"duration": {
"nanoseconds": 0
},
"starting_time": {
"nanoseconds_since_epoch": 0
},
"message_count": 0,
"topics_with_message_count": [],
"compression_format": "",
"compression_mode": "",
"relative_file_paths": [],
"files": []
}
def copy_file(self):
if os.path.exists(self.output_dir_):
shutil.rmtree(self.output_dir_)
try:
shutil.copytree(self.input_dir_, self.output_dir_)
print(f"Directory successfully copied from {self.input_dir_} to {self.output_dir_}")
except shutil.Error as e:
print(f"Copy error: {e}")
except OSError as e:
print(f"System error: {e}")
self.starting_time_nanoseconds = int(1e20)
self.end_time_nanoseconds = 0
self.id = id # target db message id
self.conn = conn # target db connection
self.cursor = cursor # target db cursor
def parse_yaml(self):
with open(os.path.join(self.output_dir_, "metadata.yaml"), "r") as f:
with open(os.path.join(self.input_dir, "metadata.yaml"), "r") as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if data["aimrt_bagfile_information"] is not None:
if data["aimrt_bagfile_information"]["topics"] is not None:
self.topics_list = data["aimrt_bagfile_information"]["topics"]
for topic in self.topics_list:
self.topic_info_dict[topic["id"]] = TopicInfo(
topic["id"], topic["topic_name"], topic["msg_type"], topic["serialization_type"])
else:
raise Exception("No topics found in metadata.yaml")
if data["aimrt_bagfile_information"]["files"] is not None:
self.files_list = data["aimrt_bagfile_information"]["files"]
else:
raise Exception("No files found in metadata.yaml")
else:
if data is None or data["aimrt_bagfile_information"] is None or data["aimrt_bagfile_information"]["topics"] is None:
raise Exception("No aimrt_bagfile_information found in metadata.yaml")
self.rosbag_yaml_data = {
"version": 5,
"storage_identifier": "sqlite3",
}
return data
def update_rosbag_yaml(self):
self.rosbag_yaml_data = {
"version": 5,
"storage_identifier": "sqlite3",
"duration": {
"nanoseconds": self.all_duration
},
"starting_time": {
"nanoseconds_since_epoch": self.starting_time_nanoseconds
},
"message_count": self.message_count,
"topics_with_message_count": [],
"compression_format": "",
"compression_mode": "",
"relative_file_paths": [],
"files": [],
}
def transfertopic_msg_type(msg_type):
if msg_type.startswith("pb"):
return "ros2_plugin_proto/msg/RosMsgWrapper"
elif msg_type.startswith("ros2"):
return msg_type.replace("ros2:", "")
else:
return msg_type
self.topics_list = data["aimrt_bagfile_information"]["topics"]
for topic in self.topics_list:
self.topic_info_dict[topic["id"]] = TopicInfo(
topic["id"], topic["topic_name"], topic["msg_type"], topic["serialization_type"], 0)
topic_message_count = self.topic_with_message_count.get(topic["topic_name"], 0)
topic_entry = {
"topic_metadata": {
"name": encode_topic_name(topic["topic_name"], topic["msg_type"]),
"type": transfertopic_msg_type(topic["msg_type"]),
"serialization_format": "cdr",
"offered_qos_profiles": self.format_qos_profiles()
},
"message_count": topic_message_count
}
self.rosbag_yaml_data["topics_with_message_count"].append(topic_entry)
if data["aimrt_bagfile_information"]["files"] is not None:
self.files_list = data["aimrt_bagfile_information"]["files"]
else:
raise Exception("No db files found in metadata.yaml")
for file_info in self.bag_info_list:
self.rosbag_yaml_data["relative_file_paths"].append(file_info.db_path.name)
file_entry = {
"path": file_info.db_path.name,
"starting_time": {
"nanoseconds_since_epoch": file_info.starting_time_nanoseconds
},
"duration": {
"nanoseconds": file_info.duration_nanoseconds
},
"message_count": file_info.message_count
}
self.rosbag_yaml_data["files"].append(file_entry)
def trans_single_db(self, source_path: Path, topic_map: dict):
single_bag_info = SingleDbProcess(self.topic_info_dict, source_path)
self.message_count += single_bag_info.message_count
self.starting_time_nanoseconds = min(self.starting_time_nanoseconds, single_bag_info.starting_time_nanoseconds)
self.end_time_nanoseconds = max(self.end_time_nanoseconds, single_bag_info.end_time_nanoseconds)
final_yaml_data = {
"rosbag2_bagfile_information": self.rosbag_yaml_data
}
conn = sqlite3.connect(source_path)
print(f" processing db file: {source_path}")
cursor = conn.cursor()
abs_output_dir = os.path.abspath(self.output_dir_)
with open(os.path.join(abs_output_dir, "metadata.yaml"), "w") as f:
yaml_str = yaml.dump(
final_yaml_data,
Dumper=IndentDumper,
default_flow_style=False,
sort_keys=False,
indent=2,
width=1000000)
yaml_str = yaml_str.replace("\'", "\"")
f.write(yaml_str)
print(f"{os.path.join(abs_output_dir, 'metadata.yaml')} has been updated")
try:
select_sql = "SELECT id,topic_id, timestamp, data FROM messages"
cursor.execute(select_sql)
rows = cursor.fetchall()
self.cursor.executemany("""
INSERT INTO messages (id, topic_id, timestamp, data)
VALUES (?, ?, ?, ?)
""", [(self.id + row[0], topic_map[self.topic_info_dict[row[1]].topic_name].topic_id, row[2], row[3]) for row in rows])
for row in rows:
topic_map[self.topic_info_dict[row[1]].topic_name].message_count += 1
self.conn.commit()
print(f" size of data inserted: {len(rows)} done")
except Exception as e:
print(f" Error updating messages table: {e}")
self.conn.rollback()
self.id += len(rows)
def trans_single_bag(self, topic_map: dict):
self.parse_yaml()
print(f"there are {len(self.files_list)} db files in {self.input_dir}")
for db_path in self.files_list:
trans_path = Path(self.output_dir) / db_path['path']
self.trans_single_db(Path(self.input_dir) / db_path['path'], topic_map)
print(f" trans_path: {trans_path} done")
print(f"all db files in {self.input_dir} done\n")
class AimrtbagToRos2:
def __init__(self, input_dir: list, output_dir: str):
self.input_dir = input_dir
self.output_dir = output_dir
self.topic_map = {}
self.id = 0
self.message_count = 0
self.starting_time_nanoseconds = int(1e20)
self.end_time_nanoseconds = 0
self.topics_list = []
self.db_manager = None
self.conn = None
self.cursor = None
def create_output_dir(self):
if os.path.exists(self.output_dir):
shutil.rmtree(self.output_dir)
os.makedirs(self.output_dir)
# initialize database
db_path = os.path.join(self.output_dir, "rosbag.db3")
self.db_manager = DatabaseManager(db_path)
self.conn, self.cursor = self.db_manager.connect()
self.db_manager.create_tables()
def parse_yaml(self, input_dir: str):
with open(os.path.join(input_dir, "metadata.yaml"), "r") as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if data["aimrt_bagfile_information"] is None or data["aimrt_bagfile_information"]["topics"] is None:
raise Exception("No topics information found in metadata.yaml")
topics_list = data["aimrt_bagfile_information"]["topics"]
for topic in topics_list:
if topic["topic_name"] not in self.topic_map:
self.id += 1
self.topic_map[topic["topic_name"]] = TopicInfo(
self.id, topic["topic_name"], topic["msg_type"], topic["serialization_type"], 0)
else:
print(f"warning: topic {topic['topic_name']} already exists")
def format_qos_profiles(self):
qos_dict = {
@ -221,26 +249,8 @@ class RosbagTrans(TransBase):
return qos_string
def update_messages_table(self, conn, cursor):
def insert_topics_table(self):
try:
cursor.execute("UPDATE messages SET topic_id = topic_id + 1")
conn.commit()
except Exception as e:
print(f"Error update messages table, error: {e}")
conn.rollback()
def insert_topics_table(self, conn, cursor):
try:
cursor.execute("""
CREATE TABLE IF NOT EXISTS "topics" (
"id" INTEGER,
"name" TEXT NOT NULL,
"type" TEXT NOT NULL,
"serialization_format" TEXT NOT NULL,
"offered_qos_profiles" TEXT NOT NULL,
PRIMARY KEY("id")
)
""")
qos_dict = [{
'history': 3,
'depth': 0,
@ -264,83 +274,131 @@ class RosbagTrans(TransBase):
qos_json = yaml.dump(qos_dict, Dumper=IndentDumper, sort_keys=False)
# Populate the topics table from self.topics_list
for topic in self.topics_list:
topic['offered_qos_profiles'] = self.format_qos_profiles()
cursor.execute("""
for topic in self.topic_map.values():
self.cursor.execute("""
INSERT INTO topics (id, name, type, serialization_format, offered_qos_profiles)
VALUES (?, ?, ?, ?, ?)
""", (
topic['id'] + 1,
topic['topic_name'],
topic['msg_type'].replace('ros2:', ''),
topic.topic_id,
topic.topic_name,
topic.msg_type.replace('ros2:', ''),
'cdr', # Use 'cdr' as the default serialization format
qos_json
))
conn.commit()
self.conn.commit()
except Exception as e:
print(f"Error create topics table or insert topics table data, error: {e}")
conn.rollback()
self.conn.rollback()
def insert_schema_version(self, conn, cursor):
def update_rosbag_yaml_data(self):
self.rosbag_yaml_data = {
"version": 5,
"storage_identifier": "sqlite3",
"duration": {
"nanoseconds": self.end_time_nanoseconds - self.starting_time_nanoseconds
},
"starting_time": {
"nanoseconds_since_epoch": self.starting_time_nanoseconds
},
"message_count": self.message_count,
"topics_with_message_count": [],
"compression_format": "",
"compression_mode": "",
"relative_file_paths": [],
"files": []
}
for topic in self.topic_map.values():
topic_entry = {
"topic_metadata": {
"name": topic.topic_name,
"type": topic.msg_type.replace('ros2:', ''),
"serialization_format": "cdr",
"offered_qos_profiles": self.format_qos_profiles(),
},
"message_count": topic.message_count
}
self.rosbag_yaml_data["topics_with_message_count"].append(topic_entry)
file_entry = {
"path": "rosbag.db3",
"starting_time": {
"nanoseconds_since_epoch": self.starting_time_nanoseconds
},
"duration": {
"nanoseconds": self.end_time_nanoseconds - self.starting_time_nanoseconds
},
"message_count": self.message_count
}
self.rosbag_yaml_data["relative_file_paths"].append("rosbag.db3")
self.rosbag_yaml_data["files"].append(file_entry)
final_yaml_data = {
"rosbag2_bagfile_information": self.rosbag_yaml_data
}
with open(os.path.join(self.output_dir, "metadata.yaml"), "w") as f:
yaml_str = yaml.dump(
final_yaml_data,
Dumper=IndentDumper,
default_flow_style=False,
sort_keys=False,
indent=2,
width=1000000)
yaml_str = yaml_str.replace("\'", "\"")
f.write(yaml_str)
def sort_db_data(self):
print("start sorting messages table by timestamp")
try:
cursor.execute("""
CREATE TABLE "schema" (
"schema_version" INTEGER,
"ros_distro" TEXT NOT NULL,
PRIMARY KEY("schema_version")
);
self.cursor.execute("""
CREATE TABLE messages_temp(
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
topic_id INTEGER NOT NULL,
timestamp INTEGER NOT NULL,
data BLOB NOT NULL)
""")
cursor.execute("""
INSERT INTO schema (schema_version, ros_distro)
VALUES (?, ?)
""", (3, "humble"))
conn.commit()
except Exception as e:
print(f"Error create schema version, error: {e}")
conn.rollback()
def insert_metadata_table(self, conn, cursor):
try:
cursor.execute("""
CREATE TABLE "metadata" (
"id" INTEGER,
"metadata_version" INTEGER NOT NULL,
"metadata" TEXT NOT NULL,
PRIMARY KEY("id")
);
self.cursor.execute("""
INSERT INTO messages_temp (topic_id, timestamp, data)
SELECT topic_id, timestamp, data
FROM messages
ORDER BY timestamp ASC
""")
self.cursor.execute("DROP TABLE messages")
self.cursor.execute("ALTER TABLE messages_temp RENAME TO messages")
self.conn.commit()
except Exception as e:
print(f"Error create metadata table, error: {e}")
conn.rollback()
def trans_single_db(self, db_path: Path):
single_bag_info = SingleBagProcess(self.topic_info_dict, db_path)
self.all_duration += single_bag_info.duration_nanoseconds
self.message_count += single_bag_info.message_count
self.starting_time_nanoseconds = min(self.starting_time_nanoseconds, single_bag_info.starting_time_nanoseconds)
for topic in single_bag_info.topic_with_message_count:
self.topic_with_message_count[topic] = self.topic_with_message_count.get(
topic, 0) + single_bag_info.topic_with_message_count[topic]
self.bag_info_list.append(single_bag_info)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
try:
self.insert_schema_version(conn, cursor)
self.insert_metadata_table(conn, cursor)
self.insert_topics_table(conn, cursor)
self.update_messages_table(conn, cursor)
except Exception as e:
print(f"Error updating messages table: {e}")
conn.rollback()
print(f"Error sorting messages table: {e}")
self.conn.rollback()
def trans(self):
self.copy_file()
self.parse_yaml()
print(f"thers is : {len(self.files_list)} files")
for db_path in self.files_list:
trans_path = Path(self.output_dir_) / db_path['path']
self.trans_single_db(trans_path)
print(f"trans_path: {trans_path} done")
self.update_rosbag_yaml()
print(f"transing {self.input_dir} to {self.output_dir} \n")
try:
self.create_output_dir()
for input_dir in self.input_dir:
self.parse_yaml(input_dir)
self.insert_topics_table()
for input_dir in self.input_dir:
single_bag_trans = SingleBagTrans(
input_dir,
self.output_dir,
self.conn,
self.cursor,
self.message_count
)
single_bag_trans.trans_single_bag(self.topic_map)
self.message_count = single_bag_trans.id
self.starting_time_nanoseconds = single_bag_trans.starting_time_nanoseconds
self.end_time_nanoseconds = single_bag_trans.end_time_nanoseconds
self.sort_db_data()
self.update_rosbag_yaml_data()
finally:
if self.db_manager:
self.db_manager.close()
print(f"transing {self.input_dir} to {self.output_dir} done\n")