Update schmea; create client when registering

This commit is contained in:
Owen
2025-11-03 15:42:22 -08:00
parent 43590896e9
commit d30743a428
5 changed files with 200 additions and 33 deletions

View File

@@ -1,10 +1,22 @@
import { db, ExitNode } from "@server/db";
import {
Client,
db,
ExitNode,
orgs,
roleClients,
roles,
Transaction,
userClients,
userOrgs,
users
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!");
@@ -17,15 +29,62 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
const { publicKey, relay, olmVersion, orgId, deviceName } = message.data;
let client: Client;
if (orgId) {
if (!olm.userId) {
logger.warn("Olm has no user ID to verify org change!");
return;
}
try {
client = await getOrCreateOrgClient(orgId, olm.userId, deviceName);
} catch (err) {
logger.error(
`Error switching olm client ${olm.olmId} to org ${orgId}: ${err}`
);
return;
}
if (!client) {
logger.warn("Client not found");
return;
}
logger.debug(
`Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}`
);
await db
.update(olms)
.set({
clientId: client.clientId
})
.where(eq(olms.olmId, olm.olmId));
} else {
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
logger.debug(`Using last connected org for client ${olm.clientId}`);
[client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, olm.clientId))
.limit(1);
}
if (!client) {
logger.warn("Client ID not found");
return;
}
const clientId = olm.clientId;
const { publicKey, relay, olmVersion } = message.data;
logger.debug(
`Olm client ID: ${clientId}, Public Key: ${publicKey}, Relay: ${relay}`
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
);
if (!publicKey) {
@@ -33,18 +92,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
if (client.exitNodeId) {
// TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER
@@ -103,7 +150,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.set({
pubKey: publicKey
})
.where(eq(clients.clientId, olm.clientId));
.where(eq(clients.clientId, client.clientId));
// set isRelay to false for all of the client's sites to reset the connection metadata
await db
@@ -111,7 +158,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.set({
isRelayed: relay == true
})
.where(eq(clientSites.clientId, olm.clientId));
.where(eq(clientSites.clientId, client.clientId));
}
// Get all sites data
@@ -145,7 +192,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// Validate endpoint and hole punch status
if (!site.endpoint) {
logger.warn(`In olm register: site ${site.siteId} has no endpoint, skipping`);
logger.warn(
`In olm register: site ${site.siteId} has no endpoint, skipping`
);
continue;
}
@@ -240,3 +289,105 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
excludeSender: false
};
};
async function getOrCreateOrgClient(
orgId: string,
userId: string,
deviceName?: string,
trx: Transaction | typeof db = db
): Promise<Client> {
let client: Client;
// get the org
const [org] = await trx
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
throw new Error("Org not found");
}
if (!org.subnet) {
throw new Error("Org has no subnet defined");
}
// Verify that the user belongs to the org
const [userOrg] = await trx
.select()
.from(userOrgs)
.where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId)))
.limit(1);
if (!userOrg) {
throw new Error("User does not belong to org");
}
// check if the user has a client in the org and if not then create a client for them
const [existingClient] = await trx
.select()
.from(clients)
.where(and(eq(clients.orgId, orgId), eq(clients.userId, userId)))
.limit(1);
if (!existingClient) {
logger.debug(
`Client does not exist in org ${orgId}, creating new client for user ${userId}`
);
// TODO: more intelligent way to pick the exit node
const exitNodesList = await listExitNodes(orgId);
const randomExitNode =
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
const [adminRole] = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (!adminRole) {
throw new Error("Admin role not found");
}
const newSubnet = await getNextAvailableClientSubnet(orgId);
if (!newSubnet) {
throw new Error("No available subnet found");
}
const subnet = newSubnet.split("/")[0];
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
const [newClient] = await trx
.insert(clients)
.values({
exitNodeId: randomExitNode.exitNodeId,
orgId,
name: deviceName || "User Device",
subnet: updatedSubnet,
type: "olm",
userId: userId
})
.returning();
await trx.insert(roleClients).values({
roleId: adminRole.roleId,
clientId: newClient.clientId
});
if (userOrg.roleId != adminRole.roleId) {
// make sure the user can access the client
trx.insert(userClients).values({
userId,
clientId: newClient.clientId
});
}
client = newClient;
} else {
client = existingClient;
}
return client;
}