feat(rosbag_trans): add support to convert multiple aimrt bags into a single ROS bag (#90)
* feat: add functionality to convert multiple aimrt bags into a single ROS bag * refactor(rosbag_trans): add DatabaseManager class to unify database operations * fix: format the code * fix: format code * feat(bagtrans_tool): Update Command Line Tool Documentation
This commit is contained in:
parent
29eb541fd6
commit
e01333313f
@ -23,10 +23,10 @@ aimrt_cli trans -h, --help 会显示参数说明:
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
-s SRC_DIR, --src_dir SRC_DIR
|
||||
aimrtbag source directory.
|
||||
-s SRC_DIR [SRC_DIR ...], --src_dir SRC_DIR [SRC_DIR ...]
|
||||
aimrtbag source directories (support multiple directories)
|
||||
-o OUTPUT_DIR, --output_dir OUTPUT_DIR
|
||||
directory you want to output your files.
|
||||
directory you want to output your files
|
||||
```
|
||||
|
||||
其中 `-s` 参数为必填参数,表示 aimrtbag 的源目录,`-o` 参数为必填参数,表示转换后的bag的输出目录,如果输出目录不存在,则会自动创建;如果输出目录存在,则会覆盖。
|
||||
其中 `-s` 参数为必填参数,表示 aimrtbag 的源目录,支持多个目录,`-o` 参数为必填参数,表示转换后的bag的输出目录,如果输出目录不存在,则会自动创建;如果输出目录存在,则会覆盖。
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from aimrt_cli.command import CommandBase
|
||||
from aimrt_cli.generator.project_generator import ProjectGenerator
|
||||
from aimrt_cli.trans.rosbag_trans import RosbagTrans
|
||||
from aimrt_cli.trans.rosbag_trans import AimrtbagToRos2
|
||||
|
||||
|
||||
class GenCommand(CommandBase):
|
||||
@ -25,5 +25,5 @@ class GenCommand(CommandBase):
|
||||
generator = ProjectGenerator(cfg_path=args.project_cfg, output_dir=args.output_dir)
|
||||
generator.generate()
|
||||
|
||||
trans = RosbagTrans(args.src_dir, args.output_dir)
|
||||
trans = AimrtbagToRos2(args.src_dir, args.output_dir)
|
||||
trans.trans()
|
||||
|
@ -2,7 +2,7 @@
|
||||
# All rights reserved.
|
||||
|
||||
from aimrt_cli.command import CommandBase
|
||||
from aimrt_cli.trans.rosbag_trans import RosbagTrans
|
||||
from aimrt_cli.trans.rosbag_trans import AimrtbagToRos2
|
||||
|
||||
|
||||
class TransCommand(CommandBase):
|
||||
@ -13,12 +13,16 @@ class TransCommand(CommandBase):
|
||||
def add_arguments(self, parser, cmd_name):
|
||||
if cmd_name == "trans":
|
||||
self.parser_ = parser
|
||||
parser.add_argument("-s", "--src_dir", help="aimrtbag source directory.")
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--src_dir",
|
||||
nargs='+',
|
||||
help="aimrtbag source directories (support multiple directories).")
|
||||
parser.add_argument("-o", "--output_dir", help="directory you want to output your files.")
|
||||
|
||||
def main(self, *, args=None):
|
||||
if args is None:
|
||||
self.parser_.print_help()
|
||||
return 0
|
||||
trans = RosbagTrans(args.src_dir, args.output_dir)
|
||||
trans = AimrtbagToRos2(args.src_dir, args.output_dir)
|
||||
trans.trans()
|
||||
|
@ -21,7 +21,11 @@ def main(description=None):
|
||||
|
||||
# bag trans sub command
|
||||
trans_parser = subparsers.add_parser('trans', help='Transform bag files')
|
||||
trans_parser.add_argument("-s", "--src_dir", help="aimrtbag source directory")
|
||||
trans_parser.add_argument(
|
||||
"-s",
|
||||
"--src_dir",
|
||||
nargs='+',
|
||||
help="aimrtbag source directories (support multiple directories)")
|
||||
trans_parser.add_argument("-o", "--output_dir", help="directory you want to output your files")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -15,14 +15,74 @@ class IndentDumper(yaml.Dumper):
|
||||
return super(IndentDumper, self).increase_indent(flow, False)
|
||||
|
||||
|
||||
class SingleBagProcess:
|
||||
class DatabaseManager:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.conn = None
|
||||
self.cursor = None
|
||||
|
||||
def connect(self):
|
||||
self.conn = sqlite3.connect(self.db_path)
|
||||
self.cursor = self.conn.cursor()
|
||||
return self.conn, self.cursor
|
||||
|
||||
def create_tables(self):
|
||||
try:
|
||||
# create messages table
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE messages(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||
topic_id INTEGER NOT NULL,
|
||||
timestamp INTEGER NOT NULL,
|
||||
data BLOB NOT NULL)
|
||||
""")
|
||||
|
||||
# create topics table
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE topics(
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
serialization_format TEXT NOT NULL,
|
||||
offered_qos_profiles TEXT NOT NULL)
|
||||
""")
|
||||
|
||||
# create schema table
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE "schema" (
|
||||
"schema_version" INTEGER,
|
||||
"ros_distro" TEXT NOT NULL,
|
||||
PRIMARY KEY("schema_version")
|
||||
);
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
INSERT INTO schema (schema_version, ros_distro)
|
||||
VALUES (?, ?)
|
||||
""", (3, "humble"))
|
||||
|
||||
# create metadata table
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE metadata(id INTEGER PRIMARY KEY,metadata_version INTEGER NOT NULL,metadata TEXT NOT NULL)
|
||||
""")
|
||||
self.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self.conn.rollback()
|
||||
raise e
|
||||
|
||||
def close(self):
|
||||
if self.cursor:
|
||||
self.cursor.close()
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
|
||||
|
||||
class SingleDbProcess:
|
||||
def __init__(self, topic_info_dict: dict, db_path: Path):
|
||||
self.message_count = 0
|
||||
self.duration_nanoseconds = 0
|
||||
self.starting_time_nanoseconds = 100000000000000000000
|
||||
|
||||
self.topic_with_message_count = {}
|
||||
self.topic_info_dict = topic_info_dict
|
||||
self.starting_time_nanoseconds = int(1e20)
|
||||
self.end_time_nanoseconds = 0
|
||||
self.db_path = db_path
|
||||
self.get_info()
|
||||
|
||||
@ -30,21 +90,25 @@ class SingleBagProcess:
|
||||
try:
|
||||
cursor.execute("SELECT topic_id, timestamp FROM messages")
|
||||
rows = sorted(cursor.fetchall())
|
||||
|
||||
self.starting_time_nanoseconds = min(self.starting_time_nanoseconds, rows[0][1])
|
||||
self.duration_nanoseconds = rows[-1][1] - self.starting_time_nanoseconds
|
||||
self.message_count = len(rows)
|
||||
for row in rows:
|
||||
self.topic_with_message_count[self.topic_info_dict[row[0]].topic_name] = self.topic_with_message_count.get(
|
||||
self.topic_info_dict[row[0]].topic_name, 0) + 1
|
||||
if rows:
|
||||
self.starting_time_nanoseconds = rows[0][1]
|
||||
self.end_time_nanoseconds = rows[-1][1]
|
||||
self.message_count = len(rows)
|
||||
for row in rows:
|
||||
topic_name = self.topic_info_dict[row[0]].topic_name
|
||||
self.topic_with_message_count[topic_name] = \
|
||||
self.topic_with_message_count.get(topic_name, 0) + 1
|
||||
except Exception as e:
|
||||
print(f"Error getting single bag info: {e}")
|
||||
conn.rollback()
|
||||
|
||||
def get_info(self):
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
self.get_bag_info(conn, cursor)
|
||||
db_manager = DatabaseManager(str(self.db_path))
|
||||
conn, cursor = db_manager.connect()
|
||||
try:
|
||||
self.get_bag_info(conn, cursor)
|
||||
finally:
|
||||
db_manager.close()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -53,152 +117,116 @@ class TopicInfo:
|
||||
topic_name: str
|
||||
msg_type: str
|
||||
serialization_type: str
|
||||
message_count: int
|
||||
|
||||
|
||||
def encode_topic_name(topic_name: str, msg_type: str):
|
||||
if msg_type.startswith("pb"):
|
||||
return topic_name + '/' + msg_type.replace('/', '_2F').replace(':', '_3A').replace('.', '_2E')
|
||||
else:
|
||||
return topic_name
|
||||
|
||||
|
||||
class RosbagTrans(TransBase):
|
||||
def __init__(self, input_dir: str, output_dir: str):
|
||||
class SingleBagTrans(TransBase):
|
||||
def __init__(self, input_dir: str, output_dir: str, conn: sqlite3.Connection, cursor: sqlite3.Cursor, id: int):
|
||||
super().__init__(output_dir)
|
||||
self.input_dir_ = input_dir
|
||||
self.input_dir = input_dir
|
||||
self.output_dir = output_dir
|
||||
self.topics_list = {}
|
||||
self.topic_info_dict = {}
|
||||
self.files_list = {}
|
||||
self.bag_info_list = []
|
||||
self.message_count = 0
|
||||
self.all_duration = 0
|
||||
self.topic_with_message_count = {}
|
||||
self.starting_time_nanoseconds = 100000000000000000000
|
||||
self.duration_nanoseconds = 0
|
||||
|
||||
self.rosbag_yaml_data = {
|
||||
"version": 5,
|
||||
"storage_identifier": "sqlite3",
|
||||
"duration": {
|
||||
"nanoseconds": 0
|
||||
},
|
||||
"starting_time": {
|
||||
"nanoseconds_since_epoch": 0
|
||||
},
|
||||
"message_count": 0,
|
||||
"topics_with_message_count": [],
|
||||
"compression_format": "",
|
||||
"compression_mode": "",
|
||||
"relative_file_paths": [],
|
||||
"files": []
|
||||
}
|
||||
|
||||
def copy_file(self):
|
||||
if os.path.exists(self.output_dir_):
|
||||
shutil.rmtree(self.output_dir_)
|
||||
try:
|
||||
shutil.copytree(self.input_dir_, self.output_dir_)
|
||||
print(f"Directory successfully copied from {self.input_dir_} to {self.output_dir_}")
|
||||
except shutil.Error as e:
|
||||
print(f"Copy error: {e}")
|
||||
except OSError as e:
|
||||
print(f"System error: {e}")
|
||||
self.starting_time_nanoseconds = int(1e20)
|
||||
self.end_time_nanoseconds = 0
|
||||
self.id = id # target db message id
|
||||
self.conn = conn # target db connection
|
||||
self.cursor = cursor # target db cursor
|
||||
|
||||
def parse_yaml(self):
|
||||
with open(os.path.join(self.output_dir_, "metadata.yaml"), "r") as f:
|
||||
with open(os.path.join(self.input_dir, "metadata.yaml"), "r") as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if data["aimrt_bagfile_information"] is not None:
|
||||
if data["aimrt_bagfile_information"]["topics"] is not None:
|
||||
self.topics_list = data["aimrt_bagfile_information"]["topics"]
|
||||
for topic in self.topics_list:
|
||||
self.topic_info_dict[topic["id"]] = TopicInfo(
|
||||
topic["id"], topic["topic_name"], topic["msg_type"], topic["serialization_type"])
|
||||
else:
|
||||
raise Exception("No topics found in metadata.yaml")
|
||||
|
||||
if data["aimrt_bagfile_information"]["files"] is not None:
|
||||
self.files_list = data["aimrt_bagfile_information"]["files"]
|
||||
else:
|
||||
raise Exception("No files found in metadata.yaml")
|
||||
else:
|
||||
if data is None or data["aimrt_bagfile_information"] is None or data["aimrt_bagfile_information"]["topics"] is None:
|
||||
raise Exception("No aimrt_bagfile_information found in metadata.yaml")
|
||||
|
||||
self.rosbag_yaml_data = {
|
||||
"version": 5,
|
||||
"storage_identifier": "sqlite3",
|
||||
}
|
||||
return data
|
||||
|
||||
def update_rosbag_yaml(self):
|
||||
self.rosbag_yaml_data = {
|
||||
"version": 5,
|
||||
"storage_identifier": "sqlite3",
|
||||
"duration": {
|
||||
"nanoseconds": self.all_duration
|
||||
},
|
||||
"starting_time": {
|
||||
"nanoseconds_since_epoch": self.starting_time_nanoseconds
|
||||
},
|
||||
"message_count": self.message_count,
|
||||
"topics_with_message_count": [],
|
||||
"compression_format": "",
|
||||
"compression_mode": "",
|
||||
"relative_file_paths": [],
|
||||
"files": [],
|
||||
}
|
||||
|
||||
def transfertopic_msg_type(msg_type):
|
||||
if msg_type.startswith("pb"):
|
||||
return "ros2_plugin_proto/msg/RosMsgWrapper"
|
||||
elif msg_type.startswith("ros2"):
|
||||
return msg_type.replace("ros2:", "")
|
||||
else:
|
||||
return msg_type
|
||||
|
||||
self.topics_list = data["aimrt_bagfile_information"]["topics"]
|
||||
for topic in self.topics_list:
|
||||
self.topic_info_dict[topic["id"]] = TopicInfo(
|
||||
topic["id"], topic["topic_name"], topic["msg_type"], topic["serialization_type"], 0)
|
||||
|
||||
topic_message_count = self.topic_with_message_count.get(topic["topic_name"], 0)
|
||||
topic_entry = {
|
||||
"topic_metadata": {
|
||||
"name": encode_topic_name(topic["topic_name"], topic["msg_type"]),
|
||||
"type": transfertopic_msg_type(topic["msg_type"]),
|
||||
"serialization_format": "cdr",
|
||||
"offered_qos_profiles": self.format_qos_profiles()
|
||||
},
|
||||
"message_count": topic_message_count
|
||||
}
|
||||
self.rosbag_yaml_data["topics_with_message_count"].append(topic_entry)
|
||||
if data["aimrt_bagfile_information"]["files"] is not None:
|
||||
self.files_list = data["aimrt_bagfile_information"]["files"]
|
||||
else:
|
||||
raise Exception("No db files found in metadata.yaml")
|
||||
|
||||
for file_info in self.bag_info_list:
|
||||
self.rosbag_yaml_data["relative_file_paths"].append(file_info.db_path.name)
|
||||
file_entry = {
|
||||
"path": file_info.db_path.name,
|
||||
"starting_time": {
|
||||
"nanoseconds_since_epoch": file_info.starting_time_nanoseconds
|
||||
},
|
||||
"duration": {
|
||||
"nanoseconds": file_info.duration_nanoseconds
|
||||
},
|
||||
"message_count": file_info.message_count
|
||||
}
|
||||
self.rosbag_yaml_data["files"].append(file_entry)
|
||||
def trans_single_db(self, source_path: Path, topic_map: dict):
|
||||
single_bag_info = SingleDbProcess(self.topic_info_dict, source_path)
|
||||
self.message_count += single_bag_info.message_count
|
||||
self.starting_time_nanoseconds = min(self.starting_time_nanoseconds, single_bag_info.starting_time_nanoseconds)
|
||||
self.end_time_nanoseconds = max(self.end_time_nanoseconds, single_bag_info.end_time_nanoseconds)
|
||||
|
||||
final_yaml_data = {
|
||||
"rosbag2_bagfile_information": self.rosbag_yaml_data
|
||||
}
|
||||
conn = sqlite3.connect(source_path)
|
||||
print(f" processing db file: {source_path}")
|
||||
cursor = conn.cursor()
|
||||
|
||||
abs_output_dir = os.path.abspath(self.output_dir_)
|
||||
with open(os.path.join(abs_output_dir, "metadata.yaml"), "w") as f:
|
||||
yaml_str = yaml.dump(
|
||||
final_yaml_data,
|
||||
Dumper=IndentDumper,
|
||||
default_flow_style=False,
|
||||
sort_keys=False,
|
||||
indent=2,
|
||||
width=1000000)
|
||||
yaml_str = yaml_str.replace("\'", "\"")
|
||||
f.write(yaml_str)
|
||||
print(f"{os.path.join(abs_output_dir, 'metadata.yaml')} has been updated")
|
||||
try:
|
||||
select_sql = "SELECT id,topic_id, timestamp, data FROM messages"
|
||||
cursor.execute(select_sql)
|
||||
rows = cursor.fetchall()
|
||||
self.cursor.executemany("""
|
||||
INSERT INTO messages (id, topic_id, timestamp, data)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", [(self.id + row[0], topic_map[self.topic_info_dict[row[1]].topic_name].topic_id, row[2], row[3]) for row in rows])
|
||||
for row in rows:
|
||||
topic_map[self.topic_info_dict[row[1]].topic_name].message_count += 1
|
||||
self.conn.commit()
|
||||
print(f" size of data inserted: {len(rows)} done")
|
||||
except Exception as e:
|
||||
print(f" Error updating messages table: {e}")
|
||||
self.conn.rollback()
|
||||
self.id += len(rows)
|
||||
|
||||
def trans_single_bag(self, topic_map: dict):
|
||||
self.parse_yaml()
|
||||
print(f"there are {len(self.files_list)} db files in {self.input_dir}")
|
||||
for db_path in self.files_list:
|
||||
trans_path = Path(self.output_dir) / db_path['path']
|
||||
self.trans_single_db(Path(self.input_dir) / db_path['path'], topic_map)
|
||||
print(f" trans_path: {trans_path} done")
|
||||
print(f"all db files in {self.input_dir} done\n")
|
||||
|
||||
|
||||
class AimrtbagToRos2:
|
||||
def __init__(self, input_dir: list, output_dir: str):
|
||||
self.input_dir = input_dir
|
||||
self.output_dir = output_dir
|
||||
self.topic_map = {}
|
||||
self.id = 0
|
||||
self.message_count = 0
|
||||
self.starting_time_nanoseconds = int(1e20)
|
||||
self.end_time_nanoseconds = 0
|
||||
self.topics_list = []
|
||||
self.db_manager = None
|
||||
self.conn = None
|
||||
self.cursor = None
|
||||
|
||||
def create_output_dir(self):
|
||||
if os.path.exists(self.output_dir):
|
||||
shutil.rmtree(self.output_dir)
|
||||
os.makedirs(self.output_dir)
|
||||
|
||||
# initialize database
|
||||
db_path = os.path.join(self.output_dir, "rosbag.db3")
|
||||
self.db_manager = DatabaseManager(db_path)
|
||||
self.conn, self.cursor = self.db_manager.connect()
|
||||
self.db_manager.create_tables()
|
||||
|
||||
def parse_yaml(self, input_dir: str):
|
||||
with open(os.path.join(input_dir, "metadata.yaml"), "r") as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if data["aimrt_bagfile_information"] is None or data["aimrt_bagfile_information"]["topics"] is None:
|
||||
raise Exception("No topics information found in metadata.yaml")
|
||||
|
||||
topics_list = data["aimrt_bagfile_information"]["topics"]
|
||||
|
||||
for topic in topics_list:
|
||||
if topic["topic_name"] not in self.topic_map:
|
||||
self.id += 1
|
||||
self.topic_map[topic["topic_name"]] = TopicInfo(
|
||||
self.id, topic["topic_name"], topic["msg_type"], topic["serialization_type"], 0)
|
||||
else:
|
||||
print(f"warning: topic {topic['topic_name']} already exists")
|
||||
|
||||
def format_qos_profiles(self):
|
||||
qos_dict = {
|
||||
@ -221,26 +249,8 @@ class RosbagTrans(TransBase):
|
||||
|
||||
return qos_string
|
||||
|
||||
def update_messages_table(self, conn, cursor):
|
||||
def insert_topics_table(self):
|
||||
try:
|
||||
cursor.execute("UPDATE messages SET topic_id = topic_id + 1")
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
print(f"Error update messages table, error: {e}")
|
||||
conn.rollback()
|
||||
|
||||
def insert_topics_table(self, conn, cursor):
|
||||
try:
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS "topics" (
|
||||
"id" INTEGER,
|
||||
"name" TEXT NOT NULL,
|
||||
"type" TEXT NOT NULL,
|
||||
"serialization_format" TEXT NOT NULL,
|
||||
"offered_qos_profiles" TEXT NOT NULL,
|
||||
PRIMARY KEY("id")
|
||||
)
|
||||
""")
|
||||
qos_dict = [{
|
||||
'history': 3,
|
||||
'depth': 0,
|
||||
@ -264,83 +274,131 @@ class RosbagTrans(TransBase):
|
||||
qos_json = yaml.dump(qos_dict, Dumper=IndentDumper, sort_keys=False)
|
||||
|
||||
# Populate the topics table from self.topics_list
|
||||
for topic in self.topics_list:
|
||||
topic['offered_qos_profiles'] = self.format_qos_profiles()
|
||||
cursor.execute("""
|
||||
for topic in self.topic_map.values():
|
||||
self.cursor.execute("""
|
||||
INSERT INTO topics (id, name, type, serialization_format, offered_qos_profiles)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (
|
||||
topic['id'] + 1,
|
||||
topic['topic_name'],
|
||||
topic['msg_type'].replace('ros2:', ''),
|
||||
topic.topic_id,
|
||||
topic.topic_name,
|
||||
topic.msg_type.replace('ros2:', ''),
|
||||
'cdr', # Use 'cdr' as the default serialization format
|
||||
qos_json
|
||||
))
|
||||
conn.commit()
|
||||
self.conn.commit()
|
||||
except Exception as e:
|
||||
print(f"Error create topics table or insert topics table data, error: {e}")
|
||||
conn.rollback()
|
||||
self.conn.rollback()
|
||||
|
||||
def insert_schema_version(self, conn, cursor):
|
||||
def update_rosbag_yaml_data(self):
|
||||
self.rosbag_yaml_data = {
|
||||
"version": 5,
|
||||
"storage_identifier": "sqlite3",
|
||||
"duration": {
|
||||
"nanoseconds": self.end_time_nanoseconds - self.starting_time_nanoseconds
|
||||
},
|
||||
"starting_time": {
|
||||
"nanoseconds_since_epoch": self.starting_time_nanoseconds
|
||||
},
|
||||
"message_count": self.message_count,
|
||||
"topics_with_message_count": [],
|
||||
"compression_format": "",
|
||||
"compression_mode": "",
|
||||
"relative_file_paths": [],
|
||||
"files": []
|
||||
}
|
||||
|
||||
for topic in self.topic_map.values():
|
||||
topic_entry = {
|
||||
"topic_metadata": {
|
||||
"name": topic.topic_name,
|
||||
"type": topic.msg_type.replace('ros2:', ''),
|
||||
"serialization_format": "cdr",
|
||||
"offered_qos_profiles": self.format_qos_profiles(),
|
||||
},
|
||||
"message_count": topic.message_count
|
||||
}
|
||||
self.rosbag_yaml_data["topics_with_message_count"].append(topic_entry)
|
||||
|
||||
file_entry = {
|
||||
"path": "rosbag.db3",
|
||||
"starting_time": {
|
||||
"nanoseconds_since_epoch": self.starting_time_nanoseconds
|
||||
},
|
||||
"duration": {
|
||||
"nanoseconds": self.end_time_nanoseconds - self.starting_time_nanoseconds
|
||||
},
|
||||
"message_count": self.message_count
|
||||
}
|
||||
self.rosbag_yaml_data["relative_file_paths"].append("rosbag.db3")
|
||||
self.rosbag_yaml_data["files"].append(file_entry)
|
||||
final_yaml_data = {
|
||||
"rosbag2_bagfile_information": self.rosbag_yaml_data
|
||||
}
|
||||
with open(os.path.join(self.output_dir, "metadata.yaml"), "w") as f:
|
||||
yaml_str = yaml.dump(
|
||||
final_yaml_data,
|
||||
Dumper=IndentDumper,
|
||||
default_flow_style=False,
|
||||
sort_keys=False,
|
||||
indent=2,
|
||||
width=1000000)
|
||||
yaml_str = yaml_str.replace("\'", "\"")
|
||||
f.write(yaml_str)
|
||||
|
||||
def sort_db_data(self):
|
||||
print("start sorting messages table by timestamp")
|
||||
try:
|
||||
cursor.execute("""
|
||||
CREATE TABLE "schema" (
|
||||
"schema_version" INTEGER,
|
||||
"ros_distro" TEXT NOT NULL,
|
||||
PRIMARY KEY("schema_version")
|
||||
);
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE messages_temp(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||
topic_id INTEGER NOT NULL,
|
||||
timestamp INTEGER NOT NULL,
|
||||
data BLOB NOT NULL)
|
||||
""")
|
||||
cursor.execute("""
|
||||
INSERT INTO schema (schema_version, ros_distro)
|
||||
VALUES (?, ?)
|
||||
""", (3, "humble"))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
print(f"Error create schema version, error: {e}")
|
||||
conn.rollback()
|
||||
|
||||
def insert_metadata_table(self, conn, cursor):
|
||||
try:
|
||||
cursor.execute("""
|
||||
CREATE TABLE "metadata" (
|
||||
"id" INTEGER,
|
||||
"metadata_version" INTEGER NOT NULL,
|
||||
"metadata" TEXT NOT NULL,
|
||||
PRIMARY KEY("id")
|
||||
);
|
||||
self.cursor.execute("""
|
||||
INSERT INTO messages_temp (topic_id, timestamp, data)
|
||||
SELECT topic_id, timestamp, data
|
||||
FROM messages
|
||||
ORDER BY timestamp ASC
|
||||
""")
|
||||
|
||||
self.cursor.execute("DROP TABLE messages")
|
||||
|
||||
self.cursor.execute("ALTER TABLE messages_temp RENAME TO messages")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error create metadata table, error: {e}")
|
||||
conn.rollback()
|
||||
|
||||
def trans_single_db(self, db_path: Path):
|
||||
single_bag_info = SingleBagProcess(self.topic_info_dict, db_path)
|
||||
self.all_duration += single_bag_info.duration_nanoseconds
|
||||
self.message_count += single_bag_info.message_count
|
||||
self.starting_time_nanoseconds = min(self.starting_time_nanoseconds, single_bag_info.starting_time_nanoseconds)
|
||||
|
||||
for topic in single_bag_info.topic_with_message_count:
|
||||
self.topic_with_message_count[topic] = self.topic_with_message_count.get(
|
||||
topic, 0) + single_bag_info.topic_with_message_count[topic]
|
||||
self.bag_info_list.append(single_bag_info)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
self.insert_schema_version(conn, cursor)
|
||||
self.insert_metadata_table(conn, cursor)
|
||||
self.insert_topics_table(conn, cursor)
|
||||
self.update_messages_table(conn, cursor)
|
||||
except Exception as e:
|
||||
print(f"Error updating messages table: {e}")
|
||||
conn.rollback()
|
||||
print(f"Error sorting messages table: {e}")
|
||||
self.conn.rollback()
|
||||
|
||||
def trans(self):
|
||||
self.copy_file()
|
||||
self.parse_yaml()
|
||||
print(f"thers is : {len(self.files_list)} files")
|
||||
for db_path in self.files_list:
|
||||
trans_path = Path(self.output_dir_) / db_path['path']
|
||||
self.trans_single_db(trans_path)
|
||||
print(f"trans_path: {trans_path} done")
|
||||
self.update_rosbag_yaml()
|
||||
print(f"transing {self.input_dir} to {self.output_dir} \n")
|
||||
try:
|
||||
self.create_output_dir()
|
||||
for input_dir in self.input_dir:
|
||||
self.parse_yaml(input_dir)
|
||||
self.insert_topics_table()
|
||||
|
||||
for input_dir in self.input_dir:
|
||||
single_bag_trans = SingleBagTrans(
|
||||
input_dir,
|
||||
self.output_dir,
|
||||
self.conn,
|
||||
self.cursor,
|
||||
self.message_count
|
||||
)
|
||||
single_bag_trans.trans_single_bag(self.topic_map)
|
||||
self.message_count = single_bag_trans.id
|
||||
self.starting_time_nanoseconds = single_bag_trans.starting_time_nanoseconds
|
||||
self.end_time_nanoseconds = single_bag_trans.end_time_nanoseconds
|
||||
|
||||
self.sort_db_data()
|
||||
self.update_rosbag_yaml_data()
|
||||
finally:
|
||||
if self.db_manager:
|
||||
self.db_manager.close()
|
||||
|
||||
print(f"transing {self.input_dir} to {self.output_dir} done\n")
|
||||
|
Loading…
x
Reference in New Issue
Block a user