Changed tlv8_t from struct to class and made variables private

Add getLen() and getTag() method to get length and tag of tlv8_t.

Also overrode subscript operator [] so you can access any element of internal uint8_t array.

Given previous additions (such as getVal()) there is now no reason to need to access the underlying std::unique_ptr directly.
This commit is contained in:
Gregg 2024-04-13 22:59:37 -05:00
parent 28990d6ed6
commit 48cab1f82b
3 changed files with 43 additions and 26 deletions

View File

@ -404,7 +404,7 @@ int HAPClient::postPairSetupURL(uint8_t *content, size_t len){
return(0); return(0);
}; };
srp->createSessionKey(*itPublicKey,(*itPublicKey).len); // create session key, K, from client Public Key, A srp->createSessionKey(*itPublicKey,(*itPublicKey).getLen()); // create session key, K, from client Public Key, A
if(!srp->verifyClientProof(*itClientProof)){ // verify client Proof, M1 if(!srp->verifyClientProof(*itClientProof)){ // verify client Proof, M1
LOG0("\n*** ERROR: SRP Proof Verification Failed\n\n"); LOG0("\n*** ERROR: SRP Proof Verification Failed\n\n");
@ -454,9 +454,9 @@ int HAPClient::postPairSetupURL(uint8_t *content, size_t len){
// use SessionKey to decrypt encryptedData TLV with padded nonce="PS-Msg05" // use SessionKey to decrypt encryptedData TLV with padded nonce="PS-Msg05"
TempBuffer<uint8_t> decrypted((*itEncryptedData).len-crypto_aead_chacha20poly1305_IETF_ABYTES); // temporary storage for decrypted data TempBuffer<uint8_t> decrypted((*itEncryptedData).getLen()-crypto_aead_chacha20poly1305_IETF_ABYTES); // temporary storage for decrypted data
if(crypto_aead_chacha20poly1305_ietf_decrypt(decrypted, NULL, NULL, *itEncryptedData, (*itEncryptedData).len, NULL, 0, (unsigned char *)"\x00\x00\x00\x00PS-Msg05", sessionKey)==-1){ if(crypto_aead_chacha20poly1305_ietf_decrypt(decrypted, NULL, NULL, *itEncryptedData, (*itEncryptedData).getLen(), NULL, 0, (unsigned char *)"\x00\x00\x00\x00PS-Msg05", sessionKey)==-1){
LOG0("\n*** ERROR: Exchange-Request Authentication Failed\n\n"); LOG0("\n*** ERROR: Exchange-Request Authentication Failed\n\n");
responseTLV.add(kTLVType_Error,tagError_Authentication); // set Error=Authentication responseTLV.add(kTLVType_Error,tagError_Authentication); // set Error=Authentication
tlvRespond(responseTLV); // send response to client tlvRespond(responseTLV); // send response to client
@ -492,7 +492,7 @@ int HAPClient::postPairSetupURL(uint8_t *content, size_t len){
// Concatenate iosDeviceX, IOS ID, and IOS PublicKey into iosDeviceInfo // Concatenate iosDeviceX, IOS ID, and IOS PublicKey into iosDeviceInfo
TempBuffer<uint8_t> iosDeviceInfo(iosDeviceX,iosDeviceX.len(),(*itIdentifier).val.get(),(*itIdentifier).len,(*itPublicKey).val.get(),(*itPublicKey).len,NULL); TempBuffer<uint8_t> iosDeviceInfo(iosDeviceX,iosDeviceX.len(),(uint8_t *)(*itIdentifier),(*itIdentifier).getLen(),(uint8_t *)(*itPublicKey),(*itPublicKey).getLen(),NULL);
if(crypto_sign_verify_detached(*itSignature, iosDeviceInfo, iosDeviceInfo.len(), *itPublicKey) != 0){ // verify signature of iosDeviceInfo using iosDeviceLTPK if(crypto_sign_verify_detached(*itSignature, iosDeviceInfo, iosDeviceInfo.len(), *itPublicKey) != 0){ // verify signature of iosDeviceInfo using iosDeviceLTPK
LOG0("\n*** ERROR: LPTK Signature Verification Failed\n\n"); LOG0("\n*** ERROR: LPTK Signature Verification Failed\n\n");
@ -668,9 +668,9 @@ int HAPClient::postPairVerifyURL(uint8_t *content, size_t len){
// use Session Curve25519 Key (from previous step) to decrypt encrypytedData TLV with padded nonce="PV-Msg03" // use Session Curve25519 Key (from previous step) to decrypt encrypytedData TLV with padded nonce="PV-Msg03"
TempBuffer<uint8_t> decrypted((*itEncryptedData).len-crypto_aead_chacha20poly1305_IETF_ABYTES); // temporary storage for decrypted data TempBuffer<uint8_t> decrypted((*itEncryptedData).getLen()-crypto_aead_chacha20poly1305_IETF_ABYTES); // temporary storage for decrypted data
if(crypto_aead_chacha20poly1305_ietf_decrypt(decrypted, NULL, NULL, *itEncryptedData, (*itEncryptedData).len, NULL, 0, (unsigned char *)"\x00\x00\x00\x00PV-Msg03", sessionKey)==-1){ if(crypto_aead_chacha20poly1305_ietf_decrypt(decrypted, NULL, NULL, *itEncryptedData, (*itEncryptedData).getLen(), NULL, 0, (unsigned char *)"\x00\x00\x00\x00PV-Msg03", sessionKey)==-1){
LOG0("\n*** ERROR: Verify Authentication Failed\n\n"); LOG0("\n*** ERROR: Verify Authentication Failed\n\n");
responseTLV.add(kTLVType_State,pairState_M4); // set State=<M4> responseTLV.add(kTLVType_State,pairState_M4); // set State=<M4>
responseTLV.add(kTLVType_Error,tagError_Authentication); // set Error=Authentication responseTLV.add(kTLVType_Error,tagError_Authentication); // set Error=Authentication

View File

@ -70,7 +70,7 @@ void tlv8_t::osprint(std::ostream& os){
TLV8_it TLV8::add(uint8_t tag, size_t len, const uint8_t* val){ TLV8_it TLV8::add(uint8_t tag, size_t len, const uint8_t* val){
if(!empty() && back().tag==tag) if(!empty() && back().getTag()==tag)
back().update(len,val); back().update(len,val);
else else
emplace_back(tag,len,val); emplace_back(tag,len,val);
@ -103,9 +103,9 @@ TLV8_it TLV8::add(uint8_t tag, uint64_t val){
TLV8_it TLV8::find(uint8_t tag, TLV8_it it1, TLV8_it it2){ TLV8_it TLV8::find(uint8_t tag, TLV8_it it1, TLV8_it it2){
auto it=it1; auto it=it1;
while(it!=it2 && (*it).tag!=tag) while(it!=it2 && (*it).getTag()!=tag)
it++; it++;
return(it==it2?end():it); return(it);
} }
///////////////////////////////////// /////////////////////////////////////
@ -115,9 +115,9 @@ size_t TLV8::pack_size(TLV8_it it1, TLV8_it it2){
size_t nBytes=0; size_t nBytes=0;
while(it1!=it2){ while(it1!=it2){
nBytes+=2+(*it1).len; nBytes+=2+(*it1).getLen();
if((*it1).len>255) if((*it1).getLen()>255)
nBytes+=2*(((*it1).len-1)/255); nBytes+=2*(((*it1).getLen()-1)/255);
it1++; it1++;
} }
@ -134,13 +134,13 @@ size_t TLV8::pack(uint8_t *buf, size_t bufSize){
switch(currentPackPhase){ switch(currentPackPhase){
case 0: case 0:
currentPackBuf=(*currentPackIt).val.get(); currentPackBuf=*currentPackIt;
endPackBuf=(*currentPackIt).val.get()+(*currentPackIt).len; endPackBuf=(*currentPackIt)+(*currentPackIt).getLen();
currentPackPhase=1; currentPackPhase=1;
break; break;
case 1: case 1:
*buf++=(*currentPackIt).tag; *buf++=(*currentPackIt).getTag();
nBytes++; nBytes++;
currentPackPhase=2; currentPackPhase=2;
break; break;
@ -228,7 +228,7 @@ int TLV8::unpack(TLV8_it it){
if(it==end()) if(it==end())
return(0); return(0);
return(unpack(*it,(*it).len)); return(unpack(*it,(*it).getLen()));
} }
///////////////////////////////////// /////////////////////////////////////
@ -251,19 +251,19 @@ const char *TLV8::getName(uint8_t tag){
void TLV8::print(TLV8_it it1, TLV8_it it2){ void TLV8::print(TLV8_it it1, TLV8_it it2){
while(it1!=it2){ while(it1!=it2){
const char *name=getName((*it1).tag); const char *name=getName((*it1).getTag());
if(name) if(name)
Serial.printf("%s",name); Serial.printf("%s",name);
else else
Serial.printf("%d",(*it1).tag); Serial.printf("%d",(*it1).getTag());
Serial.printf("(%d) ",(*it1).len); Serial.printf("(%d) ",(*it1).getLen());
for(int i=0;i<(*it1).len;i++) for(int i=0;i<(*it1).getLen();i++)
Serial.printf("%02X",(*it1).val.get()[i]); Serial.printf("%02X",(*it1)[i]);
if((*it1).len==0) if((*it1).getLen()==0)
Serial.printf(" [null]"); Serial.printf(" [null]");
else if((*it1).len<=4) else if((*it1).getLen()<=4)
Serial.printf(" [%u]",(*it1).getVal()); Serial.printf(" [%u]",(*it1).getVal());
else if((*it1).len<=8) else if((*it1).getLen()<=8)
Serial.printf(" [%llu]",(*it1).getVal<uint64_t>()); Serial.printf(" [%llu]",(*it1).getVal<uint64_t>());
Serial.printf("\n"); Serial.printf("\n");
it1++; it1++;

View File

@ -34,11 +34,16 @@
#include "PSRAM.h" #include "PSRAM.h"
struct tlv8_t { class tlv8_t {
private:
uint8_t tag; uint8_t tag;
size_t len; size_t len;
std::unique_ptr<uint8_t> val; std::unique_ptr<uint8_t> val;
public:
tlv8_t(uint8_t tag, size_t len, const uint8_t* val); tlv8_t(uint8_t tag, size_t len, const uint8_t* val);
void update(size_t addLen, const uint8_t *addVal); void update(size_t addLen, const uint8_t *addVal);
void osprint(std::ostream& os); void osprint(std::ostream& os);
@ -47,6 +52,18 @@ struct tlv8_t {
return(val.get()); return(val.get());
} }
uint8_t & operator[](int index){
return(val.get()[index]);
}
size_t getLen(){
return(len);
}
uint8_t getTag(){
return(tag);
}
template<class T=uint32_t> T getVal(){ template<class T=uint32_t> T getVal(){
T iVal=0; T iVal=0;
for(int i=0;i<len;i++) for(int i=0;i<len;i++)
@ -94,7 +111,7 @@ class TLV8 : public std::list<tlv8_t, Mallocator<tlv8_t>> {
TLV8_it find(uint8_t tag, TLV8_it it1){return(find(tag, it1, end()));} TLV8_it find(uint8_t tag, TLV8_it it1){return(find(tag, it1, end()));}
TLV8_it find(uint8_t tag){return(find(tag, begin(), end()));} TLV8_it find(uint8_t tag){return(find(tag, begin(), end()));}
int len(TLV8_it it){return(it==end()?-1:(*it).len);} int len(TLV8_it it){return(it==end()?-1:(*it).getLen());}
size_t pack_size(TLV8_it it1, TLV8_it it2); size_t pack_size(TLV8_it it1, TLV8_it it2);
size_t pack_size(){return(pack_size(begin(), end()));} size_t pack_size(){return(pack_size(begin(), end()));}